blob: b4807e1697cf318b71ad060bc24e43f3b20da74f [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/video_effects/calculators/inference_calculator_webgpu.h"
#include <memory>
#include <optional>
#include "base/no_destructor.h"
#include "base/notreached.h"
#include "services/on_device_model/ml/chrome_ml_api.h"
#include "services/on_device_model/ml/chrome_ml_holder.h"
#include "services/video_effects/calculators/mediapipe_webgpu_utils.h"
#include "services/video_effects/calculators/video_effects_graph_config.h"
#include "third_party/abseil-cpp/absl/status/status.h"
#include "third_party/abseil-cpp/absl/strings/str_format.h"
#include "third_party/dawn/include/dawn/webgpu_cpp.h"
#include "third_party/dawn/include/dawn/wire/WireClient.h"
#include "third_party/mediapipe/src/mediapipe/framework/calculator_context.h"
#include "third_party/mediapipe/src/mediapipe/framework/calculator_contract.h"
#include "third_party/mediapipe/src/mediapipe/framework/calculator_registry.h"
#include "third_party/mediapipe/src/mediapipe/framework/packet.h"
#include "third_party/mediapipe/src/mediapipe/gpu/gpu_buffer.h"
#include "third_party/mediapipe/src/mediapipe/gpu/gpu_buffer_format.h"
#include "third_party/mediapipe/src/mediapipe/gpu/webgpu/webgpu_service.h"
#include "third_party/mediapipe/src/mediapipe/gpu/webgpu/webgpu_texture_view.h"
namespace {
constexpr uint32_t kBufferWidth = video_effects::kInferenceInputBufferWidth;
constexpr uint32_t kBufferHeight = video_effects::kInferenceInputBufferHeight;
constexpr mediapipe::GpuBufferFormat kBufferFormat =
video_effects::WebGpuTextureFormatToGpuBufferFormat(
video_effects::kInferenceInputBufferFormat);
const ChromeMLAPI* GetChromeMlApi() {
static base::NoDestructor<std::unique_ptr<ml::ChromeMLHolder>> holder{
ml::ChromeMLHolder::Create()};
ml::ChromeMLHolder* holder_ptr = holder->get();
if (!holder_ptr) {
return nullptr;
}
return &holder_ptr->api();
}
DISABLE_CFI_DLSYM
void ChromeMLFatalErrorFnImpl(const char* msg) {
NOTREACHED() << "ChromeMFatalErrorFn invoked, msg=" << msg;
}
} // namespace
namespace video_effects {
InferenceCalculatorWebGpu::InferenceCalculatorWebGpu() = default;
InferenceCalculatorWebGpu::~InferenceCalculatorWebGpu() = default;
// static
absl::Status InferenceCalculatorWebGpu::GetContract(
mediapipe::CalculatorContract* cc) {
cc->UseService(mediapipe::kWebGpuService);
cc->InputSidePackets()
.Tag(kStaticConfigInputSidePacketStreamTag)
.Set<video_effects::StaticConfig>();
cc->Inputs()
.Tag(kRuntimeConfigInputStreamTag)
.Set<video_effects::RuntimeConfig>();
cc->Inputs().Tag(kInputTextureStreamTag).Set<mediapipe::GpuBuffer>();
cc->Outputs().Tag(kOutputTextureStreamTag).Set<mediapipe::GpuBuffer>();
return absl::OkStatus();
}
DISABLE_CFI_DLSYM absl::Status InferenceCalculatorWebGpu::Open(
mediapipe::CalculatorContext* cc) {
auto* ml_api = GetChromeMlApi();
CHECK(ml_api);
ml_api->InitDawnProcs(dawn::wire::client::GetProcs());
ml_api->SetFatalErrorNonGpuFn(&ChromeMLFatalErrorFnImpl);
ml_api->SetFatalErrorFn(&ChromeMLFatalErrorFnImpl);
const mediapipe::Packet& static_config_packet =
cc->InputSidePackets().Tag(kStaticConfigInputSidePacketStreamTag);
const StaticConfig& static_config = static_config_packet.Get<StaticConfig>();
CHECK(!static_config.background_segmentation_model().empty());
auto device = cc->Service(mediapipe::kWebGpuService).GetObject().device();
auto adapter = device.GetAdapter();
wgpu::AdapterInfo adapter_info;
adapter.GetInfo(&adapter_info);
inference_engine_ = ml_api->CreateInferenceEngine(
adapter_info, device.Get(),
reinterpret_cast<const char*>(
static_config.background_segmentation_model().data()),
static_config.background_segmentation_model().size());
if (!inference_engine_) {
return absl::InternalError("Failed to create the inference engine!");
}
return absl::OkStatus();
}
DISABLE_CFI_DLSYM absl::Status InferenceCalculatorWebGpu::Process(
mediapipe::CalculatorContext* cc) {
auto* ml_api = GetChromeMlApi();
CHECK(ml_api);
mediapipe::Packet& config_packet =
cc->Inputs().Tag(kRuntimeConfigInputStreamTag).Value();
if (config_packet.IsEmpty()) {
return absl::InternalError("Runtime configuration not present!");
}
if (config_packet.Get<RuntimeConfig>().blur_state != BlurState::kEnabled) {
return absl::InternalError(
"Blur is disabled, the calculator should not even run!");
}
mediapipe::Packet& input_frame =
cc->Inputs().Tag(kInputTextureStreamTag).Value();
mediapipe::GpuBuffer input_buffer = input_frame.Get<mediapipe::GpuBuffer>();
auto input_buffer_view =
input_buffer.GetReadView<mediapipe::WebGpuTextureView>();
mediapipe::GpuBuffer output_buffer(kBufferWidth, kBufferHeight,
kBufferFormat);
auto output_buffer_view =
output_buffer.GetWriteView<mediapipe::WebGpuTextureView>();
if (!ml_api->RunInference(inference_engine_,
input_buffer_view.texture().Get(),
output_buffer_view.texture().Get())) {
return absl::InternalError("Processing failed");
}
cc->Outputs()
.Tag(kOutputTextureStreamTag)
.AddPacket(
mediapipe::MakePacket<mediapipe::GpuBuffer>(std::move(output_buffer))
.At(input_frame.Timestamp()));
return absl::OkStatus();
}
DISABLE_CFI_DLSYM
absl::Status InferenceCalculatorWebGpu::Close(
mediapipe::CalculatorContext* cc) {
auto* ml_api = GetChromeMlApi();
CHECK(ml_api);
if (inference_engine_) {
ml_api->DestroyInferenceEngine(inference_engine_);
inference_engine_ = 0;
}
return absl::OkStatus();
}
REGISTER_CALCULATOR(InferenceCalculatorWebGpu)
} // namespace video_effects