blob: 3794e7baf9a7cfe68322351880af629725a683f2 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/webnn/coreml/context_impl_coreml.h"
#import <CoreML/CoreML.h>
#include "base/sequence_checker.h"
#include "gpu/command_buffer/service/shared_image/shared_image_manager.h"
#include "services/webnn/coreml/graph_builder_coreml.h"
#include "services/webnn/coreml/graph_impl_coreml.h"
#include "services/webnn/coreml/tensor_impl_coreml.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/webnn_types.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/scoped_sequence.h"
#include "services/webnn/webnn_constant_operand.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_context_provider_impl.h"
namespace webnn::coreml {
ContextImplCoreml::ContextImplCoreml(
mojo::PendingAssociatedReceiver<mojom::WebNNContext> receiver,
WebNNContextProviderImpl* context_provider,
mojom::CreateContextOptionsPtr options,
gpu::CommandBufferId command_buffer_id,
std::unique_ptr<ScopedSequence> sequence,
scoped_refptr<gpu::SchedulerTaskRunner> task_runner)
: WebNNContextImpl(std::move(receiver),
context_provider,
GraphBuilderCoreml::GetContextProperties(),
std::move(options),
command_buffer_id,
std::move(sequence),
std::move(task_runner)) {}
ContextImplCoreml::~ContextImplCoreml() = default;
base::WeakPtr<WebNNContextImpl> ContextImplCoreml::AsWeakPtr() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return weak_factory_.GetWeakPtr();
}
void ContextImplCoreml::CreateGraphImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
mojom::GraphInfoPtr graph_info,
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
CreateGraphImplCallback callback) {
GraphImplCoreml::CreateAndBuild(
std::move(receiver), this, std::move(graph_info),
std::move(compute_resource_info), std::move(constant_operands),
std::move(constant_tensor_operands), options().Clone(), properties(),
std::move(callback));
}
base::expected<scoped_refptr<WebNNTensorImpl>, mojom::ErrorPtr>
ContextImplCoreml::CreateTensorImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
mojom::TensorInfoPtr tensor_info) {
// TODO(crbug.com/332350952): implement constant tensors for CoreML.
if (tensor_info->usage.Has(MLTensorUsageFlags::kGraphConstant)) {
return base::unexpected(
mojom::Error::New(mojom::Error::Code::kNotSupportedError,
"Creation of constant tensors is not supported."));
}
// TODO(crbug.com/345352987): implement WebGPU interop tensors for CoreML
// backend.
if (tensor_info->usage.Has(MLTensorUsageFlags::kWebGpuInterop)) {
return base::unexpected(
mojom::Error::New(mojom::Error::Code::kNotSupportedError,
"WebGPU Interop is not supported."));
}
return TensorImplCoreml::Create(std::move(receiver), AsWeakPtr(),
std::move(tensor_info));
}
base::expected<scoped_refptr<WebNNTensorImpl>, mojom::ErrorPtr>
ContextImplCoreml::CreateTensorFromMailboxImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
mojom::TensorInfoPtr tensor_info,
gpu::Mailbox mailbox) {
gpu::SharedImageManager* shared_image_manager =
context_provider()->shared_image_manager();
CHECK(shared_image_manager);
// TODO(crbug.com/345352987): give WebNN its own memory source and tracker.
std::unique_ptr<gpu::WebNNTensorRepresentation> representation =
shared_image_manager->ProduceWebNNTensor(
mailbox,
context_provider()->shared_context_state()->memory_type_tracker());
if (!representation) {
return base::unexpected(mojom::Error::New(mojom::Error::Code::kUnknownError,
"Failed to create tensor."));
}
return TensorImplCoreml::Create(std::move(receiver), AsWeakPtr(),
std::move(tensor_info),
std::move(representation));
}
} // namespace webnn::coreml