blob: 50d6d6758014d65465d55e4a6d30760377ab5d0f [file] [log] [blame]
// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "mojo/public/cpp/bindings/message.h"
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <algorithm>
#include <utility>
#include "base/bind.h"
#include "base/lazy_instance.h"
#include "base/logging.h"
#include "base/strings/stringprintf.h"
#include "base/threading/thread_local.h"
#include "mojo/public/cpp/bindings/associated_group_controller.h"
#include "mojo/public/cpp/bindings/lib/array_internal.h"
namespace mojo {
namespace {
base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>::
DestructorAtExit g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER;
base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>::
DestructorAtExit g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER;
void DoNotifyBadMessage(Message message, const std::string& error) {
message.NotifyBadMessage(error);
}
} // namespace
Message::Message() {
}
Message::Message(Message&& other)
: buffer_(std::move(other.buffer_)),
handles_(std::move(other.handles_)),
associated_endpoint_handles_(
std::move(other.associated_endpoint_handles_)) {}
Message::~Message() {
CloseHandles();
}
Message& Message::operator=(Message&& other) {
Reset();
std::swap(other.buffer_, buffer_);
std::swap(other.handles_, handles_);
std::swap(other.associated_endpoint_handles_, associated_endpoint_handles_);
return *this;
}
void Message::Reset() {
CloseHandles();
handles_.clear();
associated_endpoint_handles_.clear();
buffer_.reset();
}
void Message::Initialize(size_t capacity, bool zero_initialized) {
DCHECK(!buffer_);
buffer_.reset(new internal::MessageBuffer(capacity, zero_initialized));
}
void Message::InitializeFromMojoMessage(ScopedMessageHandle message,
uint32_t num_bytes,
std::vector<Handle>* handles) {
DCHECK(!buffer_);
buffer_.reset(new internal::MessageBuffer(std::move(message), num_bytes));
handles_.swap(*handles);
}
const uint8_t* Message::payload() const {
if (version() < 2)
return data() + header()->num_bytes;
DCHECK(!header_v2()->payload.is_null());
return static_cast<const uint8_t*>(header_v2()->payload.Get());
}
uint32_t Message::payload_num_bytes() const {
DCHECK_GE(data_num_bytes(), header()->num_bytes);
size_t num_bytes;
if (version() < 2) {
num_bytes = data_num_bytes() - header()->num_bytes;
} else {
auto payload_begin =
reinterpret_cast<uintptr_t>(header_v2()->payload.Get());
auto payload_end =
reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get());
if (!payload_end)
payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes());
DCHECK_GE(payload_end, payload_begin);
num_bytes = payload_end - payload_begin;
}
DCHECK_LE(num_bytes, std::numeric_limits<uint32_t>::max());
return static_cast<uint32_t>(num_bytes);
}
uint32_t Message::payload_num_interface_ids() const {
auto* array_pointer =
version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
return array_pointer ? static_cast<uint32_t>(array_pointer->size()) : 0;
}
const uint32_t* Message::payload_interface_ids() const {
auto* array_pointer =
version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
return array_pointer ? array_pointer->storage() : nullptr;
}
ScopedMessageHandle Message::TakeMojoMessage() {
// If there are associated endpoints transferred,
// SerializeAssociatedEndpointHandles() must be called before this method.
DCHECK(associated_endpoint_handles_.empty());
if (handles_.empty()) // Fast path for the common case: No handles.
return buffer_->TakeMessage();
// Allocate a new message with space for the handles, then copy the buffer
// contents into it.
//
// TODO(rockot): We could avoid this copy by extending GetSerializedSize()
// behavior to collect handles. It's unoptimized for now because it's much
// more common to have messages with no handles.
ScopedMessageHandle new_message;
MojoResult rv = AllocMessage(
data_num_bytes(),
handles_.empty() ? nullptr
: reinterpret_cast<const MojoHandle*>(handles_.data()),
handles_.size(),
MOJO_ALLOC_MESSAGE_FLAG_NONE,
&new_message);
CHECK_EQ(rv, MOJO_RESULT_OK);
handles_.clear();
void* new_buffer = nullptr;
rv = GetMessageBuffer(new_message.get(), &new_buffer);
CHECK_EQ(rv, MOJO_RESULT_OK);
memcpy(new_buffer, data(), data_num_bytes());
buffer_.reset();
return new_message;
}
void Message::NotifyBadMessage(const std::string& error) {
DCHECK(buffer_);
buffer_->NotifyBadMessage(error);
}
void Message::CloseHandles() {
for (std::vector<Handle>::iterator it = handles_.begin();
it != handles_.end(); ++it) {
if (it->is_valid())
CloseRaw(*it);
}
}
void Message::SerializeAssociatedEndpointHandles(
AssociatedGroupController* group_controller) {
if (associated_endpoint_handles_.empty())
return;
DCHECK_GE(version(), 2u);
DCHECK(header_v2()->payload_interface_ids.is_null());
size_t size = associated_endpoint_handles_.size();
auto* data = internal::Array_Data<uint32_t>::New(size, buffer());
header_v2()->payload_interface_ids.Set(data);
for (size_t i = 0; i < size; ++i) {
ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i];
DCHECK(handle.pending_association());
data->storage()[i] =
group_controller->AssociateInterface(std::move(handle));
}
associated_endpoint_handles_.clear();
}
bool Message::DeserializeAssociatedEndpointHandles(
AssociatedGroupController* group_controller) {
associated_endpoint_handles_.clear();
uint32_t num_ids = payload_num_interface_ids();
if (num_ids == 0)
return true;
associated_endpoint_handles_.reserve(num_ids);
uint32_t* ids = header_v2()->payload_interface_ids.Get()->storage();
bool result = true;
for (uint32_t i = 0; i < num_ids; ++i) {
auto handle = group_controller->CreateLocalEndpointHandle(ids[i]);
if (IsValidInterfaceId(ids[i]) && !handle.is_valid()) {
// |ids[i]| itself is valid but handle creation failed. In that case, mark
// deserialization as failed but continue to deserialize the rest of
// handles.
result = false;
}
associated_endpoint_handles_.push_back(std::move(handle));
ids[i] = kInvalidInterfaceId;
}
return result;
}
PassThroughFilter::PassThroughFilter() {}
PassThroughFilter::~PassThroughFilter() {}
bool PassThroughFilter::Accept(Message* message) { return true; }
SyncMessageResponseContext::SyncMessageResponseContext()
: outer_context_(current()) {
g_tls_sync_response_context.Get().Set(this);
}
SyncMessageResponseContext::~SyncMessageResponseContext() {
DCHECK_EQ(current(), this);
g_tls_sync_response_context.Get().Set(outer_context_);
}
// static
SyncMessageResponseContext* SyncMessageResponseContext::current() {
return g_tls_sync_response_context.Get().Get();
}
void SyncMessageResponseContext::ReportBadMessage(const std::string& error) {
GetBadMessageCallback().Run(error);
}
const ReportBadMessageCallback&
SyncMessageResponseContext::GetBadMessageCallback() {
if (bad_message_callback_.is_null()) {
bad_message_callback_ =
base::Bind(&DoNotifyBadMessage, base::Passed(&response_));
}
return bad_message_callback_;
}
MojoResult ReadMessage(MessagePipeHandle handle, Message* message) {
MojoResult rv;
std::vector<Handle> handles;
ScopedMessageHandle mojo_message;
uint32_t num_bytes = 0, num_handles = 0;
rv = ReadMessageNew(handle,
&mojo_message,
&num_bytes,
nullptr,
&num_handles,
MOJO_READ_MESSAGE_FLAG_NONE);
if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) {
DCHECK_GT(num_handles, 0u);
handles.resize(num_handles);
rv = ReadMessageNew(handle,
&mojo_message,
&num_bytes,
reinterpret_cast<MojoHandle*>(handles.data()),
&num_handles,
MOJO_READ_MESSAGE_FLAG_NONE);
}
if (rv != MOJO_RESULT_OK)
return rv;
message->InitializeFromMojoMessage(
std::move(mojo_message), num_bytes, &handles);
return MOJO_RESULT_OK;
}
void ReportBadMessage(const std::string& error) {
internal::MessageDispatchContext* context =
internal::MessageDispatchContext::current();
DCHECK(context);
context->GetBadMessageCallback().Run(error);
}
ReportBadMessageCallback GetBadMessageCallback() {
internal::MessageDispatchContext* context =
internal::MessageDispatchContext::current();
DCHECK(context);
return context->GetBadMessageCallback();
}
namespace internal {
MessageHeaderV2::MessageHeaderV2() = default;
MessageDispatchContext::MessageDispatchContext(Message* message)
: outer_context_(current()), message_(message) {
g_tls_message_dispatch_context.Get().Set(this);
}
MessageDispatchContext::~MessageDispatchContext() {
DCHECK_EQ(current(), this);
g_tls_message_dispatch_context.Get().Set(outer_context_);
}
// static
MessageDispatchContext* MessageDispatchContext::current() {
return g_tls_message_dispatch_context.Get().Get();
}
const ReportBadMessageCallback&
MessageDispatchContext::GetBadMessageCallback() {
if (bad_message_callback_.is_null()) {
bad_message_callback_ =
base::Bind(&DoNotifyBadMessage, base::Passed(message_));
}
return bad_message_callback_;
}
// static
void SyncMessageResponseSetup::SetCurrentSyncResponseMessage(Message* message) {
SyncMessageResponseContext* context = SyncMessageResponseContext::current();
if (context)
context->response_ = std::move(*message);
}
} // namespace internal
} // namespace mojo