// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/coreml/graph_builder.h"

#include <fstream>
#include <optional>
#include <string_view>

#include "base/bits.h"
#include "base/containers/fixed_flat_set.h"
#include "base/containers/span.h"
#include "base/files/file.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/json/json_file_value_serializer.h"
#include "base/metrics/histogram_macros.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_piece.h"
#include "base/timer/elapsed_timer.h"
#include "base/types/expected.h"
#include "base/types/expected_macros.h"
#include "base/unguessable_token.h"
#include "base/uuid.h"
#include "base/values.h"
#include "mojo/public/cpp/base/big_buffer.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "third_party/coremltools/mlmodel/format/FeatureTypes.pb.h"
#include "third_party/coremltools/mlmodel/format/MIL.pb.h"

namespace webnn::coreml {

using mojom::Operand;
using mojom::Operation;

namespace {

const char kWriteFileErrorMessage[] = "Failed to write constant to file.";

const base::FilePath::CharType kMlPackageExtension[] =
    FILE_PATH_LITERAL(".mlpackage");
const base::FilePath::CharType kMlPackageDataDir[] = FILE_PATH_LITERAL("Data");
const base::FilePath::CharType kMlPackageWeightsDir[] =
    FILE_PATH_LITERAL("weights");
const base::FilePath::CharType kMlPackageWeightsFileName[] =
    FILE_PATH_LITERAL("weights.bin");
const base::FilePath::CharType kMlPackageModelFileName[] =
    FILE_PATH_LITERAL("model.mlmodel");
const base::FilePath::CharType kManifestFileName[] =
    FILE_PATH_LITERAL("Manifest.json");

// Information in model package Manifest.json file.
const char kManifestItemAuthorKey[] = "author";
const char kManifestItemAuthorValue[] = "Chromium";
const char kManifestItemDescriptionKey[] = "description";
const char kManifestModelDescriptionValue[] = "CoreML Model Specification";
const char kManifestWeightsDescriptionValue[] = "CoreML Model Weights";
const char kManifestItemNameKey[] = "name";
const char kManifestItemPathKey[] = "path";
const char kManifestModelValue[] = "model.mlmodel";
const char kManifestWeightsValue[] = "weights";
const char kManifestItemInfoEntriesKey[] = "itemInfoEntries";
const char kManifestVersionKey[] = "fileFormatVersion";
const char kManifestVersionValue[] = "1.0.0";
const char kManifestModelIdentifierKey[] = "rootModelIdentifier";

// model op related consts.
const char kPlaceholderOuputName[] = "placeholder_output";
const char kOpConstTypeName[] = "const";
const char kOpAddTypeName[] = "add";
const char kOpMultiplyTypeName[] = "mul";
const char kOpDivideTypeName[] = "real_div";
const char kOpSubtractTypeName[] = "sub";
const char kOpMaximumTypeName[] = "maximum";
const char kOpMinimumTypeName[] = "minimum";
const char kOpPowerTypeName[] = "pow";

// Hard coded path used in the model file to point at the weight path.
const char kWeightsRelativeFilePath[] = "@model_path/weights/weights.bin";

// Maps to types defined in
// https://github.com/apple/coremltools/blob/b416f36054af9ca9d10b2d74ba215d0454677ca0/mlmodel/src/MILBlob/Blob/BlobDataType.hpp#L14
enum class BlobDataType : uint32_t {
  Float16 = 1,
  Float32 = 2,
  UInt8 = 3,
  Int8 = 4,
  BFloat16 = 5,
  Int16 = 6,
  UInt16 = 7,
};

// The weights format follows the definition in
// https://github.com/apple/coremltools/blob/b416f36054af9ca9d10b2d74ba215d0454677ca0/mlmodel/src/MILBlob/Blob/StorageFormat.hpp#L14-L78
// which defines the sentinel, alignment, header, and metadata structures.

// Default sentinel for validation for metadata.
constexpr uint64_t BlobMetadataSentinel = 0xDEADBEEF;

// All entries in the weight file need to be 64 bytes aligned, including the
// header, metadata and the weights.
constexpr uint64_t kWeightAlignment = 64;

struct WeightHeader {
  uint32_t count = 0;    // Number of constant values stored in the weight file.
  uint32_t version = 2;  // The default version that this format supports.

  uint64_t padding = 0;  // Paddings added to be 64 bytes aligned.
  uint64_t padding1 = 0;
  uint64_t padding2 = 0;
  uint64_t padding3 = 0;
  uint64_t padding4 = 0;
  uint64_t padding5 = 0;
  uint64_t padding6 = 0;
};

static_assert(sizeof(WeightHeader) == 64, "WeightHeader must be 64 bytes");

struct WeightMetadata {
  WeightMetadata(BlobDataType mil_data_type, uint64_t size_in_bytes,
                 uint64_t offset)
      : mil_data_type(mil_data_type),
        size_in_bytes(size_in_bytes),
        offset(offset) {}

  uint32_t sentinel = BlobMetadataSentinel;
  BlobDataType mil_data_type;
  uint64_t size_in_bytes;
  uint64_t offset;  // offset of the actual weight blob, after the metadata.

  uint64_t padding = 0;  // Paddings added to be 64 bytes aligned.
  uint64_t padding1 = 0;
  uint64_t padding2 = 0;
  uint64_t padding3 = 0;
  uint64_t padding4 = 0;
};

static_assert(sizeof(WeightMetadata) == 64, "WeightMetadata must be 64 bytes");

std::optional<BlobDataType> OperandTypeToDataTypeInWeightFile(
    mojom::Operand::DataType data_type) {
  switch (data_type) {
    case mojom::Operand::DataType::kFloat16:
      return BlobDataType::Float16;
    case mojom::Operand::DataType::kFloat32:
      return BlobDataType::Float32;
    case mojom::Operand::DataType::kUint8:
      return BlobDataType::UInt8;
    case mojom::Operand::DataType::kInt8:
      return BlobDataType::Int8;
    case mojom::Operand::DataType::kInt32:
    case mojom::Operand::DataType::kUint32:
    case mojom::Operand::DataType::kInt64:
    case mojom::Operand::DataType::kUint64:
      return std::nullopt;
  }
}

std::string GetCoreMLNameFromOperand(uint64_t operand_id,
                                     const Operand& operand) {
  // CoreML doesn't allow op output names to start with numbers, so "var_"
  // prefixes are added.
  switch (operand.kind) {
    case Operand::Kind::kInput:
      CHECK(operand.name.has_value());
      return GetCoreMLNameFromInput(operand.name.value());
    case Operand::Kind::kConstant:
      return "var_" + base::NumberToString(operand_id);
    case Operand::Kind::kOutput:
      if (operand.name.has_value()) {
        return GetCoreMLNameFromOutput(operand.name.value());
      } else {
        // Intermediate outputs don't have names so use operand_id instead.
        return "var_" + base::NumberToString(operand_id);
      }
  }
}

CoreML::Specification::MILSpec::DataType OperandTypeToMILDataType(
    mojom::Operand::DataType data_type) {
  switch (data_type) {
    case mojom::Operand::DataType::kFloat32:
      return CoreML::Specification::MILSpec::DataType::FLOAT32;
    case mojom::Operand::DataType::kFloat16:
      return CoreML::Specification::MILSpec::DataType::FLOAT16;
    case mojom::Operand::DataType::kInt32:
      return CoreML::Specification::MILSpec::DataType::INT32;
    case mojom::Operand::DataType::kUint32:
      return CoreML::Specification::MILSpec::DataType::UINT32;
    case mojom::Operand::DataType::kInt64:
      return CoreML::Specification::MILSpec::DataType::INT64;
    case mojom::Operand::DataType::kUint64:
      return CoreML::Specification::MILSpec::DataType::UINT64;
    case mojom::Operand::DataType::kInt8:
      return CoreML::Specification::MILSpec::DataType::INT8;
    case mojom::Operand::DataType::kUint8:
      return CoreML::Specification::MILSpec::DataType::UINT8;
  }
}

CoreML::Specification::MILSpec::Value CreateStringValue(
    std::string_view value) {
  CoreML::Specification::MILSpec::Value scalar_value{};
  scalar_value.mutable_type()->mutable_tensortype()->set_datatype(
      CoreML::Specification::MILSpec::DataType::STRING);
  scalar_value.mutable_immediatevalue()
      ->mutable_tensor()
      ->mutable_strings()
      ->add_values(value.data());
  return scalar_value;
}

base::unexpected<mojom::ErrorPtr> NewNotSupportedError(std::string message) {
  return base::unexpected(mojom::Error::New(
      mojom::Error::Code::kNotSupportedError, std::move(message)));
}

base::unexpected<mojom::ErrorPtr> NewUnknownError(std::string message) {
  return base::unexpected(
      mojom::Error::New(mojom::Error::Code::kUnknownError, std::move(message)));
}

}  // namespace

std::string GetCoreMLNameFromInput(std::string_view input_name) {
  // Prefix is added to user provided names to avoid collision with intermediate
  // operands' names
  return base::StrCat({"input_", input_name});
}

std::string GetCoreMLNameFromOutput(std::string_view output_name) {
  // Prefix is added to user provided names to avoid collision with intermediate
  // operands' names
  return base::StrCat({"output_", output_name});
}

// static
[[nodiscard]] base::expected<std::unique_ptr<GraphBuilder>, mojom::ErrorPtr>
GraphBuilder::CreateAndBuild(const mojom::GraphInfo& graph_info,
                             const base::FilePath& working_directory) {
  // Use a random string for the model package directory, because MLModel
  // compileModelAtURL creates a folder directly in the NSTemporaryDirectory
  // with the name of the .mlmodel file. Using a random string will avoid any
  // potential name collision of that dir.
  base::FilePath ml_package_dir =
      working_directory.AppendASCII(base::UnguessableToken::Create().ToString())
          .AddExtension(kMlPackageExtension);

  base::FilePath data_dir = ml_package_dir.Append(kMlPackageDataDir);

  auto graph_builder = base::WrapUnique(new GraphBuilder(
      graph_info, std::move(ml_package_dir),
      data_dir.Append(kMlPackageModelFileName),
      data_dir.Append(kMlPackageWeightsDir).Append(kMlPackageWeightsFileName)));

  RETURN_IF_ERROR(graph_builder->BuildCoreMLModel());

  if (!graph_builder->SerializeModel()) {
    return NewUnknownError("Failed to serialize CoreML model.");
  }

  return graph_builder;
}

GraphBuilder::GraphBuilder(const mojom::GraphInfo& graph_info,
                           base::FilePath ml_package_dir,
                           base::FilePath model_file_path,
                           base::FilePath weights_file_path)
    : graph_info_(graph_info),
      ml_package_dir_(std::move(ml_package_dir)),
      model_file_path_(std::move(model_file_path)),
      weights_file_path_(std::move(weights_file_path)) {}

GraphBuilder::~GraphBuilder() = default;

[[nodiscard]] base::expected<void, mojom::ErrorPtr>
GraphBuilder::BuildCoreMLModel() {
  CHECK_EQ(ml_model_.specificationversion(), 0);
  // Based on comment in Model.proto
  //  * 7 : iOS 16, macOS 13, tvOS 16, watchOS 9 (Core ML 6)
  //  * - FLOAT16 array data type
  //  * - GRAYSCALE_FLOAT16 image color space.
  // use the model specification version supported on macOS 13 which is
  // version 7.
  ml_model_.set_specificationversion(7);
  ml_model_.set_isupdatable(false);

  program_ = ml_model_.mutable_mlprogram();
  program_->set_version(1);

  // Creates a Program with a single main function, and a single block within
  // the function. The block contains all the ops right now.
  // TODO(https://crbug.com/327216253): figure out when to use CoreML7 for some
  // ops.
  auto& main_function = (*program_->mutable_functions())["main"];
  // CoreML6 means specification version 7.
  main_function.set_opset("CoreML6");
  auto& block = (*main_function.mutable_block_specializations())["CoreML6"];

  // Add inputs.
  for (auto& input_id : graph_info_->input_operands) {
    RETURN_IF_ERROR(AddInput(input_id, main_function));
  }

  if (graph_info_->input_operands.empty()) {
    AddPlaceholderInput(main_function, block);
  }

  RETURN_IF_ERROR(SetupMlPackageDirStructure());

  base::ElapsedTimer ml_weights_write_timer;
  RETURN_IF_ERROR(WriteWeightsToFile(block));
  UMA_HISTOGRAM_MEDIUM_TIMES("WebNN.CoreML.TimingMs.MLWeightsWrite",
                             ml_weights_write_timer.Elapsed());

  // Add operations.
  for (auto& operation : graph_info_->operations) {
    switch (operation->which()) {
      case mojom::Operation::Tag::kElementWiseBinary: {
        RETURN_IF_ERROR(AddOperationForBinary(
            *operation->get_element_wise_binary(), block));
        break;
      }
      case mojom::Operation::Tag::kArgMinMax:
      case mojom::Operation::Tag::kBatchNormalization:
      case mojom::Operation::Tag::kClamp:
      case mojom::Operation::Tag::kConv2d:
      case mojom::Operation::Tag::kConcat:
      case mojom::Operation::Tag::kElementWiseUnary:
      case mojom::Operation::Tag::kElu:
      case mojom::Operation::Tag::kExpand:
      case mojom::Operation::Tag::kGather:
      case mojom::Operation::Tag::kGemm:
      case mojom::Operation::Tag::kGru:
      case mojom::Operation::Tag::kHardSigmoid:
      case mojom::Operation::Tag::kHardSwish:
      case mojom::Operation::Tag::kLayerNormalization:
      case mojom::Operation::Tag::kInstanceNormalization:
      case mojom::Operation::Tag::kLeakyRelu:
      case mojom::Operation::Tag::kLinear:
      case mojom::Operation::Tag::kLstm:
      case mojom::Operation::Tag::kLstmCell:
      case mojom::Operation::Tag::kMatmul:
      case mojom::Operation::Tag::kPad:
      case mojom::Operation::Tag::kPool2d:
      case mojom::Operation::Tag::kPrelu:
      case mojom::Operation::Tag::kReduce:
      case mojom::Operation::Tag::kRelu:
      case mojom::Operation::Tag::kResample2d:
      case mojom::Operation::Tag::kReshape:
      case mojom::Operation::Tag::kSigmoid:
      case mojom::Operation::Tag::kSlice:
      case mojom::Operation::Tag::kSoftmax:
      case mojom::Operation::Tag::kSoftplus:
      case mojom::Operation::Tag::kSoftsign:
      case mojom::Operation::Tag::kSplit:
      case mojom::Operation::Tag::kTanh:
      case mojom::Operation::Tag::kTranspose:
      case mojom::Operation::Tag::kTriangular:
      case mojom::Operation::Tag::kWhere:
        return NewNotSupportedError("This operator is not implemented.");
    }
  }

  // Add output.
  for (auto& output_id : graph_info_->output_operands) {
    block.add_outputs(GetCoreMLNameFromOperand(
        output_id, GetOperand(output_id)));
    RETURN_IF_ERROR(AddOutput(output_id));
  }
  return base::ok();
}

bool GraphBuilder::SerializeModel() {
  base::ElapsedTimer ml_model_write_timer;
  // This will always overwrite if there is an existing file.
  std::fstream model_file(model_file_path_.value(),
                          std::ios::out | std::ios::binary);
  bool result = ml_model_.SerializeToOstream(&model_file);
  UMA_HISTOGRAM_MEDIUM_TIMES("WebNN.CoreML.TimingMs.MLModelWrite",
                             ml_model_write_timer.Elapsed());
  return result;
}

base::expected<void, mojom::ErrorPtr> GraphBuilder::WriteWeightsToFile(
    CoreML::Specification::MILSpec::Block& block) {
  base::File weights_file(weights_file_path_,
                          base::File::FLAG_CREATE | base::File::FLAG_WRITE);

  uint64_t current_offset = 0;
  WeightHeader header{
      static_cast<uint32_t>(graph_info_->constant_id_to_buffer_map.size())};
  if (!weights_file.WriteAtCurrentPosAndCheck(
          base::byte_span_from_ref(header))) {
    return NewUnknownError(kWriteFileErrorMessage);
  }
  current_offset += sizeof(header);

  for (auto& [key, buffer] : graph_info_->constant_id_to_buffer_map) {
    const Operand& operand = GetOperand(key);
    if (operand.dimensions.empty()) {
      AddConstantImmediateValue(key, block);
      continue;
    }

    std::optional<BlobDataType> weight_type =
        OperandTypeToDataTypeInWeightFile(operand.data_type);
    if (!weight_type.has_value()) {
      return NewNotSupportedError("Unsupported constant type.");
    }

    WeightMetadata metadata(weight_type.value(), buffer.size(),
                            current_offset + sizeof(metadata));

    if (!weights_file.WriteAtCurrentPosAndCheck(
            base::byte_span_from_ref(metadata))) {
      return NewUnknownError(kWriteFileErrorMessage);
    }

    if (!weights_file.WriteAtCurrentPosAndCheck(base::make_span(buffer))) {
      return NewUnknownError(kWriteFileErrorMessage);
    }

    AddConstantFileValue(key, current_offset, operand, block);
    current_offset += sizeof(metadata);
    current_offset += buffer.size();
    current_offset = base::bits::AlignUp(current_offset, kWeightAlignment);
    if (!weights_file.Seek(base::File::Whence::FROM_BEGIN, current_offset)) {
      return NewUnknownError(kWriteFileErrorMessage);
    }
  }
  return base::ok();
}

const mojom::Operand& GraphBuilder::GetOperand(uint64_t operand_id) const {
  return *graph_info_->id_to_operand_map.at(operand_id);
}

const GraphBuilder::OperandInfo* GraphBuilder::FindInputOperandInfo(
    const std::string& input_name) const {
  auto id = input_name_to_id_map_.find(input_name);
  if (id == input_name_to_id_map_.end()) {
    return nullptr;
  }
  return GetOperandInfo(id->second);
}

const base::FilePath& GraphBuilder::GetModelFilePath() {
  return ml_package_dir_;
}

void GraphBuilder::AddPlaceholderInput(
    CoreML::Specification::MILSpec::Function& main_function,
    CoreML::Specification::MILSpec::Block& block) {
  auto* mutable_description = ml_model_.mutable_description();
  auto* feature_description = mutable_description->add_input();

  auto* feature_type = feature_description->mutable_type();
  auto* array_feature_type = feature_type->mutable_multiarraytype();
  array_feature_type->set_datatype(
      CoreML::Specification::ArrayFeatureType_ArrayDataType::
          ArrayFeatureType_ArrayDataType_FLOAT16);

  array_feature_type->add_shape(1);
  feature_description->mutable_name()->assign(kPlaceholderInputName);

  const mojom::Operand operand{mojom::Operand::Kind::kInput,
                               mojom::Operand::DataType::kFloat16,
                               {1},
                               kPlaceholderInputName};

  CoreML::Specification::MILSpec::NamedValueType& input_for_main_function =
      *main_function.add_inputs();
  input_for_main_function.set_name(kPlaceholderInputName);
  auto& value_type = *input_for_main_function.mutable_type();
  PopulateValueType(operand, value_type);

  // The model compute only succeeds when the placeholder is used in one op.
  CoreML::Specification::MILSpec::Operation* placeholder_op =
      block.add_operations();
  (*placeholder_op->mutable_inputs())["x"].add_arguments()->set_name(
      kPlaceholderInputName);
  (*placeholder_op->mutable_inputs())["y"].add_arguments()->set_name(
      kPlaceholderInputName);
  placeholder_op->set_type(kOpAddTypeName);
  CoreML::Specification::MILSpec::NamedValueType& outputs =
      *placeholder_op->add_outputs();
  outputs.set_name(kPlaceholderOuputName);
  auto& output_value_type = *outputs.mutable_type();
  PopulateValueType(operand, output_value_type);
}

[[nodiscard]] base::expected<void, mojom::ErrorPtr> GraphBuilder::AddInput(
    uint64_t input_id,
    CoreML::Specification::MILSpec::Function& main_function) {
  auto* mutable_description = ml_model_.mutable_description();
  auto* feature_description = mutable_description->add_input();
  const Operand& operand = GetOperand(input_id);
  RETURN_IF_ERROR(
      PopulateFeatureDescription(input_id, operand, *feature_description));

  PopulateNamedValueType(input_id, operand, *main_function.add_inputs());

  CHECK(input_name_to_id_map_.try_emplace(operand.name.value(), input_id)
            .second);
  return base::ok();
}

[[nodiscard]] base::expected<void, mojom::ErrorPtr> GraphBuilder::AddOutput(
    uint64_t output_id) {
  const auto output_iterator = id_to_op_input_info_map_.find(output_id);
  CHECK(output_iterator != id_to_op_input_info_map_.end());
  const Operand& operand = GetOperand(output_id);
  auto* mutable_description = ml_model_.mutable_description();
  auto* feature_description = mutable_description->add_output();
  RETURN_IF_ERROR(
      PopulateFeatureDescription(output_id, operand, *feature_description));
  return base::ok();
}

base::expected<void, mojom::ErrorPtr> GraphBuilder::AddOperationForBinary(
    const mojom::ElementWiseBinary& operation,
    CoreML::Specification::MILSpec::Block& block) {
  CoreML::Specification::MILSpec::Operation* op = block.add_operations();

  auto input_lhs = id_to_op_input_info_map_.at(operation.lhs_operand_id);
  auto input_rhs = id_to_op_input_info_map_.at(operation.rhs_operand_id);
  // Input keys (x, y) and supported types are defined in coremltools.
  // https://github.com/apple/coremltools/blob/b416f36054af9ca9d10b2d74ba215d0454677ca0/coremltools/converters/mil/mil/ops/defs/iOS15/elementwise_binary.py#L33
  static constexpr auto kSupportedBinaryOpsTypes =
      base::MakeFixedFlatSet<CoreML::Specification::MILSpec::DataType>(
          {CoreML::Specification::MILSpec::DataType::FLOAT16,
           CoreML::Specification::MILSpec::DataType::FLOAT32,
           CoreML::Specification::MILSpec::DataType::INT32});

  if (!kSupportedBinaryOpsTypes.contains(input_lhs.mil_data_type) ||
      !kSupportedBinaryOpsTypes.contains(input_rhs.mil_data_type)) {
    return NewNotSupportedError("Unsupported input datatype.");
  }

  (*op->mutable_inputs())["x"].add_arguments()->set_name(input_lhs.coreml_name);
  (*op->mutable_inputs())["y"].add_arguments()->set_name(input_rhs.coreml_name);

  switch (operation.kind) {
    case mojom::ElementWiseBinary::Kind::kAdd: {
      op->set_type(kOpAddTypeName);
      break;
    }
    case mojom::ElementWiseBinary::Kind::kDiv: {
      op->set_type(kOpDivideTypeName);
      break;
    }
    case mojom::ElementWiseBinary::Kind::kMul: {
      op->set_type(kOpMultiplyTypeName);
      break;
    }
    case mojom::ElementWiseBinary::Kind::kSub: {
      op->set_type(kOpSubtractTypeName);
      break;
    }
    case mojom::ElementWiseBinary::Kind::kMax: {
      op->set_type(kOpMaximumTypeName);
      break;
    }
    case mojom::ElementWiseBinary::Kind::kMin: {
      op->set_type(kOpMinimumTypeName);
      break;
    }
    case mojom::ElementWiseBinary::Kind::kPow: {
      op->set_type(kOpPowerTypeName);
      break;
    }
    case mojom::ElementWiseBinary::Kind::kEqual:
    case mojom::ElementWiseBinary::Kind::kGreater:
    case mojom::ElementWiseBinary::Kind::kGreaterOrEqual:
    case mojom::ElementWiseBinary::Kind::kLesser:
    case mojom::ElementWiseBinary::Kind::kLesserOrEqual:
      return NewNotSupportedError("Unimplemented Binary Operator.");
  }

  PopulateNamedValueType(
      operation.output_operand_id,
      GetOperand(operation.output_operand_id),
      *op->add_outputs());

  return base::ok();
}

void GraphBuilder::AddConstantImmediateValue(
    uint32_t constant_id, CoreML::Specification::MILSpec::Block& block) {
  const Operand& operand = GetOperand(constant_id);
  auto* op = block.add_operations();
  PopulateNamedValueType(constant_id, operand, *op->add_outputs());
  op->set_type(kOpConstTypeName);

  google::protobuf::Map<std::string, ::CoreML::Specification::MILSpec::Value>&
      attributes = *op->mutable_attributes();
  attributes["name"] =
      CreateStringValue(id_to_op_input_info_map_.at(constant_id).coreml_name);
  CoreML::Specification::MILSpec::Value immediate_value{};
  PopulateValueType(operand, *immediate_value.mutable_type());
  auto* data = immediate_value.mutable_immediatevalue()->mutable_tensor();
  const mojo_base::BigBuffer& buffer =
      graph_info_->constant_id_to_buffer_map.at(constant_id);

  switch (operand.data_type) {
    case mojom::Operand::DataType::kFloat32:
      data->mutable_floats()->add_values(
          *reinterpret_cast<const float*>(buffer.data()));
      break;
    // As per
    // https://github.com/apple/coremltools/blob/bba83f43859e087d50c7d764cb132e7d4b427611/coremltools/converters/mil/backend/mil/helper.py#L23,
    // these types are stored in bytes.
    case mojom::Operand::DataType::kFloat16:
    case mojom::Operand::DataType::kInt8:
    case mojom::Operand::DataType::kUint8:
    case mojom::Operand::DataType::kUint32:
      data->mutable_bytes()->mutable_values()->assign(
          buffer.data(), buffer.data() + buffer.size());
      break;
    case mojom::Operand::DataType::kInt32:
      data->mutable_ints()->add_values(
          *reinterpret_cast<const int*>(buffer.data()));
      break;
    case mojom::Operand::DataType::kInt64:
    case mojom::Operand::DataType::kUint64:
      data->mutable_longints()->add_values(
          *reinterpret_cast<const long*>(buffer.data()));
      break;
  }
  attributes["val"] = std::move(immediate_value);
}

void GraphBuilder::AddConstantFileValue(
    uint32_t constant_id, uint64_t offset, const mojom::Operand& operand,
    CoreML::Specification::MILSpec::Block& block) {
  auto* op = block.add_operations();
  PopulateNamedValueType(constant_id, operand, *op->add_outputs());
  op->set_type(kOpConstTypeName);
  // Blob path is defined in generic Operation.attributes.
  // This follows the actual data structure in
  // https://github.com/apple/coremltools/blob/bba83f43859e087d50c7d764cb132e7d4b427611/coremltools/converters/mil/backend/mil/load.py#L60.
  auto& attributes = *op->mutable_attributes();
  attributes["name"] =
      CreateStringValue(id_to_op_input_info_map_.at(constant_id).coreml_name);
  CoreML::Specification::MILSpec::Value blob_value{};
  PopulateValueType(operand, *blob_value.mutable_type());
  CoreML::Specification::MILSpec::Value::BlobFileValue* blob =
      blob_value.mutable_blobfilevalue();
  blob->set_filename(kWeightsRelativeFilePath);
  blob->set_offset(offset);
  attributes["val"] = std::move(blob_value);
}

[[nodiscard]] const GraphBuilder::OperandInfo* GraphBuilder::GetOperandInfo(
    uint64_t operand_id) const {
  const auto input_iterator = id_to_op_input_info_map_.find(operand_id);
  CHECK(input_iterator != id_to_op_input_info_map_.end());
  return &input_iterator->second;
}

base::expected<void, mojom::ErrorPtr> GraphBuilder::PopulateFeatureDescription(
    uint64_t operand_id, const mojom::Operand& operand,
    ::CoreML::Specification::FeatureDescription& feature_description) {
  auto* feature_type = feature_description.mutable_type();
  auto* array_feature_type = feature_type->mutable_multiarraytype();
  switch (operand.data_type) {
    case mojom::Operand::DataType::kFloat32:
      array_feature_type->set_datatype(
          CoreML::Specification::ArrayFeatureType_ArrayDataType::
              ArrayFeatureType_ArrayDataType_FLOAT32);
      break;
    case mojom::Operand::DataType::kFloat16:
      array_feature_type->set_datatype(
          CoreML::Specification::ArrayFeatureType_ArrayDataType::
              ArrayFeatureType_ArrayDataType_FLOAT16);
      break;
    case mojom::Operand::DataType::kInt32:
      array_feature_type->set_datatype(
          CoreML::Specification::ArrayFeatureType_ArrayDataType::
              ArrayFeatureType_ArrayDataType_INT32);
      break;
    case mojom::Operand::DataType::kUint32:
    case mojom::Operand::DataType::kInt64:
    case mojom::Operand::DataType::kUint64:
    case mojom::Operand::DataType::kInt8:
    case mojom::Operand::DataType::kUint8:
      return NewNotSupportedError("Unsupported input datatype.");
  }
  // FeatureDescriptions are about input and output features, WebNN allows
  // scalar operands to have empty dimensions. At the input and output layers
  // these can be treated as a 1D tensor to satisfy CoreML's requirement of
  // having atleast 1 dimension.
  if (operand.dimensions.empty()) {
    array_feature_type->add_shape(1);
  } else {
    for (int dimension : operand.dimensions) {
      array_feature_type->add_shape(dimension);
    }
  }
  feature_description.mutable_name()->assign(
      GetCoreMLNameFromOperand(operand_id, operand));
  return base::ok();
}

void GraphBuilder::PopulateNamedValueType(
    uint64_t operand_id, const mojom::Operand& operand,
    CoreML::Specification::MILSpec::NamedValueType& named_value_type) {
  named_value_type.set_name(GetCoreMLNameFromOperand(operand_id, operand));
  auto& value_type = *named_value_type.mutable_type();
  PopulateValueType(operand, value_type);

  // WebNN allows 0D scalar operands to have empty dimensions.
  // At the input and output nodes, these can be treated as a 1D tensor to
  // satisfy CoreML's requirement of having at least 1 dimension.
  CHECK(id_to_op_input_info_map_
            .try_emplace(operand_id,
                         OperandInfo(named_value_type.name(),
                                     operand.dimensions.empty()
                                         ? std::vector<uint32_t>({1})
                                         : operand.dimensions,
                                     operand.data_type,
                                     value_type.tensortype().datatype()))
            .second);
}

void GraphBuilder::PopulateValueType(
    const mojom::Operand& operand,
    CoreML::Specification::MILSpec::ValueType& value_type) {
  auto* tensor_type = value_type.mutable_tensortype();
  auto mil_data_type = OperandTypeToMILDataType(operand.data_type);
  tensor_type->set_datatype(mil_data_type);
  tensor_type->set_rank(operand.dimensions.empty() ? 1
                                                   : operand.dimensions.size());
  if (operand.dimensions.empty()) {
    tensor_type->set_rank(1);
    tensor_type->add_dimensions()->mutable_constant()->set_size(1);
  } else {
    tensor_type->set_rank(operand.dimensions.size());
    for (int dimension : operand.dimensions) {
      tensor_type->add_dimensions()->mutable_constant()->set_size(dimension);
    }
  }
}

base::expected<void, mojom::ErrorPtr>
GraphBuilder::SetupMlPackageDirStructure() {
  if (!base::CreateDirectory(ml_package_dir_)) {
    return NewUnknownError("Fail to create .mlpackage directory.");
  }
  base::FilePath data_dir = ml_package_dir_.Append(kMlPackageDataDir);
  if (!base::CreateDirectory(data_dir)) {
    return NewUnknownError("Fail to create .mlpackage/Data directory.");
  }

  base::FilePath weights_dir = data_dir.Append(kMlPackageWeightsDir);
  if (!base::CreateDirectory(weights_dir)) {
    return NewUnknownError("Fail to create .mlpackage/Data/weights directory.");
  }

  // Creates a Manifest.json file that contains the package information. The
  // coremltools definition is here
  // https://github.com/apple/coremltools/blob/169d0ac7657c60e0d96e08612727ac51ab68c431/modelpackage/src/ModelPackage.hpp.
  base::Value::Dict metadata;
  base::Value::Dict item_info_entries;
  base::Value::Dict model_info;
  model_info.Set(kManifestItemAuthorKey, kManifestItemAuthorValue);
  model_info.Set(kManifestItemDescriptionKey, kManifestModelDescriptionValue);
  model_info.Set(kManifestItemNameKey, kManifestModelValue);
  model_info.Set(kManifestItemPathKey, kManifestModelValue);
  // Follows coremltools to use uuid for model identifier and weights
  // identifier.
  // https://github.com/apple/coremltools/blob/169d0ac7657c60e0d96e08612727ac51ab68c431/modelpackage/src/ModelPackage.cpp#L374
  std::string model_identifier(
      base::Uuid::GenerateRandomV4().AsLowercaseString());
  item_info_entries.Set(model_identifier, std::move(model_info));

  base::Value::Dict weights_info;
  weights_info.Set(kManifestItemAuthorKey, kManifestItemAuthorValue);
  weights_info.Set(kManifestItemDescriptionKey,
                   kManifestWeightsDescriptionValue);
  weights_info.Set(kManifestItemNameKey, kManifestModelValue);
  weights_info.Set(kManifestItemPathKey, kManifestWeightsValue);
  item_info_entries.Set(base::Uuid::GenerateRandomV4().AsLowercaseString(),
                        std::move(weights_info));

  metadata.Set(kManifestItemInfoEntriesKey, std::move(item_info_entries));
  metadata.Set(kManifestVersionKey, kManifestVersionValue);
  metadata.Set(kManifestModelIdentifierKey, model_identifier);
  JSONFileValueSerializer serializer(ml_package_dir_.Append(kManifestFileName));
  if (!serializer.Serialize(std::move(metadata))) {
    return NewUnknownError("Fail to create Manifest.json for mlpackage.");
  }

  return base::ok();
}

GraphBuilder::OperandInfo::OperandInfo(
    std::string coreml_name, std::vector<uint32_t> dimensions,
    mojom::Operand::DataType data_type,
    CoreML::Specification::MILSpec::DataType mil_data_type)
    : coreml_name(std::move(coreml_name)),
      dimensions(std::move(dimensions)),
      data_type(data_type),
      mil_data_type(std::move(mil_data_type)) {}

GraphBuilder::OperandInfo::OperandInfo() = default;
GraphBuilder::OperandInfo::~OperandInfo() = default;
GraphBuilder::OperandInfo::OperandInfo(OperandInfo&) = default;
GraphBuilder::OperandInfo::OperandInfo(OperandInfo&&) = default;

}  // namespace webnn::coreml
