blob: 2dcaa5f607306d5c76843f5fc635079e45edc59e [file] [edit]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h"
#include <memory>
#include <optional>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/compiler/jit/device_compilation_cluster_signature.h"
#include "tensorflow/compiler/jit/pjrt_device_compiler_client.h"
#include "tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h"
#include "tensorflow/compiler/jit/xla_compiler_options_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h"
#include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_compiler.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/service/compiler.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tfrt/fallback/fallback_state.h"
#include "tensorflow/core/tfrt/graph_executor/export_mlir.h"
#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
#include "tensorflow/core/tfrt/graph_executor/graph_executor.h"
#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h"
#include "tensorflow/core/tfrt/runtime/runtime.h"
#include "tensorflow/core/tfrt/saved_model/saved_model_util.h"
#include "tensorflow/core/tfrt/saved_model/utils/serialize_utils.h"
#include "tensorflow/core/tfrt/utils/utils.h"
#include "tensorflow/core/tpu/virtual_device.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/file_system_helper.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
#include "tfrt/bef/bef_buffer.h" // from @tf_runtime
#include "tfrt/bef_executor/bef_file.h" // from @tf_runtime
#include "tfrt/host_context/resource_context.h" // from @tf_runtime
namespace tensorflow::tfrt_stub {
namespace {
void UpdateCompileOptions(AotOptions& options) {
// Disable DecomposeResourceOpsPass for now, as DecomposeResourceGather does
// not work well with GPU (b/232819415).
if (options.graph_execution_options->enable_tfrt_gpu) {
options.graph_execution_options->compile_options.decompose_resource_ops =
false;
}
options.graph_execution_options->compile_options
.fuse_get_resource_ops_in_hoisting =
!options.graph_execution_options->enable_mlrt;
}
absl::Status CompileTfGraphToHlo(
const FunctionLibraryDefinition* flib_def, const NameAttrList& function,
int graph_def_version, const std::vector<XlaCompiler::Argument>& args,
bool has_ref_vars, bool may_alias_resource_update,
XlaCompiler::Options* options,
XlaCompiler::CompilationResult** compilation_result) {
// Construct a GPU device.
DeviceAttributes device_proto;
device_proto.set_name("/job:localhost/replica:0/task:0/device:GPU:0");
device_proto.set_device_type(DEVICE_GPU);
auto device =
std::make_unique<VirtualDevice>(tensorflow::Env::Default(), device_proto);
XlaPlatformInfo platform_info(DEVICE_GPU, se::cuda::kCudaPlatformId, nullptr,
nullptr, nullptr);
*options = GenerateCompilerOptionsForPjRt(
flib_def, graph_def_version, device.get(), platform_info, nullptr);
// Set device type correctly so that compilation can find kernels.
options->device_type = DeviceType("XLA_GPU_JIT");
XlaCompiler::CompileOptions compile_options =
GenerateCompileOptions(has_ref_vars, may_alias_resource_update);
TfGraphToHloCompiler compiler(*options);
auto compilation_status =
compiler.Compile(compile_options, function, args, *compilation_result);
if ((*compilation_result)->computation == nullptr) {
LOG(ERROR) << compilation_status;
return compilation_status;
}
return absl::OkStatus();
}
// Signature node name is "${node_name}:0". This function extracts node_name.
std::string GetNodeName(const std::string& signature_node_name) {
int node_name_len = signature_node_name.size();
return signature_node_name.substr(0, node_name_len - 2);
}
absl::Status UpdateGraphDefWithInputShapes(
MetaGraphDef& meta_graph_def,
const absl::flat_hash_map<std::string, tensorflow::TensorShapeProto>&
input_shapes,
const std::string& signature_name) {
if (!meta_graph_def.signature_def().contains(signature_name)) {
return absl::NotFoundError(
absl::StrCat("Signature not found: ", signature_name));
}
SignatureDef& signature_def =
(*meta_graph_def.mutable_signature_def())[signature_name];
// Maps from graph node name to its tensor shape.
absl::flat_hash_map<std::string, tensorflow::TensorShapeProto>
graph_input_shapes;
for (const auto& input : input_shapes) {
*((*signature_def.mutable_inputs())[input.first].mutable_tensor_shape()) =
input.second;
const std::string node_name = signature_def.inputs().at(input.first).name();
graph_input_shapes[GetNodeName(node_name)] = input.second;
}
// Update GraphDef node shapes.
for (NodeDef& node : *meta_graph_def.mutable_graph_def()->mutable_node()) {
if (graph_input_shapes.find(node.name()) != graph_input_shapes.end()) {
if (node.attr().contains("_output_shapes")) {
(*(*node.mutable_attr())["_output_shapes"]
.mutable_list()
->mutable_shape())[0] = graph_input_shapes[node.name()];
}
if (node.attr().contains("shape")) {
*((*node.mutable_attr())["shape"].mutable_shape()) =
graph_input_shapes[node.name()];
}
}
}
return absl::OkStatus();
}
// Constructs function and args in place using `xla_func_def`.
void ConstructFunctionAndArgs(const std::string& name,
const FunctionDef& xla_func_def,
NameAttrList& function,
std::vector<XlaCompiler::Argument>& args) {
function.set_name(name);
*function.mutable_attr() = xla_func_def.attr();
args.resize(xla_func_def.signature().input_arg_size());
for (const auto& attr : xla_func_def.arg_attr()) {
XlaCompiler::Argument arg;
const int index = attr.first;
arg.name = index;
TensorShapeProto shape_proto =
attr.second.attr().at("_output_shapes").list().shape(0);
arg.shape = shape_proto;
arg.kind = XlaCompiler::Argument::kParameter;
arg.type = xla_func_def.signature().input_arg(index).type();
arg.initialized = true;
args[index] = arg;
}
}
} // namespace
AotOptions::AotOptions() : graph_execution_options(nullptr) {}
absl::StatusOr<AotResult> AotCompileSavedModel(
absl::string_view input_model_dir, AotOptions aot_options) {
TF_ASSIGN_OR_RETURN(tensorflow::MetaGraphDef meta_graph_def,
ReadSavedModel(input_model_dir, aot_options.tags));
UpdateTpuTargetByBridgeCompatibility(*aot_options.graph_execution_options,
meta_graph_def.graph_def());
UpdateCompileOptions(aot_options);
mlir::DialectRegistry registry;
RegisterMlirDialect(
registry,
aot_options.graph_execution_options->compile_options.backend_compiler);
mlir::MLIRContext context(registry);
tensorflow::SessionOptions session_options =
CreateDefaultSessionOptions(*aot_options.graph_execution_options);
session_options.config.mutable_experimental()->set_optimize_for_static_graph(
true);
LOG_FIRST_N(INFO, 10) << "SessionOptions: "
<< session_options.config.DebugString();
LOG_FIRST_N(INFO, 10) << "GraphExecutionOptions: "
<< *aot_options.graph_execution_options;
const ::tensorflow::FunctionDefLibrary& fdef_lib =
meta_graph_def.graph_def().library();
ASSIGN_OR_RETURN_IN_IMPORT(
std::unique_ptr<tensorflow::tfrt_stub::FallbackState> fallback_state,
FallbackState::CreateWithMockGpuDevice(session_options, fdef_lib));
ASSIGN_OR_RETURN_IN_IMPORT(
mlir::OwningOpRef<mlir::ModuleOp> mlir_module,
ImportSavedModel(&context, meta_graph_def, *fallback_state,
std::string(input_model_dir),
/*import_user_signatures=*/true,
aot_options.graph_execution_options
->run_placer_grappler_on_functions));
auto kernel_registry = std::make_unique<mlrt::KernelRegistry>();
auto resource_context = std::make_unique<tfrt::ResourceContext>();
ModelRuntimeContext model_context(&*aot_options.graph_execution_options,
std::string(input_model_dir),
resource_context.get());
CallableOptions callable_options =
CombineSignatureDefs(meta_graph_def.signature_def());
model_context.set_graph_def(&meta_graph_def.graph_def());
model_context.set_callable_options(&callable_options);
TF_RETURN_IF_ERROR(
aot_options.graph_execution_options->runtime->CreateRuntimeResources(
model_context));
// These are only needed for `CreateRuntimeResources`, and also safer
// since meta_graph_def will be moved.
model_context.set_graph_def(nullptr);
model_context.set_callable_options(nullptr);
tfrt::BefBuffer bef;
std::vector<std::string> xla_function_names;
mlrt::bc::Buffer bytecode_buffer;
if (aot_options.graph_execution_options->enable_mlrt) {
mlir::OwningOpRef<mlir::ModuleOp> module_with_op_keys;
ASSIGN_OR_RETURN_IN_COMPILE(
bytecode_buffer,
tensorflow::mlrt_compiler::ConvertTfMlirToBytecode(
aot_options.graph_execution_options->compile_options,
*fallback_state, mlir_module.get(), model_context,
&module_with_op_keys, &xla_function_names));
if (bytecode_buffer.empty()) {
LOG(ERROR) << "MLRT byte buffer is empty.";
return absl::InternalError("bytecode_buffer is empty.");
}
} else {
RETURN_IF_ERROR_IN_COMPILE(tensorflow::ConvertTfMlirToBef(
aot_options.graph_execution_options->compile_options, mlir_module.get(),
&bef, model_context, fallback_state.get(), &xla_function_names));
if (bef.empty()) {
LOG(ERROR) << "BEF byte buffer is empty.";
return absl::InternalError("BefBuffer is empty.");
}
}
const FunctionLibraryDefinition& flib_def = fallback_state->func_lib_def();
std::vector<FunctionDef> xla_functions;
xla_functions.reserve(xla_function_names.size());
for (const std::string& name : xla_function_names) {
const FunctionDef* xla_func_def = flib_def.Find(name);
if (xla_func_def == nullptr) {
return absl::NotFoundError(
absl::StrCat("XLA function ", name, " not found in library."));
}
xla_functions.push_back(*xla_func_def);
}
if (aot_options.graph_execution_options->enable_mlrt) {
return AotResult{std::move(bytecode_buffer), std::move(xla_functions)};
}
return AotResult{std::move(bef), std::move(xla_functions)};
}
absl::StatusOr<std::unique_ptr<xla::PjRtExecutable>>
AotCompileToGpuPjRtExecutable(
const FunctionLibraryDefinition* flib_def, const NameAttrList& function,
int graph_def_version, const std::vector<XlaCompiler::Argument>& args,
bool has_ref_vars, bool may_alias_resource_update,
const stream_executor::GpuTargetConfigProto& gpu_target_config,
XlaCompiler::CompilationResult** compilation_result) {
XlaCompiler::Options options;
TF_RETURN_IF_ERROR(CompileTfGraphToHlo(
flib_def, function, graph_def_version, args, has_ref_vars,
may_alias_resource_update, &options, compilation_result));
TF_ASSIGN_OR_RETURN(
xla::Compiler::GpuTargetConfig gpu_config,
xla::Compiler::GpuTargetConfig::FromProto(gpu_target_config));
xla::StreamExecutorGpuCompiler pjrt_gpu_compiler(xla::CudaId());
// Create a trivial topology, which won't be used.
xla::StreamExecutorGpuTopologyDescription topology(xla::CudaId(),
xla::CudaName(), nullptr);
xla::CompileOptions pjrt_options =
GetPjRtCompileOptions(options, **compilation_result);
pjrt_options.gpu_target_config = gpu_config;
return pjrt_gpu_compiler.Compile(
pjrt_options, *((*compilation_result)->computation), topology, nullptr);
}
absl::StatusOr<std::string> AotCompileToGpuPjRtLoadedExecutableWithDevice(
const FunctionLibraryDefinition* flib_def, const NameAttrList& function,
int graph_def_version, const std::vector<XlaCompiler::Argument>& args,
bool has_ref_vars, bool may_alias_resource_update,
XlaCompiler::CompilationResult** compilation_result) {
TF_ASSIGN_OR_RETURN(auto client,
xla::GetStreamExecutorGpuClient(xla::GpuClientOptions()));
auto se_client = absl::WrapUnique(
absl::down_cast<xla::StreamExecutorGpuClient*>(client.release()));
XlaCompiler::Options options;
TF_RETURN_IF_ERROR(CompileTfGraphToHlo(
flib_def, function, graph_def_version, args, has_ref_vars,
may_alias_resource_update, &options, compilation_result));
const xla::CompileOptions pjrt_options =
GetPjRtCompileOptions(options, **compilation_result);
TF_ASSIGN_OR_RETURN(auto executable,
se_client->CompileAndLoad(
*((*compilation_result)->computation), pjrt_options));
return se_client->SerializeExecutable(*executable);
}
absl::StatusOr<AotResult::ExecutableMap> AotCompileXlaFunctionsInMetaGraphDef(
const MetaGraphDef& meta_graph_def, const std::string& signature_name,
const absl::flat_hash_map<std::string, tensorflow::TensorShapeProto>&
input_shapes,
const tensorflow::FunctionDefLibrary& fdef_lib,
const tensorflow::SessionOptions& session_options,
const mlir::DialectRegistry& registry, const AotOptions& aot_options,
absl::string_view input_model_dir, ModelRuntimeContext& model_context) {
// Make a copy since we need to modify the graph.
MetaGraphDef input_meta_graph_def = meta_graph_def;
TF_RETURN_IF_ERROR(UpdateGraphDefWithInputShapes(
input_meta_graph_def, input_shapes, signature_name));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<tensorflow::tfrt_stub::FallbackState> fallback_state,
FallbackState::CreateWithMockGpuDevice(session_options, fdef_lib));
// Import the graph corresponding to `signature_name` into MLIR module.
mlir::MLIRContext context(registry);
TF_ASSIGN_OR_RETURN(
mlir::OwningOpRef<mlir::ModuleOp> mlir_module,
ImportSavedModel(
&context, input_meta_graph_def, *fallback_state,
std::string(input_model_dir),
/*import_user_signatures=*/true,
aot_options.graph_execution_options->run_placer_grappler_on_functions,
{signature_name}));
// Runs bridge pass.
std::vector<std::string> xla_function_names;
RETURN_IF_ERROR_IN_COMPILE(ConvertTfMlirToRuntimeExecutable(
aot_options.graph_execution_options->compile_options, mlir_module.get(),
[](mlir::PassManager& pm, mlir::ModuleOp module,
const tensorflow::TfrtPipelineOptions& options) {
return absl::OkStatus();
},
model_context, fallback_state.get(), &xla_function_names));
AotResult::ExecutableMap result;
const FunctionLibraryDefinition& flib_def = fallback_state->func_lib_def();
// Compiles every exported XLA function.
for (const std::string& name : xla_function_names) {
const FunctionDef* xla_func_def = flib_def.Find(name);
if (xla_func_def == nullptr) {
return absl::NotFoundError(
absl::StrCat("XLA function ", name, " not found in library."));
}
NameAttrList func_attr_list = NameAttrList();
std::vector<XlaCompiler::Argument> args;
ConstructFunctionAndArgs(name, *xla_func_def, func_attr_list, args);
XlaCompiler::CompilationResult out_compilation_result;
XlaCompiler::CompilationResult* compilation_result =
&out_compilation_result;
TF_ASSIGN_OR_RETURN(
std::string serialized_executable,
AotCompileToGpuPjRtLoadedExecutableWithDevice(
&flib_def, func_attr_list,
input_meta_graph_def.graph_def().versions().producer(), args, false,
false, &compilation_result));
TF_ASSIGN_OR_RETURN(
auto signature,
DeviceCompilationClusterSignature::Build(func_attr_list, args));
result.emplace(signature, serialized_executable);
}
return result;
}
} // namespace tensorflow::tfrt_stub