[IFRT] Migrate `Array::pjrt_layout()` callers to interpret `nullptr` as a default layout
`Array::pjrt_layout()` will be changed to return `nullptr` to indicate a default layout, where the callers can obtain the corresponding concrete default layout by using `Client::GetDefaultPjRtLayout()`.
This change adds `nullptr` handling preemptively before the new `Array::pjrt_layout()` semantics becomes effective so that the existing code works as before.
Tests using `Array::pjrt_layout()` method calls are minimally updated to add a non-nullness check. They will be updated as `Array::pjrt_layout()` actually returns `nullptr`.
PiperOrigin-RevId: 817893146
diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc
index baa0a1d..c51adef 100644
--- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc
+++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc
@@ -580,6 +580,19 @@
DCHECK(this);
absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>> layout_ptr =
pjrt_layout();
+ if (layout_ptr.ok() && *layout_ptr == nullptr) {
+ layout_ptr =
+ [&]() -> absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>> {
+ TF_ASSIGN_OR_RETURN(xla::ifrt::Shape shard_shape,
+ sharding_->GetShardShape(std::get<Shape>(shape_)));
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<const xla::PjRtLayout> layout,
+ client_->GetDefaultPjRtLayout(dtype_, shard_shape.dims(),
+ sharding_->devices()->devices().front(),
+ sharding_->memory_kind()));
+ return layout;
+ }();
+ }
std::string layout_str =
layout_ptr.ok() ? (*layout_ptr)->ToString() : "<unknown>";
diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc
index 2115d81..f6ef50f 100644
--- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc
+++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc
@@ -1387,6 +1387,16 @@
arrays[i]->shared_ptr_sharding()->WithDeviceAssignment(
dst_devices, memory_kind));
TF_ASSIGN_OR_RETURN(auto new_layout, arrays[i]->pjrt_layout());
+ if (new_layout == nullptr) {
+ TF_ASSIGN_OR_RETURN(
+ xla::ifrt::Shape shard_shape,
+ arrays[i]->sharding().GetShardShape(arrays[i]->shape()));
+ TF_ASSIGN_OR_RETURN(
+ new_layout, GetDefaultPjRtLayout(
+ arrays[i]->dtype(), shard_shape.dims(),
+ arrays[i]->sharding().devices()->devices().front(),
+ arrays[i]->sharding().memory_kind()));
+ }
TF_ASSIGN_OR_RETURN(
new_arrays.emplace_back(),
PjRtArray::Create(this, arrays[i]->dtype(), arrays[i]->shape(),
diff --git a/third_party/xla/xla/python/pjrt_ifrt/reshard_impl_test_lib.cc b/third_party/xla/xla/python/pjrt_ifrt/reshard_impl_test_lib.cc
index 51693dd..b7fb830 100644
--- a/third_party/xla/xla/python/pjrt_ifrt/reshard_impl_test_lib.cc
+++ b/third_party/xla/xla/python/pjrt_ifrt/reshard_impl_test_lib.cc
@@ -402,6 +402,7 @@
// Make sure that the destination layout is actually different from the source
// layout in order to ensure the test coverage.
TF_ASSERT_OK_AND_ASSIGN(const auto src_layout, src_array->pjrt_layout());
+ ASSERT_NE(src_layout, nullptr);
ASSERT_NE(src_layout->xla_layout(), dst_array_spec.layout->xla_layout());
TF_ASSERT_OK_AND_ASSIGN(
@@ -415,6 +416,7 @@
// Verify that the destination array is created with the user-provided layout.
TF_ASSERT_OK_AND_ASSIGN(const auto dst_layout, dst_array->pjrt_layout());
+ ASSERT_NE(dst_layout, nullptr);
EXPECT_EQ(dst_layout->xla_layout(), dst_array_spec.layout->xla_layout());
EXPECT_THAT(CopyArrayToLiteral(dst_array),