blob: 0f5bbf88963b83afb43b35628813462e7f90c9d1 [file] [edit]
/* Copyright 2023 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/pjrt/gpu/se_gpu_pjrt_compiler.h"
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "xla/client/executable_build_options.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/layout_util.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/gpu/se_gpu_topology_description.h"
#include "xla/pjrt/maybe_owning_mlir_module.h"
#include "xla/pjrt/mlir_to_hlo.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/stream_executor_executable.h"
#include "xla/pjrt/utils.h"
#include "xla/service/compiled_module.h"
#include "xla/service/compiler.h"
#include "xla/service/dump.h"
#include "xla/service/gpu_topology.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_module_util.h"
#include "xla/service/local_service_utils.h"
#include "xla/service/platform_util.h"
#include "xla/shape.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/platform/casts.h"
namespace xla {
namespace {
bool IsGpuClient(const PjRtClient& client) {
return client.platform_id() == CudaId() || client.platform_id() == RocmId() ||
client.platform_id() == SyclId();
}
bool IsSameTopology(const PjRtTopologyDescription& topology1,
const PjRtTopologyDescription& topology2) {
const StreamExecutorGpuTopologyDescription& gpu_topology1 =
tensorflow::down_cast<const StreamExecutorGpuTopologyDescription&>(
topology1);
const StreamExecutorGpuTopologyDescription& gpu_topology2 =
tensorflow::down_cast<const StreamExecutorGpuTopologyDescription&>(
topology2);
return gpu_topology1 == gpu_topology2;
}
absl::Status IsValidTopologyAndClientForCompile(
const PjRtTopologyDescription& topology, PjRtClient* client) {
if (client == nullptr) {
return absl::UnimplementedError(
"SE:GPU compiler requires non-null client.");
}
if (!IsGpuClient(*client)) {
return absl::InvalidArgumentError(
"SE:GPU compiler requires a GPU PjRtClient.");
}
return absl::OkStatus();
}
absl::StatusOr<std::unique_ptr<xla::Compiler>>
GetCompilerForDefaultGpuPlatform() {
TF_ASSIGN_OR_RETURN(stream_executor::Platform * platform,
PlatformUtil::GetPlatform("gpu"));
return Compiler::GetForPlatform(platform->id());
}
absl::StatusOr<std::unique_ptr<xla::Compiler>> GetCompilerForPlatform(
std::optional<stream_executor::Platform::Id> platform_id) {
if (!platform_id.has_value()) {
return GetCompilerForDefaultGpuPlatform();
}
TF_ASSIGN_OR_RETURN(
stream_executor::Platform * platform,
stream_executor::PlatformManager::PlatformWithId(platform_id.value()));
return Compiler::GetForPlatform(platform->id());
}
} // namespace
StreamExecutorGpuCompiler::StreamExecutorGpuCompiler(
PjRtPlatformId pjrt_platform_id, stream_executor::Platform::Id platform_id)
: requested_platform_id_(platform_id),
pjrt_platform_id_(pjrt_platform_id) {}
absl::StatusOr<Compiler*> StreamExecutorGpuCompiler::GetOrCreateCompiler() {
absl::MutexLock lock(compiler_mutex_);
if (compiler_ == nullptr) {
// We get the compiler here because doing so in the constructor might fail
// due to static initialization order shenanigans (An instance of this class
// is initialized statically and this might happen before the compiler is
// registered with Compiler::RegisterCompilerFactory). For the same reason,
// we can't fail construction of this class, therefore we have this
// GetOrCreate function and we can return on error when calling Compile.
TF_ASSIGN_OR_RETURN(compiler_,
GetCompilerForPlatform(requested_platform_id_));
}
return compiler_.get();
}
absl::StatusOr<std::unique_ptr<PjRtExecutable>>
StreamExecutorGpuCompiler::Compile(CompileOptions options,
const XlaComputation& computation,
const PjRtTopologyDescription& topology,
PjRtClient* client) {
TF_ASSIGN_OR_RETURN(Compiler * gpu_compiler, GetOrCreateCompiler());
CompileOptions input_options = options;
if (xla::IsEarlyExitCompilation(options)) {
auto* se_gpu_topology =
tsl::down_cast<const xla::StreamExecutorGpuTopologyDescription*>(
&topology);
const xla::GpuTopology& gpu_topology = se_gpu_topology->gpu_topology();
TF_RET_CHECK(gpu_topology.has_gpu_target_config())
<< "GPU cross-compile is not yet implemented for topology "
<< se_gpu_topology->ToProto()->ShortDebugString();
options.gpu_target_config = gpu_topology.gpu_target_config();
}
if (!options.gpu_target_config) {
if (client != nullptr) {
TF_RET_CHECK(IsGpuClient(*client))
<< "GPU compilation requires a GPU PjRt client.";
TF_RETURN_IF_ERROR(IsValidTopologyAndClientForCompile(topology, client));
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
client->Compile(computation, options));
return executable;
}
const auto& gpu_topology =
tensorflow::down_cast<const xla::StreamExecutorGpuTopologyDescription&>(
topology);
if (gpu_topology.target_config().has_value()) {
TF_ASSIGN_OR_RETURN(
Compiler::GpuTargetConfig target_config,
Compiler::GpuTargetConfig::FromProto(*gpu_topology.target_config()));
options.gpu_target_config.emplace(std::move(target_config));
} else {
return absl::UnimplementedError(
"Compilation without client and without target_config specified is "
"not implemented");
}
}
TF_RETURN_IF_ERROR(options.ApplyAllOptionOverrides());
std::vector<const Shape*> argument_layout_pointers;
const ExecutableBuildOptions& build_options =
options.executable_build_options;
const bool allow_auto_layout =
build_options.has_debug_options() &&
build_options.debug_options().xla_pjrt_allow_auto_layout_in_hlo();
TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
computation,
[allow_auto_layout](Shape shape) {
if (allow_auto_layout && !shape.has_layout()) {
return shape;
}
return LayoutUtil::GetWithDefaultLayout(shape);
},
options.argument_layouts, &options.executable_build_options,
&argument_layout_pointers));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> hlo_config,
GetHloModuleConfig(computation, argument_layout_pointers,
options.executable_build_options));
HloModuleProto hlo_module_proto = computation.proto();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> hlo_module,
HloModule::CreateFromProto(hlo_module_proto, *hlo_config));
hlo_module->mutable_config()
.mutable_debug_options()
.set_xla_pjrt_allow_auto_layout_in_hlo(true);
UpdateEntryComputationLayout(
hlo_module.get(), std::bind(&Compiler::DefaultDeviceShapeRepresentation,
gpu_compiler, std::placeholders::_1));
DumpHloModuleIfEnabled(*hlo_module, kBeforeOptimizationsDumpName);
AotCompilationOptions aot_options(gpu_compiler->PlatformId());
GpuTopology xla_gpu_topology = GetSingleDeviceGpuTopology(
/*platform_version=*/"", *options.gpu_target_config);
aot_options.set_gpu_topology(xla_gpu_topology);
aot_options.set_run_backend_only(
options.executable_build_options.run_backend_only());
if (IsEarlyExitCompilation(options)) {
aot_options.set_early_exit_point(
AotCompilationOptions::EarlyExitPoint::kAfterLayoutAssignment);
aot_options.set_executor(nullptr);
}
const int num_replicas = hlo_module->config().replica_count();
const int num_partitions = hlo_module->config().num_partitions();
const std::string name = hlo_module->name();
const std::string fingerprint = hlo_module->GetFingerprint128();
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<CompiledModule>> aot_results,
gpu_compiler->CompileAheadOfTime(std::move(hlo_module), aot_options));
return std::make_unique<StreamExecutorExecutable>(
std::move(input_options), std::move(aot_results), num_replicas,
num_partitions, name, fingerprint,
/*default_memory_kind=*/StreamExecutorGpuHbmMemorySpace::kKind);
}
absl::StatusOr<std::unique_ptr<PjRtExecutable>>
StreamExecutorGpuCompiler::Compile(CompileOptions options,
MaybeOwningMlirModule module,
const PjRtTopologyDescription& topology,
PjRtClient* client) {
if (!options.gpu_target_config && client != nullptr) {
TF_RET_CHECK(IsGpuClient(*client))
<< "GPU compilation requires a GPU PjRt client.";
TF_RETURN_IF_ERROR(IsValidTopologyAndClientForCompile(topology, client));
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
client->Compile(std::move(module), options));
return executable;
}
XlaComputation xla_computation;
TF_RETURN_IF_ERROR(MlirToXlaComputation(
module.mlir_module(), xla_computation,
/*use_tuple_args=*/options.parameter_is_tupled_arguments,
/*return_tuple=*/false,
/*exec_build_options=*/&options.executable_build_options,
mlir::mhlo::getGpuChloToHighLevelMhloOptions()));
// MLIR module no longer required - release any memory if owned.
module = MaybeOwningMlirModule();
return Compile(std::move(options), xla_computation, topology, client);
}
} // namespace xla