Make PjrtPlatformId available to StreamExecutorGpuPjRtCompiler To create instances of `PjrtExecutableAbiVersion` each PjRtCompiler needs to know its associated PjRtPlatformId. This change is making it possible to pass in the platform ID when constructing the `StreamExecutorGpuPjRtCompiler` and adjust all the users. PiperOrigin-RevId: 879547572
diff --git a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc index 992e39c..2dcaa5f 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc
@@ -310,7 +310,7 @@ TF_ASSIGN_OR_RETURN( xla::Compiler::GpuTargetConfig gpu_config, xla::Compiler::GpuTargetConfig::FromProto(gpu_target_config)); - xla::StreamExecutorGpuCompiler pjrt_gpu_compiler; + xla::StreamExecutorGpuCompiler pjrt_gpu_compiler(xla::CudaId()); // Create a trivial topology, which won't be used. xla::StreamExecutorGpuTopologyDescription topology(xla::CudaId(), xla::CudaName(), nullptr);
diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 08ef804..4216232 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD
@@ -472,6 +472,7 @@ "//xla/pjrt:maybe_owning_mlir_module", "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_executable", "//xla/pjrt:stream_executor_executable", @@ -486,6 +487,7 @@ "//xla/service:local_service_utils", "//xla/service:platform_util", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_id", "//xla/stream_executor:platform_manager", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", @@ -549,11 +551,15 @@ ":se_gpu_pjrt_compiler_impl", "//xla/hlo/builder:xla_computation", "//xla/pjrt:maybe_owning_mlir_module", + "//xla/pjrt:pjrt_abi_version", + "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_executable", "//xla/service:compiler", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_id", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@llvm-project//mlir:IR", @@ -592,6 +598,7 @@ "//xla/mlir_hlo", "//xla/pjrt:maybe_owning_mlir_module", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_executable", "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", @@ -619,11 +626,13 @@ ":se_gpu_pjrt_compiler_impl", "//xla:literal", "//xla:literal_util", + "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/mlir_hlo", + "//xla/pjrt:compiled_memory_stats", "//xla/pjrt:maybe_owning_mlir_module", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_compiler",
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc index 65a26a9..0f5bbf8 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc
@@ -37,6 +37,7 @@ #include "xla/pjrt/maybe_owning_mlir_module.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/stream_executor_executable.h" @@ -112,8 +113,9 @@ } // namespace StreamExecutorGpuCompiler::StreamExecutorGpuCompiler( - stream_executor::Platform::Id platform_id) - : requested_platform_id_(platform_id) {} + PjRtPlatformId pjrt_platform_id, stream_executor::Platform::Id platform_id) + : requested_platform_id_(platform_id), + pjrt_platform_id_(pjrt_platform_id) {} absl::StatusOr<Compiler*> StreamExecutorGpuCompiler::GetOrCreateCompiler() { absl::MutexLock lock(compiler_mutex_);
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.h index f6f5209..41a909f 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.h
@@ -26,20 +26,24 @@ #include "mlir/IR/BuiltinOps.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/pjrt/maybe_owning_mlir_module.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/service/compiler.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_id.h" namespace xla { // Implements the interfaces that are needed for the registered compiler. class StreamExecutorGpuCompiler : public PjRtCompiler { public: // Constructs a compiler for the default "gpu" platform. - explicit StreamExecutorGpuCompiler() = default; + explicit StreamExecutorGpuCompiler(PjRtPlatformId pjrt_platform_id) + : pjrt_platform_id_(pjrt_platform_id) {} // Constructs a compiler for the given platform. - explicit StreamExecutorGpuCompiler(stream_executor::Platform::Id platform_id); + explicit StreamExecutorGpuCompiler(PjRtPlatformId pjrt_platform_id, + stream_executor::PlatformId platform_id); // Setting CompileOptions.TargetConfig field will trigger deviceless // compilation, which will not query the GPU attached to the machine. @@ -58,6 +62,8 @@ CompileOptions options, MaybeOwningMlirModule module, const PjRtTopologyDescription& topology, PjRtClient* client) override; + PjRtPlatformId pjrt_platform_id() const { return pjrt_platform_id_; } + private: std::optional<stream_executor::Platform::Id> requested_platform_id_; mutable absl::Mutex compiler_mutex_; @@ -67,6 +73,8 @@ // GPU platform if none is specified). If one does not exist, creates one. The // compiler is cached for subsequent calls. absl::StatusOr<Compiler*> GetOrCreateCompiler(); + + PjRtPlatformId pjrt_platform_id_; }; } // namespace xla #endif // XLA_PJRT_GPU_SE_GPU_PJRT_COMPILER_H_
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc index 6bcf079..912f2ea 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc
@@ -32,9 +32,11 @@ #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/layout.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/pjrt/compiled_memory_stats.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/gpu/se_gpu_pjrt_compiler.h" #include "xla/pjrt/maybe_owning_mlir_module.h" @@ -43,6 +45,8 @@ #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" #include "xla/service/compiler.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" @@ -88,13 +92,14 @@ } TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileMlirAndLoad) { - TF_ASSERT_OK_AND_ASSIGN(auto client, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client, GetStreamExecutorGpuClient(GpuClientOptions())); auto se_client = absl::WrapUnique( tensorflow::down_cast<StreamExecutorGpuClient*>(client.release())); Compiler::GpuTargetConfig gpu_target_config = xla::Compiler::GpuTargetConfig( se_client->client()->backend().default_stream_executor()); - StreamExecutorGpuCompiler compiler(se_client->client()->platform()->id()); + StreamExecutorGpuCompiler compiler(se_client->platform_id(), + se_client->client()->platform()->id()); auto context = std::make_unique<mlir::MLIRContext>(); context->loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>(); @@ -122,13 +127,14 @@ } TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileXlaAndLoad) { - TF_ASSERT_OK_AND_ASSIGN(auto client, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client, GetStreamExecutorGpuClient(GpuClientOptions())); auto se_client = absl::WrapUnique( tensorflow::down_cast<StreamExecutorGpuClient*>(client.release())); Compiler::GpuTargetConfig gpu_target_config{ se_client->client()->backend().default_stream_executor()}; - StreamExecutorGpuCompiler compiler(se_client->client()->platform()->id()); + StreamExecutorGpuCompiler compiler(se_client->platform_id(), + se_client->client()->platform()->id()); TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, GetXlaComputation(kProgram)); @@ -151,11 +157,12 @@ } TEST(StreamExecutorGpuCompilerTest, SuccessLoadFromSerializedExecutable) { - TF_ASSERT_OK_AND_ASSIGN(auto client, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client, GetStreamExecutorGpuClient(GpuClientOptions())); auto se_client = absl::WrapUnique( tensorflow::down_cast<StreamExecutorGpuClient*>(client.release())); - StreamExecutorGpuCompiler compiler(se_client->client()->platform()->id()); + StreamExecutorGpuCompiler compiler(se_client->platform_id(), + se_client->client()->platform()->id()); xla::CompileOptions opts; opts.gpu_target_config = Compiler::GpuTargetConfig( se_client->client()->backend().default_stream_executor()); @@ -189,11 +196,12 @@ })"; TEST(StreamExecutorGpuCompilerTest, SuccessSerializeDeserialize) { - TF_ASSERT_OK_AND_ASSIGN(auto client, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client, GetStreamExecutorGpuClient(GpuClientOptions())); auto se_client = absl::WrapUnique( tensorflow::down_cast<StreamExecutorGpuClient*>(client.release())); - StreamExecutorGpuCompiler compiler(se_client->client()->platform()->id()); + StreamExecutorGpuCompiler compiler(se_client->platform_id(), + se_client->client()->platform()->id()); xla::CompileOptions opts; opts.gpu_target_config = Compiler::GpuTargetConfig( se_client->client()->backend().default_stream_executor()); @@ -233,11 +241,12 @@ } )"; TEST(StreamExecutorGpuCompilerTest, UnloadedExecutableMemoryStats) { - TF_ASSERT_OK_AND_ASSIGN(auto client, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client, GetStreamExecutorGpuClient(GpuClientOptions())); auto se_client = absl::WrapUnique( tensorflow::down_cast<StreamExecutorGpuClient*>(client.release())); - StreamExecutorGpuCompiler compiler(se_client->client()->platform()->id()); + StreamExecutorGpuCompiler compiler(se_client->platform_id(), + se_client->client()->platform()->id()); xla::CompileOptions options; options.gpu_target_config = Compiler::GpuTargetConfig( se_client->client()->backend().default_stream_executor()); @@ -282,11 +291,12 @@ rhs_batch_dims={2}, rhs_contracting_dims={0} })"; - TF_ASSERT_OK_AND_ASSIGN(auto client, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client, GetStreamExecutorGpuClient(GpuClientOptions())); auto se_client = absl::WrapUnique( tensorflow::down_cast<StreamExecutorGpuClient*>(client.release())); - StreamExecutorGpuCompiler compiler(se_client->client()->platform()->id()); + StreamExecutorGpuCompiler compiler(se_client->platform_id(), + se_client->client()->platform()->id()); TF_ASSERT_OK_AND_ASSIGN(const PjRtTopologyDescription* topology, se_client->GetTopologyDescription()); TF_ASSERT_OK_AND_ASSIGN(
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_cuda_registration.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_cuda_registration.cc index 2cc7dee..0420f90 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_cuda_registration.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_cuda_registration.cc
@@ -25,9 +25,9 @@ namespace xla { STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, { - PjRtRegisterDefaultCompiler(CudaName(), - std::make_unique<StreamExecutorGpuCompiler>( - stream_executor::cuda::kCudaPlatformId)); + PjRtRegisterDefaultCompiler( + CudaName(), std::make_unique<StreamExecutorGpuCompiler>( + CudaId(), stream_executor::cuda::kCudaPlatformId)); CHECK_OK(StreamExecutorPlatformIdMapping::Global().AddMapping( stream_executor::cuda::kCudaPlatformId, CudaId())); });
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_rocm_registration.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_rocm_registration.cc index 7001896..0074151 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_rocm_registration.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_rocm_registration.cc
@@ -25,9 +25,9 @@ namespace xla { STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, { - PjRtRegisterDefaultCompiler(RocmName(), - std::make_unique<StreamExecutorGpuCompiler>( - stream_executor::rocm::kROCmPlatformId)); + PjRtRegisterDefaultCompiler( + RocmName(), std::make_unique<StreamExecutorGpuCompiler>( + RocmId(), stream_executor::rocm::kROCmPlatformId)); CHECK_OK(StreamExecutorPlatformIdMapping::Global().AddMapping( stream_executor::rocm::kROCmPlatformId, RocmId())); });
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc index 4233188..10a31ca 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc
@@ -41,6 +41,7 @@ #include "xla/pjrt/gpu/se_gpu_topology_description.h" #include "xla/pjrt/maybe_owning_mlir_module.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" @@ -85,7 +86,7 @@ } TEST(StreamExecutorGpuCompilerTest, NoClientXla) { - StreamExecutorGpuCompiler compiler; + StreamExecutorGpuCompiler compiler(CudaId()); StreamExecutorGpuTopologyDescription topology( CudaId(), CudaName(), GetGpuTopology(kFakeDeviceName, 1, 1, 2, 10)); @@ -96,7 +97,7 @@ } TEST(StreamExecutorGpuCompilerTest, TopologyNotSameXla) { - StreamExecutorGpuCompiler compiler; + StreamExecutorGpuCompiler compiler(CudaId()); StreamExecutorGpuTopologyDescription topology( CudaId(), CudaName(), GetGpuTopology(kFakeDeviceName, 1, 1, 2, 10)); @@ -109,7 +110,7 @@ } TEST(StreamExecutorGpuCompilerTest, SuccessXla) { - StreamExecutorGpuCompiler compiler; + StreamExecutorGpuCompiler compiler(CudaId()); TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); @@ -132,7 +133,7 @@ } TEST(StreamExecutorGpuCompilerTest, NoClientMlir) { - StreamExecutorGpuCompiler compiler; + StreamExecutorGpuCompiler compiler(CudaId()); auto context = std::make_unique<mlir::MLIRContext>(); context->loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>(); @@ -152,7 +153,7 @@ } TEST(StreamExecutorGpuCompilerTest, TopologyNotSameMlir) { - StreamExecutorGpuCompiler compiler; + StreamExecutorGpuCompiler compiler(CudaId()); auto context = std::make_unique<mlir::MLIRContext>(); context->loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>(); @@ -173,7 +174,7 @@ } TEST(StreamExecutorGpuCompilerTest, SuccessMlir) { - StreamExecutorGpuCompiler compiler; + StreamExecutorGpuCompiler compiler(CudaId()); auto context = std::make_unique<mlir::MLIRContext>(); context->loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>(); @@ -203,7 +204,7 @@ } TEST(StreamExecutorGpuCompilerTest, SuccessMlirCanBeSerialized) { - StreamExecutorGpuCompiler compiler; + StreamExecutorGpuCompiler compiler(CudaId()); auto context = std::make_unique<mlir::MLIRContext>(); context->loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>(); @@ -245,5 +246,11 @@ LiteralTestUtil::Equal(LiteralUtil::CreateR0(2), *result_literal)); } +TEST(StreamExecutorGpuCompilerTest, PlatformId) { + constexpr PjRtPlatformId kPlatformId = PjRtPlatformId(1234); + StreamExecutorGpuCompiler compiler(kPlatformId); + EXPECT_EQ(compiler.pjrt_platform_id(), kPlatformId); +} + } // namespace } // namespace xla