Make PjrtPlatformId available to StreamExecutorGpuPjRtCompiler

To create instances of `PjrtExecutableAbiVersion` each PjRtCompiler
needs to know its associated PjRtPlatformId.

This change is making it possible to pass in the platform ID when
constructing the `StreamExecutorGpuPjRtCompiler` and adjust all the users.

PiperOrigin-RevId: 879547572
diff --git a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc
index 992e39c..2dcaa5f 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc
+++ b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc
@@ -310,7 +310,7 @@
   TF_ASSIGN_OR_RETURN(
       xla::Compiler::GpuTargetConfig gpu_config,
       xla::Compiler::GpuTargetConfig::FromProto(gpu_target_config));
-  xla::StreamExecutorGpuCompiler pjrt_gpu_compiler;
+  xla::StreamExecutorGpuCompiler pjrt_gpu_compiler(xla::CudaId());
   // Create a trivial topology, which won't be used.
   xla::StreamExecutorGpuTopologyDescription topology(xla::CudaId(),
                                                      xla::CudaName(), nullptr);
diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD
index 08ef804..4216232 100644
--- a/third_party/xla/xla/pjrt/gpu/BUILD
+++ b/third_party/xla/xla/pjrt/gpu/BUILD
@@ -472,6 +472,7 @@
         "//xla/pjrt:maybe_owning_mlir_module",
         "//xla/pjrt:mlir_to_hlo",
         "//xla/pjrt:pjrt_client",
+        "//xla/pjrt:pjrt_common",
         "//xla/pjrt:pjrt_compiler",
         "//xla/pjrt:pjrt_executable",
         "//xla/pjrt:stream_executor_executable",
@@ -486,6 +487,7 @@
         "//xla/service:local_service_utils",
         "//xla/service:platform_util",
         "//xla/stream_executor:platform",
+        "//xla/stream_executor:platform_id",
         "//xla/stream_executor:platform_manager",
         "//xla/tsl/platform:errors",
         "//xla/tsl/platform:statusor",
@@ -549,11 +551,15 @@
         ":se_gpu_pjrt_compiler_impl",
         "//xla/hlo/builder:xla_computation",
         "//xla/pjrt:maybe_owning_mlir_module",
+        "//xla/pjrt:pjrt_abi_version",
+        "//xla/pjrt:pjrt_common",
         "//xla/pjrt:pjrt_compiler",
         "//xla/pjrt:pjrt_executable",
         "//xla/service:compiler",
         "//xla/stream_executor:platform",
+        "//xla/stream_executor:platform_id",
         "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/status:statusor",
         "@com_google_absl//absl/synchronization",
         "@llvm-project//mlir:IR",
@@ -592,6 +598,7 @@
         "//xla/mlir_hlo",
         "//xla/pjrt:maybe_owning_mlir_module",
         "//xla/pjrt:pjrt_client",
+        "//xla/pjrt:pjrt_common",
         "//xla/pjrt:pjrt_compiler",
         "//xla/pjrt:pjrt_executable",
         "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
@@ -619,11 +626,13 @@
         ":se_gpu_pjrt_compiler_impl",
         "//xla:literal",
         "//xla:literal_util",
+        "//xla:shape_util",
         "//xla:xla_data_proto_cc",
         "//xla/hlo/builder:xla_computation",
         "//xla/hlo/ir:hlo",
         "//xla/hlo/parser:hlo_parser",
         "//xla/mlir_hlo",
+        "//xla/pjrt:compiled_memory_stats",
         "//xla/pjrt:maybe_owning_mlir_module",
         "//xla/pjrt:pjrt_client",
         "//xla/pjrt:pjrt_compiler",
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc
index 65a26a9..0f5bbf8 100644
--- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc
+++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc
@@ -37,6 +37,7 @@
 #include "xla/pjrt/maybe_owning_mlir_module.h"
 #include "xla/pjrt/mlir_to_hlo.h"
 #include "xla/pjrt/pjrt_client.h"
+#include "xla/pjrt/pjrt_common.h"
 #include "xla/pjrt/pjrt_compiler.h"
 #include "xla/pjrt/pjrt_executable.h"
 #include "xla/pjrt/stream_executor_executable.h"
@@ -112,8 +113,9 @@
 }  // namespace
 
 StreamExecutorGpuCompiler::StreamExecutorGpuCompiler(
-    stream_executor::Platform::Id platform_id)
-    : requested_platform_id_(platform_id) {}
+    PjRtPlatformId pjrt_platform_id, stream_executor::Platform::Id platform_id)
+    : requested_platform_id_(platform_id),
+      pjrt_platform_id_(pjrt_platform_id) {}
 
 absl::StatusOr<Compiler*> StreamExecutorGpuCompiler::GetOrCreateCompiler() {
   absl::MutexLock lock(compiler_mutex_);
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.h
index f6f5209..41a909f 100644
--- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.h
+++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.h
@@ -26,20 +26,24 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "xla/hlo/builder/xla_computation.h"
 #include "xla/pjrt/maybe_owning_mlir_module.h"
+#include "xla/pjrt/pjrt_common.h"
 #include "xla/pjrt/pjrt_compiler.h"
 #include "xla/pjrt/pjrt_executable.h"
 #include "xla/service/compiler.h"
 #include "xla/stream_executor/platform.h"
+#include "xla/stream_executor/platform_id.h"
 
 namespace xla {
 // Implements the interfaces that are needed for the registered compiler.
 class StreamExecutorGpuCompiler : public PjRtCompiler {
  public:
   // Constructs a compiler for the default "gpu" platform.
-  explicit StreamExecutorGpuCompiler() = default;
+  explicit StreamExecutorGpuCompiler(PjRtPlatformId pjrt_platform_id)
+      : pjrt_platform_id_(pjrt_platform_id) {}
 
   // Constructs a compiler for the given platform.
-  explicit StreamExecutorGpuCompiler(stream_executor::Platform::Id platform_id);
+  explicit StreamExecutorGpuCompiler(PjRtPlatformId pjrt_platform_id,
+                                     stream_executor::PlatformId platform_id);
 
   // Setting CompileOptions.TargetConfig field will trigger deviceless
   // compilation, which will not query the GPU attached to the machine.
@@ -58,6 +62,8 @@
       CompileOptions options, MaybeOwningMlirModule module,
       const PjRtTopologyDescription& topology, PjRtClient* client) override;
 
+  PjRtPlatformId pjrt_platform_id() const { return pjrt_platform_id_; }
+
  private:
   std::optional<stream_executor::Platform::Id> requested_platform_id_;
   mutable absl::Mutex compiler_mutex_;
@@ -67,6 +73,8 @@
   // GPU platform if none is specified). If one does not exist, creates one. The
   // compiler is cached for subsequent calls.
   absl::StatusOr<Compiler*> GetOrCreateCompiler();
+
+  PjRtPlatformId pjrt_platform_id_;
 };
 }  // namespace xla
 #endif  // XLA_PJRT_GPU_SE_GPU_PJRT_COMPILER_H_
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc
index 6bcf079..912f2ea 100644
--- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc
+++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc
@@ -32,9 +32,11 @@
 #include "xla/hlo/builder/xla_computation.h"
 #include "xla/hlo/ir/hlo_module.h"
 #include "xla/hlo/parser/hlo_parser.h"
+#include "xla/layout.h"
 #include "xla/literal.h"
 #include "xla/literal_util.h"
 #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
+#include "xla/pjrt/compiled_memory_stats.h"
 #include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
 #include "xla/pjrt/gpu/se_gpu_pjrt_compiler.h"
 #include "xla/pjrt/maybe_owning_mlir_module.h"
@@ -43,6 +45,8 @@
 #include "xla/pjrt/pjrt_executable.h"
 #include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
 #include "xla/service/compiler.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
 #include "xla/tests/literal_test_util.h"
 #include "xla/tsl/platform/statusor.h"
 #include "xla/xla_data.pb.h"
@@ -88,13 +92,14 @@
 }
 
 TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileMlirAndLoad) {
-  TF_ASSERT_OK_AND_ASSIGN(auto client,
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client,
                           GetStreamExecutorGpuClient(GpuClientOptions()));
   auto se_client = absl::WrapUnique(
       tensorflow::down_cast<StreamExecutorGpuClient*>(client.release()));
   Compiler::GpuTargetConfig gpu_target_config = xla::Compiler::GpuTargetConfig(
       se_client->client()->backend().default_stream_executor());
-  StreamExecutorGpuCompiler compiler(se_client->client()->platform()->id());
+  StreamExecutorGpuCompiler compiler(se_client->platform_id(),
+                                     se_client->client()->platform()->id());
 
   auto context = std::make_unique<mlir::MLIRContext>();
   context->loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>();
@@ -122,13 +127,14 @@
 }
 
 TEST(StreamExecutorGpuCompilerTest, SuccessAotCompileXlaAndLoad) {
-  TF_ASSERT_OK_AND_ASSIGN(auto client,
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client,
                           GetStreamExecutorGpuClient(GpuClientOptions()));
   auto se_client = absl::WrapUnique(
       tensorflow::down_cast<StreamExecutorGpuClient*>(client.release()));
   Compiler::GpuTargetConfig gpu_target_config{
       se_client->client()->backend().default_stream_executor()};
-  StreamExecutorGpuCompiler compiler(se_client->client()->platform()->id());
+  StreamExecutorGpuCompiler compiler(se_client->platform_id(),
+                                     se_client->client()->platform()->id());
 
   TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation,
                           GetXlaComputation(kProgram));
@@ -151,11 +157,12 @@
 }
 
 TEST(StreamExecutorGpuCompilerTest, SuccessLoadFromSerializedExecutable) {
-  TF_ASSERT_OK_AND_ASSIGN(auto client,
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client,
                           GetStreamExecutorGpuClient(GpuClientOptions()));
   auto se_client = absl::WrapUnique(
       tensorflow::down_cast<StreamExecutorGpuClient*>(client.release()));
-  StreamExecutorGpuCompiler compiler(se_client->client()->platform()->id());
+  StreamExecutorGpuCompiler compiler(se_client->platform_id(),
+                                     se_client->client()->platform()->id());
   xla::CompileOptions opts;
   opts.gpu_target_config = Compiler::GpuTargetConfig(
       se_client->client()->backend().default_stream_executor());
@@ -189,11 +196,12 @@
 })";
 
 TEST(StreamExecutorGpuCompilerTest, SuccessSerializeDeserialize) {
-  TF_ASSERT_OK_AND_ASSIGN(auto client,
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client,
                           GetStreamExecutorGpuClient(GpuClientOptions()));
   auto se_client = absl::WrapUnique(
       tensorflow::down_cast<StreamExecutorGpuClient*>(client.release()));
-  StreamExecutorGpuCompiler compiler(se_client->client()->platform()->id());
+  StreamExecutorGpuCompiler compiler(se_client->platform_id(),
+                                     se_client->client()->platform()->id());
   xla::CompileOptions opts;
   opts.gpu_target_config = Compiler::GpuTargetConfig(
       se_client->client()->backend().default_stream_executor());
@@ -233,11 +241,12 @@
   }
 )";
 TEST(StreamExecutorGpuCompilerTest, UnloadedExecutableMemoryStats) {
-  TF_ASSERT_OK_AND_ASSIGN(auto client,
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client,
                           GetStreamExecutorGpuClient(GpuClientOptions()));
   auto se_client = absl::WrapUnique(
       tensorflow::down_cast<StreamExecutorGpuClient*>(client.release()));
-  StreamExecutorGpuCompiler compiler(se_client->client()->platform()->id());
+  StreamExecutorGpuCompiler compiler(se_client->platform_id(),
+                                     se_client->client()->platform()->id());
   xla::CompileOptions options;
   options.gpu_target_config = Compiler::GpuTargetConfig(
       se_client->client()->backend().default_stream_executor());
@@ -282,11 +291,12 @@
       rhs_batch_dims={2}, rhs_contracting_dims={0}
   })";
 
-  TF_ASSERT_OK_AND_ASSIGN(auto client,
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client,
                           GetStreamExecutorGpuClient(GpuClientOptions()));
   auto se_client = absl::WrapUnique(
       tensorflow::down_cast<StreamExecutorGpuClient*>(client.release()));
-  StreamExecutorGpuCompiler compiler(se_client->client()->platform()->id());
+  StreamExecutorGpuCompiler compiler(se_client->platform_id(),
+                                     se_client->client()->platform()->id());
   TF_ASSERT_OK_AND_ASSIGN(const PjRtTopologyDescription* topology,
                           se_client->GetTopologyDescription());
   TF_ASSERT_OK_AND_ASSIGN(
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_cuda_registration.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_cuda_registration.cc
index 2cc7dee..0420f90 100644
--- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_cuda_registration.cc
+++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_cuda_registration.cc
@@ -25,9 +25,9 @@
 namespace xla {
 
 STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
-  PjRtRegisterDefaultCompiler(CudaName(),
-                              std::make_unique<StreamExecutorGpuCompiler>(
-                                  stream_executor::cuda::kCudaPlatformId));
+  PjRtRegisterDefaultCompiler(
+      CudaName(), std::make_unique<StreamExecutorGpuCompiler>(
+                      CudaId(), stream_executor::cuda::kCudaPlatformId));
   CHECK_OK(StreamExecutorPlatformIdMapping::Global().AddMapping(
       stream_executor::cuda::kCudaPlatformId, CudaId()));
 });
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_rocm_registration.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_rocm_registration.cc
index 7001896..0074151 100644
--- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_rocm_registration.cc
+++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_rocm_registration.cc
@@ -25,9 +25,9 @@
 namespace xla {
 
 STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
-  PjRtRegisterDefaultCompiler(RocmName(),
-                              std::make_unique<StreamExecutorGpuCompiler>(
-                                  stream_executor::rocm::kROCmPlatformId));
+  PjRtRegisterDefaultCompiler(
+      RocmName(), std::make_unique<StreamExecutorGpuCompiler>(
+                      RocmId(), stream_executor::rocm::kROCmPlatformId));
   CHECK_OK(StreamExecutorPlatformIdMapping::Global().AddMapping(
       stream_executor::rocm::kROCmPlatformId, RocmId()));
 });
diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc
index 4233188..10a31ca 100644
--- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc
+++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc
@@ -41,6 +41,7 @@
 #include "xla/pjrt/gpu/se_gpu_topology_description.h"
 #include "xla/pjrt/maybe_owning_mlir_module.h"
 #include "xla/pjrt/pjrt_client.h"
+#include "xla/pjrt/pjrt_common.h"
 #include "xla/pjrt/pjrt_compiler.h"
 #include "xla/pjrt/pjrt_executable.h"
 #include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
@@ -85,7 +86,7 @@
 }
 
 TEST(StreamExecutorGpuCompilerTest, NoClientXla) {
-  StreamExecutorGpuCompiler compiler;
+  StreamExecutorGpuCompiler compiler(CudaId());
   StreamExecutorGpuTopologyDescription topology(
       CudaId(), CudaName(), GetGpuTopology(kFakeDeviceName, 1, 1, 2, 10));
 
@@ -96,7 +97,7 @@
 }
 
 TEST(StreamExecutorGpuCompilerTest, TopologyNotSameXla) {
-  StreamExecutorGpuCompiler compiler;
+  StreamExecutorGpuCompiler compiler(CudaId());
   StreamExecutorGpuTopologyDescription topology(
       CudaId(), CudaName(), GetGpuTopology(kFakeDeviceName, 1, 1, 2, 10));
 
@@ -109,7 +110,7 @@
 }
 
 TEST(StreamExecutorGpuCompilerTest, SuccessXla) {
-  StreamExecutorGpuCompiler compiler;
+  StreamExecutorGpuCompiler compiler(CudaId());
 
   TF_ASSERT_OK_AND_ASSIGN(auto client,
                           GetStreamExecutorGpuClient(GpuClientOptions()));
@@ -132,7 +133,7 @@
 }
 
 TEST(StreamExecutorGpuCompilerTest, NoClientMlir) {
-  StreamExecutorGpuCompiler compiler;
+  StreamExecutorGpuCompiler compiler(CudaId());
 
   auto context = std::make_unique<mlir::MLIRContext>();
   context->loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>();
@@ -152,7 +153,7 @@
 }
 
 TEST(StreamExecutorGpuCompilerTest, TopologyNotSameMlir) {
-  StreamExecutorGpuCompiler compiler;
+  StreamExecutorGpuCompiler compiler(CudaId());
 
   auto context = std::make_unique<mlir::MLIRContext>();
   context->loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>();
@@ -173,7 +174,7 @@
 }
 
 TEST(StreamExecutorGpuCompilerTest, SuccessMlir) {
-  StreamExecutorGpuCompiler compiler;
+  StreamExecutorGpuCompiler compiler(CudaId());
 
   auto context = std::make_unique<mlir::MLIRContext>();
   context->loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>();
@@ -203,7 +204,7 @@
 }
 
 TEST(StreamExecutorGpuCompilerTest, SuccessMlirCanBeSerialized) {
-  StreamExecutorGpuCompiler compiler;
+  StreamExecutorGpuCompiler compiler(CudaId());
 
   auto context = std::make_unique<mlir::MLIRContext>();
   context->loadDialect<mlir::mhlo::MhloDialect, mlir::func::FuncDialect>();
@@ -245,5 +246,11 @@
       LiteralTestUtil::Equal(LiteralUtil::CreateR0(2), *result_literal));
 }
 
+TEST(StreamExecutorGpuCompilerTest, PlatformId) {
+  constexpr PjRtPlatformId kPlatformId = PjRtPlatformId(1234);
+  StreamExecutorGpuCompiler compiler(kPlatformId);
+  EXPECT_EQ(compiler.pjrt_platform_id(), kPlatformId);
+}
+
 }  // namespace
 }  // namespace xla