blob: 2cb565285f7f799a5bd455243e0b16b73796cd8b [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/windows_services/elevated_tracing_service/system_tracing_session.h"
#include <windows.h>
#include <utility>
#include "base/files/file_path.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/task/sequenced_task_runner.h"
#include "base/win/win_util.h"
#include "chrome/windows_services/elevated_tracing_service/service_integration.h"
#include "chrome/windows_services/service_program/crash_reporting.h"
#include "chrome/windows_services/service_program/get_calling_process.h"
#include "chrome/windows_services/service_program/scoped_client_impersonation.h"
#include "chrome/windows_services/service_program/user_crash_state.h"
#include "components/tracing/common/etw_system_data_source_win.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/platform/named_platform_channel.h"
#include "mojo/public/cpp/platform/platform_channel_endpoint.h"
#include "mojo/public/cpp/system/handle.h"
#include "mojo/public/cpp/system/invitation.h"
#include "mojo/public/cpp/system/message_pipe.h"
#include "services/tracing/public/cpp/traced_process.h"
namespace elevated_tracing_service {
namespace {
// Invoked on the main service thread.
void OnTracedProcessReceiver(base::ProcessId client_pid,
mojo::ScopedMessagePipeHandle pipe) {
// Register the ETW data source the first time a handle is received.
[[maybe_unused]] static const bool etw_data_source_registered = [client_pid] {
tracing::EtwSystemDataSource::Register(client_pid);
return true;
}();
// Drop any previous connection before accepting the new one.
tracing::TracedProcess::ResetTracedProcessReceiver();
tracing::TracedProcess::OnTracedProcessRequest(
mojo::PendingReceiver<tracing::mojom::TracedProcess>(std::move(pipe)));
}
} // namespace
SystemTracingSession::SystemTracingSession() = default;
HRESULT SystemTracingSession::RuntimeClassInitialize(
scoped_refptr<base::SequencedTaskRunner> task_runner) {
task_runner_ = std::move(task_runner);
return S_OK;
}
// Invoked on an arbitrary RPC thread.
HRESULT SystemTracingSession::AcceptInvitation(const wchar_t* server_name,
DWORD* pid) {
if (!pid || !server_name || !*server_name) {
return E_INVALIDARG;
}
*pid = base::kNullProcessId;
if (session_) {
return kErrorSessionAlreadyActive;
}
// Impersonate the client to get a handle to the client's process and per-user
// state related to crash handling.
base::Process client_process;
std::unique_ptr<UserCrashState> user_crash_state;
if (ScopedClientImpersonation impersonate; impersonate.is_valid()) {
client_process = GetCallingProcess();
if (!client_process.IsValid()) {
return kErrorCouldNotObtainCallingProcess;
}
user_crash_state = UserCrashState::Create(impersonate, client_process);
} else {
return impersonate.result();
}
// Get a handle to the client process with appropriate rights.
const auto client_pid = client_process.Pid();
if (client_pid != base::kNullProcessId) {
if (auto dup = base::Process::OpenWithAccess(client_pid, SYNCHRONIZE);
dup.IsValid()) {
std::swap(client_process, dup);
} else {
return kErrorCouldNotOpenCallingProcess;
}
} else {
return kErrorCouldNotGetCallingProcessPid;
}
// This instance is ready to become the active session provided that there
// isn't one already.
auto session = SessionRegistry::RegisterActiveSession(
CastToUnknown(), std::move(client_process));
if (!session) {
return kErrorSessionInProgress;
}
// Start the crash handler with a user-specific database.
if (user_crash_state) {
windows_services::StartCrashHandler(
std::move(user_crash_state),
/*directory_name=*/elevated_tracing_service::GetStorageDirBasename(),
/*process_type=*/"elevated-tracing-service", task_runner_);
}
auto endpoint = mojo::NamedPlatformChannel::ConnectToServer(server_name);
if (!endpoint.is_valid()) {
return E_INVALIDARG; // Invalid channel name.
}
mojo::IncomingInvitation invitation = mojo::IncomingInvitation::Accept(
std::move(endpoint), MOJO_ACCEPT_INVITATION_FLAG_ELEVATED);
if (!invitation.is_valid()) {
return E_INVALIDARG;
}
task_runner_->PostTask(
FROM_HERE, base::BindOnce(&OnTracedProcessReceiver, client_pid,
invitation.ExtractMessagePipe(/*name=*/0)));
session_ = std::move(session);
*pid = ::GetCurrentProcessId();
return S_OK;
}
SystemTracingSession::~SystemTracingSession() = default;
} // namespace elevated_tracing_service