blob: 42f47d2d1a90c5187510bb99db454fb3061d62c9 [file] [log] [blame]
// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "mojo/public/cpp/bindings/message.h"
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <algorithm>
#include <utility>
#include "base/bind.h"
#include "base/lazy_instance.h"
#include "base/logging.h"
#include "base/numerics/safe_math.h"
#include "base/strings/stringprintf.h"
#include "base/threading/thread_local.h"
#include "mojo/public/cpp/bindings/associated_group_controller.h"
#include "mojo/public/cpp/bindings/lib/array_internal.h"
#include "mojo/public/cpp/bindings/lib/unserialized_message_context.h"
namespace mojo {
namespace {
base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>::
Leaky g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER;
base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>::Leaky
g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER;
void DoNotifyBadMessage(Message message, const std::string& error) {
message.NotifyBadMessage(error);
}
template <typename HeaderType>
void AllocateHeaderFromBuffer(internal::Buffer* buffer, HeaderType** header) {
*header = buffer->AllocateAndGet<HeaderType>();
(*header)->num_bytes = sizeof(HeaderType);
}
void WriteMessageHeader(uint32_t name,
uint32_t flags,
size_t payload_interface_id_count,
internal::Buffer* payload_buffer) {
if (payload_interface_id_count > 0) {
// Version 2
internal::MessageHeaderV2* header;
AllocateHeaderFromBuffer(payload_buffer, &header);
header->version = 2;
header->name = name;
header->flags = flags;
// The payload immediately follows the header.
header->payload.Set(header + 1);
} else if (flags &
(Message::kFlagExpectsResponse | Message::kFlagIsResponse)) {
// Version 1
internal::MessageHeaderV1* header;
AllocateHeaderFromBuffer(payload_buffer, &header);
header->version = 1;
header->name = name;
header->flags = flags;
} else {
internal::MessageHeader* header;
AllocateHeaderFromBuffer(payload_buffer, &header);
header->version = 0;
header->name = name;
header->flags = flags;
}
}
void CreateSerializedMessageObject(uint32_t name,
uint32_t flags,
size_t payload_size,
size_t payload_interface_id_count,
std::vector<ScopedHandle>* handles,
ScopedMessageHandle* out_handle,
internal::Buffer* out_buffer) {
ScopedMessageHandle handle;
MojoResult rv = mojo::CreateMessage(&handle);
DCHECK_EQ(MOJO_RESULT_OK, rv);
DCHECK(handle.is_valid());
void* buffer;
uint32_t buffer_size;
size_t total_size = internal::ComputeSerializedMessageSize(
flags, payload_size, payload_interface_id_count);
DCHECK(base::IsValueInRangeForNumericType<uint32_t>(total_size));
DCHECK(!handles ||
base::IsValueInRangeForNumericType<uint32_t>(handles->size()));
rv = MojoAttachSerializedMessageBuffer(
handle->value(), static_cast<uint32_t>(total_size),
handles ? reinterpret_cast<MojoHandle*>(handles->data()) : nullptr,
handles ? static_cast<uint32_t>(handles->size()) : 0, &buffer,
&buffer_size);
DCHECK_EQ(MOJO_RESULT_OK, rv);
if (handles) {
// Handle ownership has been taken by MojoAttachSerializedMessageBuffer.
for (size_t i = 0; i < handles->size(); ++i)
ignore_result(handles->at(i).release());
}
internal::Buffer payload_buffer(handle.get(), buffer, buffer_size);
// Make sure we zero the memory first!
memset(payload_buffer.data(), 0, total_size);
WriteMessageHeader(name, flags, payload_interface_id_count, &payload_buffer);
*out_handle = std::move(handle);
*out_buffer = std::move(payload_buffer);
}
void SerializeUnserializedContext(MojoMessageHandle message,
uintptr_t context_value) {
auto* context =
reinterpret_cast<internal::UnserializedMessageContext*>(context_value);
void* buffer;
uint32_t buffer_size;
MojoResult attach_result = MojoAttachSerializedMessageBuffer(
message, 0, nullptr, 0, &buffer, &buffer_size);
if (attach_result != MOJO_RESULT_OK)
return;
internal::Buffer payload_buffer(MessageHandle(message), buffer, buffer_size);
WriteMessageHeader(context->message_name(), context->message_flags(),
0 /* payload_interface_id_count */, &payload_buffer);
// We need to copy additional header data which may have been set after
// message construction, as this codepath may be reached at some arbitrary
// time between message send and message dispatch.
static_cast<internal::MessageHeader*>(buffer)->interface_id =
context->header()->interface_id;
if (context->header()->flags &
(Message::kFlagExpectsResponse | Message::kFlagIsResponse)) {
DCHECK_GE(context->header()->version, 1u);
static_cast<internal::MessageHeaderV1*>(buffer)->request_id =
context->header()->request_id;
}
internal::SerializationContext serialization_context;
context->Serialize(&serialization_context, &payload_buffer);
// TODO(crbug.com/753433): Support lazy serialization of associated endpoint
// handles. See corresponding TODO in the bindings generator for proof that
// this DCHECK is indeed valid.
DCHECK(serialization_context.associated_endpoint_handles()->empty());
if (!serialization_context.handles()->empty())
payload_buffer.AttachHandles(serialization_context.mutable_handles());
payload_buffer.Seal();
}
void DestroyUnserializedContext(uintptr_t context) {
delete reinterpret_cast<internal::UnserializedMessageContext*>(context);
}
ScopedMessageHandle CreateUnserializedMessageObject(
std::unique_ptr<internal::UnserializedMessageContext> context) {
ScopedMessageHandle handle;
MojoResult rv = mojo::CreateMessage(&handle);
DCHECK_EQ(MOJO_RESULT_OK, rv);
DCHECK(handle.is_valid());
rv = MojoAttachMessageContext(
handle->value(), reinterpret_cast<uintptr_t>(context.release()),
&SerializeUnserializedContext, &DestroyUnserializedContext);
DCHECK_EQ(MOJO_RESULT_OK, rv);
return handle;
}
} // namespace
Message::Message() = default;
Message::Message(Message&& other)
: handle_(std::move(other.handle_)),
payload_buffer_(std::move(other.payload_buffer_)),
handles_(std::move(other.handles_)),
associated_endpoint_handles_(
std::move(other.associated_endpoint_handles_)),
transferable_(other.transferable_),
serialized_(other.serialized_) {
other.transferable_ = false;
other.serialized_ = false;
}
Message::Message(std::unique_ptr<internal::UnserializedMessageContext> context)
: Message(CreateUnserializedMessageObject(std::move(context))) {}
Message::Message(uint32_t name,
uint32_t flags,
size_t payload_size,
size_t payload_interface_id_count,
std::vector<ScopedHandle>* handles) {
CreateSerializedMessageObject(name, flags, payload_size,
payload_interface_id_count, handles, &handle_,
&payload_buffer_);
transferable_ = true;
serialized_ = true;
}
Message::Message(ScopedMessageHandle handle) {
DCHECK(handle.is_valid());
uintptr_t context_value = 0;
MojoResult get_context_result = MojoGetMessageContext(
handle->value(), &context_value, MOJO_GET_MESSAGE_CONTEXT_FLAG_NONE);
if (get_context_result == MOJO_RESULT_NOT_FOUND) {
// It's a serialized message. Extract handles if possible.
uint32_t num_bytes;
void* buffer;
uint32_t num_handles = 0;
MojoResult rv = MojoGetSerializedMessageContents(
handle->value(), &buffer, &num_bytes, nullptr, &num_handles,
MOJO_GET_SERIALIZED_MESSAGE_CONTENTS_FLAG_NONE);
if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) {
handles_.resize(num_handles);
rv = MojoGetSerializedMessageContents(
handle->value(), &buffer, &num_bytes,
reinterpret_cast<MojoHandle*>(handles_.data()), &num_handles,
MOJO_GET_SERIALIZED_MESSAGE_CONTENTS_FLAG_NONE);
} else {
// No handles, so it's safe to retransmit this message if the caller
// really wants to.
transferable_ = true;
}
if (rv != MOJO_RESULT_OK) {
// Failed to deserialize handles. Leave the Message uninitialized.
return;
}
payload_buffer_ = internal::Buffer(buffer, num_bytes, num_bytes);
serialized_ = true;
} else {
DCHECK_EQ(MOJO_RESULT_OK, get_context_result);
auto* context =
reinterpret_cast<internal::UnserializedMessageContext*>(context_value);
// Dummy data address so common header accessors still behave properly. The
// choice is V1 reflects unserialized message capabilities: we may or may
// not need to support request IDs (which require at least V1), but we never
// (for now, anyway) need to support associated interface handles (V2).
payload_buffer_ =
internal::Buffer(context->header(), sizeof(internal::MessageHeaderV1),
sizeof(internal::MessageHeaderV1));
transferable_ = true;
serialized_ = false;
}
handle_ = std::move(handle);
}
Message::~Message() = default;
Message& Message::operator=(Message&& other) {
handle_ = std::move(other.handle_);
payload_buffer_ = std::move(other.payload_buffer_);
handles_ = std::move(other.handles_);
associated_endpoint_handles_ = std::move(other.associated_endpoint_handles_);
transferable_ = other.transferable_;
other.transferable_ = false;
serialized_ = other.serialized_;
other.serialized_ = false;
return *this;
}
void Message::Reset() {
handle_.reset();
payload_buffer_.Reset();
handles_.clear();
associated_endpoint_handles_.clear();
transferable_ = false;
serialized_ = false;
}
const uint8_t* Message::payload() const {
if (version() < 2)
return data() + header()->num_bytes;
DCHECK(!header_v2()->payload.is_null());
return static_cast<const uint8_t*>(header_v2()->payload.Get());
}
uint32_t Message::payload_num_bytes() const {
DCHECK_GE(data_num_bytes(), header()->num_bytes);
size_t num_bytes;
if (version() < 2) {
num_bytes = data_num_bytes() - header()->num_bytes;
} else {
auto payload_begin =
reinterpret_cast<uintptr_t>(header_v2()->payload.Get());
auto payload_end =
reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get());
if (!payload_end)
payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes());
DCHECK_GE(payload_end, payload_begin);
num_bytes = payload_end - payload_begin;
}
DCHECK(base::IsValueInRangeForNumericType<uint32_t>(num_bytes));
return static_cast<uint32_t>(num_bytes);
}
uint32_t Message::payload_num_interface_ids() const {
auto* array_pointer =
version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
return array_pointer ? static_cast<uint32_t>(array_pointer->size()) : 0;
}
const uint32_t* Message::payload_interface_ids() const {
auto* array_pointer =
version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
return array_pointer ? array_pointer->storage() : nullptr;
}
void Message::AttachHandlesFromSerializationContext(
internal::SerializationContext* context) {
if (context->handles()->empty() &&
context->associated_endpoint_handles()->empty()) {
// No handles attached, so no extra serialization work.
return;
}
if (context->associated_endpoint_handles()->empty()) {
// Attaching only non-associated handles is easier since we don't have to
// modify the message header. Faster path for that.
payload_buffer_.AttachHandles(context->mutable_handles());
return;
}
// Allocate a new message with enough space to hold all attached handles. Copy
// this message's contents into the new one and use it to replace ourself.
//
// TODO(rockot): We could avoid the extra full message allocation by instead
// growing the buffer and carefully moving its contents around. This errs on
// the side of less complexity with probably only marginal performance cost.
uint32_t payload_size = payload_num_bytes();
mojo::Message new_message(name(), header()->flags, payload_size,
context->associated_endpoint_handles()->size(),
context->mutable_handles());
std::swap(*context->mutable_associated_endpoint_handles(),
new_message.associated_endpoint_handles_);
memcpy(new_message.payload_buffer()->AllocateAndGet(payload_size), payload(),
payload_size);
*this = std::move(new_message);
}
ScopedMessageHandle Message::TakeMojoMessage() {
// If there are associated endpoints transferred,
// SerializeAssociatedEndpointHandles() must be called before this method.
DCHECK(associated_endpoint_handles_.empty());
DCHECK(transferable_);
payload_buffer_.Seal();
auto handle = std::move(handle_);
Reset();
return handle;
}
void Message::NotifyBadMessage(const std::string& error) {
DCHECK(handle_.is_valid());
mojo::NotifyBadMessage(handle_.get(), error);
}
void Message::SerializeAssociatedEndpointHandles(
AssociatedGroupController* group_controller) {
if (associated_endpoint_handles_.empty())
return;
DCHECK_GE(version(), 2u);
DCHECK(header_v2()->payload_interface_ids.is_null());
DCHECK(payload_buffer_.is_valid());
DCHECK(handle_.is_valid());
size_t size = associated_endpoint_handles_.size();
internal::Array_Data<uint32_t>::BufferWriter handle_writer;
handle_writer.Allocate(size, &payload_buffer_);
header_v2()->payload_interface_ids.Set(handle_writer.data());
for (size_t i = 0; i < size; ++i) {
ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i];
DCHECK(handle.pending_association());
handle_writer->storage()[i] =
group_controller->AssociateInterface(std::move(handle));
}
associated_endpoint_handles_.clear();
}
bool Message::DeserializeAssociatedEndpointHandles(
AssociatedGroupController* group_controller) {
if (!serialized_)
return true;
associated_endpoint_handles_.clear();
uint32_t num_ids = payload_num_interface_ids();
if (num_ids == 0)
return true;
associated_endpoint_handles_.reserve(num_ids);
uint32_t* ids = header_v2()->payload_interface_ids.Get()->storage();
bool result = true;
for (uint32_t i = 0; i < num_ids; ++i) {
auto handle = group_controller->CreateLocalEndpointHandle(ids[i]);
if (IsValidInterfaceId(ids[i]) && !handle.is_valid()) {
// |ids[i]| itself is valid but handle creation failed. In that case, mark
// deserialization as failed but continue to deserialize the rest of
// handles.
result = false;
}
associated_endpoint_handles_.push_back(std::move(handle));
ids[i] = kInvalidInterfaceId;
}
return result;
}
void Message::SerializeIfNecessary() {
MojoResult rv = MojoSerializeMessage(handle_->value());
if (rv == MOJO_RESULT_FAILED_PRECONDITION)
return;
// Reconstruct this Message instance from the serialized message's handle.
*this = Message(std::move(handle_));
}
std::unique_ptr<internal::UnserializedMessageContext>
Message::TakeUnserializedContext(
const internal::UnserializedMessageContext::Tag* tag) {
DCHECK(handle_.is_valid());
uintptr_t context_value = 0;
MojoResult rv = MojoGetMessageContext(handle_->value(), &context_value,
MOJO_GET_MESSAGE_CONTEXT_FLAG_NONE);
if (rv == MOJO_RESULT_NOT_FOUND)
return nullptr;
DCHECK_EQ(MOJO_RESULT_OK, rv);
auto* context =
reinterpret_cast<internal::UnserializedMessageContext*>(context_value);
if (context->tag() != tag)
return nullptr;
// Detach the context from the message.
rv = MojoGetMessageContext(handle_->value(), &context_value,
MOJO_GET_MESSAGE_CONTEXT_FLAG_RELEASE);
DCHECK_EQ(MOJO_RESULT_OK, rv);
DCHECK_EQ(context_value, reinterpret_cast<uintptr_t>(context));
return base::WrapUnique(context);
}
bool MessageReceiver::PrefersSerializedMessages() {
return false;
}
PassThroughFilter::PassThroughFilter() {}
PassThroughFilter::~PassThroughFilter() {}
bool PassThroughFilter::Accept(Message* message) {
return true;
}
SyncMessageResponseContext::SyncMessageResponseContext()
: outer_context_(current()) {
g_tls_sync_response_context.Get().Set(this);
}
SyncMessageResponseContext::~SyncMessageResponseContext() {
DCHECK_EQ(current(), this);
g_tls_sync_response_context.Get().Set(outer_context_);
}
// static
SyncMessageResponseContext* SyncMessageResponseContext::current() {
return g_tls_sync_response_context.Get().Get();
}
void SyncMessageResponseContext::ReportBadMessage(const std::string& error) {
GetBadMessageCallback().Run(error);
}
const ReportBadMessageCallback&
SyncMessageResponseContext::GetBadMessageCallback() {
if (bad_message_callback_.is_null()) {
bad_message_callback_ =
base::Bind(&DoNotifyBadMessage, base::Passed(&response_));
}
return bad_message_callback_;
}
MojoResult ReadMessage(MessagePipeHandle handle, Message* message) {
ScopedMessageHandle message_handle;
MojoResult rv =
ReadMessageNew(handle, &message_handle, MOJO_READ_MESSAGE_FLAG_NONE);
if (rv != MOJO_RESULT_OK)
return rv;
*message = Message(std::move(message_handle));
return MOJO_RESULT_OK;
}
void ReportBadMessage(const std::string& error) {
internal::MessageDispatchContext* context =
internal::MessageDispatchContext::current();
DCHECK(context);
context->GetBadMessageCallback().Run(error);
}
ReportBadMessageCallback GetBadMessageCallback() {
internal::MessageDispatchContext* context =
internal::MessageDispatchContext::current();
DCHECK(context);
return context->GetBadMessageCallback();
}
namespace internal {
MessageHeaderV2::MessageHeaderV2() = default;
MessageDispatchContext::MessageDispatchContext(Message* message)
: outer_context_(current()), message_(message) {
g_tls_message_dispatch_context.Get().Set(this);
}
MessageDispatchContext::~MessageDispatchContext() {
DCHECK_EQ(current(), this);
g_tls_message_dispatch_context.Get().Set(outer_context_);
}
// static
MessageDispatchContext* MessageDispatchContext::current() {
return g_tls_message_dispatch_context.Get().Get();
}
const ReportBadMessageCallback&
MessageDispatchContext::GetBadMessageCallback() {
if (bad_message_callback_.is_null()) {
bad_message_callback_ =
base::Bind(&DoNotifyBadMessage, base::Passed(message_));
}
return bad_message_callback_;
}
// static
void SyncMessageResponseSetup::SetCurrentSyncResponseMessage(Message* message) {
SyncMessageResponseContext* context = SyncMessageResponseContext::current();
if (context)
context->response_ = std::move(*message);
}
} // namespace internal
} // namespace mojo