blob: 59cf79645cd5ac0026a36e230725267f45996cd5 [file] [log] [blame]
// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "google_apis/gcm/engine/connection_handler_impl.h"
#include <utility>
#include "base/bind.h"
#include "base/location.h"
#include "base/threading/thread_task_runner_handle.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google_apis/gcm/base/mcs_util.h"
#include "google_apis/gcm/base/socket_stream.h"
#include "google_apis/gcm/protocol/mcs.pb.h"
#include "net/base/net_errors.h"
#include "net/socket/stream_socket.h"
using namespace google::protobuf::io;
namespace gcm {
namespace {
// # of bytes a MCS version packet consumes.
const int kVersionPacketLen = 1;
// # of bytes a tag packet consumes.
const int kTagPacketLen = 1;
// Max # of bytes a length packet consumes. A Varint32 can consume up to 5 bytes
// (the msb in each byte is reserved for denoting whether more bytes follow).
// Although the protocol only allows for 4KiB payloads currently, and the socket
// stream buffer is only of size 8KiB, it's possible for certain applications to
// have larger message sizes. When payload is larger than 4KiB, an temporary
// in-memory buffer is used instead of the normal in-place socket stream buffer.
const int kSizePacketLenMin = 1;
const int kSizePacketLenMax = 5;
// The normal limit for a data packet is 4KiB. Any data packet with a size
// larger than this uses the temporary in-memory buffer,
const int kDefaultDataPacketLimit = 1024 * 4;
// The current MCS protocol version.
const int kMCSVersion = 41;
} // namespace
ConnectionHandlerImpl::ConnectionHandlerImpl(
base::TimeDelta read_timeout,
const ProtoReceivedCallback& read_callback,
const ProtoSentCallback& write_callback,
const ConnectionChangedCallback& connection_callback)
: read_timeout_(read_timeout),
handshake_complete_(false),
message_tag_(0),
message_size_(0),
read_callback_(read_callback),
write_callback_(write_callback),
connection_callback_(connection_callback),
size_packet_so_far_(0),
weak_ptr_factory_(this) {
}
ConnectionHandlerImpl::~ConnectionHandlerImpl() {
}
void ConnectionHandlerImpl::Init(
const mcs_proto::LoginRequest& login_request,
mojo::ScopedDataPipeConsumerHandle receive_stream,
mojo::ScopedDataPipeProducerHandle send_stream) {
DCHECK(!read_callback_.is_null());
DCHECK(!write_callback_.is_null());
DCHECK(!connection_callback_.is_null());
// Invalidate any previously outstanding reads.
weak_ptr_factory_.InvalidateWeakPtrs();
handshake_complete_ = false;
message_tag_ = 0;
message_size_ = 0;
input_stream_.reset(new SocketInputStream(std::move(receive_stream)));
output_stream_.reset(new SocketOutputStream(std::move(send_stream)));
Login(login_request);
}
void ConnectionHandlerImpl::Reset() {
CloseConnection();
}
bool ConnectionHandlerImpl::CanSendMessage() const {
return handshake_complete_ && output_stream_.get() &&
output_stream_->GetState() == SocketOutputStream::EMPTY;
}
void ConnectionHandlerImpl::SendMessage(
const google::protobuf::MessageLite& message) {
DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
DCHECK(handshake_complete_);
{
CodedOutputStream coded_output_stream(output_stream_.get());
DVLOG(1) << "Writing proto of size " << message.ByteSize();
int tag = GetMCSProtoTag(message);
DCHECK_NE(tag, -1);
coded_output_stream.WriteRaw(&tag, 1);
coded_output_stream.WriteVarint32(message.ByteSize());
message.SerializeToCodedStream(&coded_output_stream);
}
if (output_stream_->Flush(
base::Bind(&ConnectionHandlerImpl::OnMessageSent,
weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
OnMessageSent();
}
}
void ConnectionHandlerImpl::Login(
const google::protobuf::MessageLite& login_request) {
DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
const char version_byte[1] = {kMCSVersion};
const char login_request_tag[1] = {kLoginRequestTag};
{
CodedOutputStream coded_output_stream(output_stream_.get());
coded_output_stream.WriteRaw(version_byte, 1);
coded_output_stream.WriteRaw(login_request_tag, 1);
coded_output_stream.WriteVarint32(login_request.ByteSize());
login_request.SerializeToCodedStream(&coded_output_stream);
}
if (output_stream_->Flush(
base::Bind(&ConnectionHandlerImpl::OnMessageSent,
weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(&ConnectionHandlerImpl::OnMessageSent,
weak_ptr_factory_.GetWeakPtr()));
}
read_timeout_timer_.Start(FROM_HERE,
read_timeout_,
base::Bind(&ConnectionHandlerImpl::OnTimeout,
weak_ptr_factory_.GetWeakPtr()));
WaitForData(MCS_VERSION_TAG_AND_SIZE);
}
void ConnectionHandlerImpl::OnMessageSent() {
if (!output_stream_.get()) {
// The connection has already been closed. Just return.
DCHECK(!input_stream_.get());
DCHECK(!read_timeout_timer_.IsRunning());
return;
}
if (output_stream_->GetState() != SocketOutputStream::EMPTY) {
int last_error = output_stream_->last_error();
CloseConnection();
// If the socket stream had an error, plumb it up, else plumb up FAILED.
if (last_error == net::OK)
last_error = net::ERR_FAILED;
connection_callback_.Run(last_error);
return;
}
write_callback_.Run();
}
void ConnectionHandlerImpl::GetNextMessage() {
DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() ||
SocketInputStream::READY == input_stream_->GetState());
message_tag_ = 0;
message_size_ = 0;
WaitForData(MCS_TAG_AND_SIZE);
}
void ConnectionHandlerImpl::WaitForData(ProcessingState state) {
DVLOG(1) << "Waiting for MCS data: state == " << state;
if (!input_stream_) {
// The connection has already been closed. Just return.
DCHECK(!output_stream_.get());
DCHECK(!read_timeout_timer_.IsRunning());
return;
}
if (input_stream_->GetState() != SocketInputStream::EMPTY &&
input_stream_->GetState() != SocketInputStream::READY) {
// An error occurred.
int last_error = output_stream_->last_error();
CloseConnection();
// If the socket stream had an error, plumb it up, else plumb up FAILED.
if (last_error == net::OK)
last_error = net::ERR_FAILED;
connection_callback_.Run(last_error);
return;
}
// Used to determine whether a Socket::Read is necessary.
int min_bytes_needed = 0;
// Used to limit the size of the Socket::Read.
int max_bytes_needed = 0;
switch(state) {
case MCS_VERSION_TAG_AND_SIZE:
min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin;
max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax;
break;
case MCS_TAG_AND_SIZE:
min_bytes_needed = kTagPacketLen + kSizePacketLenMin;
max_bytes_needed = kTagPacketLen + kSizePacketLenMax;
break;
case MCS_SIZE:
min_bytes_needed = size_packet_so_far_ + 1;
max_bytes_needed = kSizePacketLenMax;
break;
case MCS_PROTO_BYTES:
read_timeout_timer_.Reset();
if (message_size_ < kDefaultDataPacketLimit) {
// No variability in the message size, set both to the same.
min_bytes_needed = message_size_;
max_bytes_needed = message_size_;
} else {
int bytes_left = message_size_ - payload_input_buffer_.size();
if (bytes_left > kDefaultDataPacketLimit)
bytes_left = kDefaultDataPacketLimit;
min_bytes_needed = bytes_left;
max_bytes_needed = bytes_left;
}
break;
}
DCHECK_GE(max_bytes_needed, min_bytes_needed);
int unread_byte_count = input_stream_->UnreadByteCount();
if (min_bytes_needed > unread_byte_count &&
input_stream_->Refresh(
base::Bind(&ConnectionHandlerImpl::WaitForData,
weak_ptr_factory_.GetWeakPtr(),
state),
max_bytes_needed - unread_byte_count) == net::ERR_IO_PENDING) {
return;
}
// Check for refresh errors.
if (input_stream_->GetState() != SocketInputStream::READY) {
// An error occurred.
int last_error = input_stream_->last_error();
CloseConnection();
// If the socket stream had an error, plumb it up, else plumb up FAILED.
if (last_error == net::OK)
last_error = net::ERR_FAILED;
connection_callback_.Run(last_error);
return;
}
// Check whether read is complete, or needs to be continued (
// SocketInputStream::Refresh can finish without reading all the data).
if (input_stream_->UnreadByteCount() < min_bytes_needed) {
DVLOG(1) << "Socket read finished prematurely. Waiting for "
<< min_bytes_needed - input_stream_->UnreadByteCount()
<< " more bytes.";
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE,
base::BindOnce(&ConnectionHandlerImpl::WaitForData,
weak_ptr_factory_.GetWeakPtr(), MCS_PROTO_BYTES));
return;
}
// Received enough bytes, process them.
DVLOG(1) << "Processing MCS data: state == " << state;
switch(state) {
case MCS_VERSION_TAG_AND_SIZE:
OnGotVersion();
break;
case MCS_TAG_AND_SIZE:
OnGotMessageTag();
break;
case MCS_SIZE:
OnGotMessageSize();
break;
case MCS_PROTO_BYTES:
OnGotMessageBytes();
break;
}
}
void ConnectionHandlerImpl::OnGotVersion() {
uint8_t version = 0;
{
CodedInputStream coded_input_stream(input_stream_.get());
coded_input_stream.ReadRaw(&version, 1);
}
// TODO(zea): remove this when the server is ready.
if (version < kMCSVersion && version != 38) {
LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version);
connection_callback_.Run(net::ERR_FAILED);
return;
}
input_stream_->RebuildBuffer();
// Process the LoginResponse message tag.
OnGotMessageTag();
}
void ConnectionHandlerImpl::OnGotMessageTag() {
if (input_stream_->GetState() != SocketInputStream::READY) {
LOG(ERROR) << "Failed to receive protobuf tag.";
read_callback_.Run(std::unique_ptr<google::protobuf::MessageLite>());
return;
}
{
CodedInputStream coded_input_stream(input_stream_.get());
coded_input_stream.ReadRaw(&message_tag_, 1);
}
DVLOG(1) << "Received proto of type "
<< static_cast<unsigned int>(message_tag_);
if (!read_timeout_timer_.IsRunning()) {
read_timeout_timer_.Start(FROM_HERE,
read_timeout_,
base::Bind(&ConnectionHandlerImpl::OnTimeout,
weak_ptr_factory_.GetWeakPtr()));
}
OnGotMessageSize();
}
void ConnectionHandlerImpl::OnGotMessageSize() {
if (input_stream_->GetState() != SocketInputStream::READY) {
LOG(ERROR) << "Failed to receive message size.";
read_callback_.Run(std::unique_ptr<google::protobuf::MessageLite>());
return;
}
int prev_byte_count = input_stream_->UnreadByteCount();
int result = net::OK;
bool incomplete_size_packet = false;
{
CodedInputStream coded_input_stream(input_stream_.get());
if (!coded_input_stream.ReadVarint32(&message_size_)) {
DVLOG(1) << "Expecting another message size byte.";
if (prev_byte_count >= kSizePacketLenMax) {
// Already had enough bytes, something else went wrong.
LOG(ERROR) << "Failed to process message size";
result = net::ERR_FILE_TOO_BIG;
} else {
// Back up by the amount read.
int bytes_read = prev_byte_count - input_stream_->UnreadByteCount();
input_stream_->BackUp(bytes_read);
size_packet_so_far_ = bytes_read;
incomplete_size_packet = true;
}
}
}
if (result != net::OK) {
connection_callback_.Run(result);
return;
} else if (incomplete_size_packet) {
WaitForData(MCS_SIZE);
return;
}
DVLOG(1) << "Proto size: " << message_size_;
size_packet_so_far_ = 0;
payload_input_buffer_.clear();
if (message_size_ > 0)
WaitForData(MCS_PROTO_BYTES);
else
OnGotMessageBytes();
}
void ConnectionHandlerImpl::OnGotMessageBytes() {
read_timeout_timer_.Stop();
std::unique_ptr<google::protobuf::MessageLite> protobuf(
BuildProtobufFromTag(message_tag_));
// Messages with no content are valid; just use the default protobuf for
// that tag.
if (protobuf.get() && message_size_ == 0) {
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(&ConnectionHandlerImpl::GetNextMessage,
weak_ptr_factory_.GetWeakPtr()));
read_callback_.Run(std::move(protobuf));
return;
}
if (input_stream_->GetState() != SocketInputStream::READY) {
LOG(ERROR) << "Failed to extract protobuf bytes of type "
<< static_cast<unsigned int>(message_tag_);
// Reset the connection.
connection_callback_.Run(net::ERR_FAILED);
return;
}
if (!protobuf.get()) {
LOG(ERROR) << "Received message of invalid type "
<< static_cast<unsigned int>(message_tag_);
connection_callback_.Run(net::ERR_INVALID_ARGUMENT);
return;
}
int result = net::OK;
if (message_size_ < kDefaultDataPacketLimit) {
CodedInputStream coded_input_stream(input_stream_.get());
if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
LOG(ERROR) << "Unable to parse GCM message of type "
<< static_cast<unsigned int>(message_tag_);
result = net::ERR_FAILED;
}
} else {
// Copy any data in the input stream onto the end of the buffer.
const void* data_ptr = NULL;
int size = 0;
input_stream_->Next(&data_ptr, &size);
payload_input_buffer_.insert(payload_input_buffer_.end(),
static_cast<const uint8_t*>(data_ptr),
static_cast<const uint8_t*>(data_ptr) + size);
DCHECK_LE(payload_input_buffer_.size(), message_size_);
if (payload_input_buffer_.size() == message_size_) {
ArrayInputStream buffer_input_stream(payload_input_buffer_.data(),
payload_input_buffer_.size());
CodedInputStream coded_input_stream(&buffer_input_stream);
if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
LOG(ERROR) << "Unable to parse GCM message of type "
<< static_cast<unsigned int>(message_tag_);
result = net::ERR_FAILED;
}
} else {
// Continue reading data.
DVLOG(1) << "Continuing data read. Buffer size is "
<< payload_input_buffer_.size()
<< ", expecting " << message_size_;
input_stream_->RebuildBuffer();
read_timeout_timer_.Start(FROM_HERE,
read_timeout_,
base::Bind(&ConnectionHandlerImpl::OnTimeout,
weak_ptr_factory_.GetWeakPtr()));
WaitForData(MCS_PROTO_BYTES);
return;
}
}
if (result != net::OK) {
// Reset the connection.
connection_callback_.Run(result);
return;
}
input_stream_->RebuildBuffer();
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(&ConnectionHandlerImpl::GetNextMessage,
weak_ptr_factory_.GetWeakPtr()));
if (message_tag_ == kLoginResponseTag) {
if (handshake_complete_) {
LOG(ERROR) << "Unexpected login response.";
} else {
handshake_complete_ = true;
DVLOG(1) << "GCM Handshake complete.";
connection_callback_.Run(net::OK);
}
}
read_callback_.Run(std::move(protobuf));
}
void ConnectionHandlerImpl::OnTimeout() {
LOG(ERROR) << "Timed out waiting for GCM Protocol buffer.";
CloseConnection();
connection_callback_.Run(net::ERR_TIMED_OUT);
}
void ConnectionHandlerImpl::CloseConnection() {
DVLOG(1) << "Closing connection.";
read_timeout_timer_.Stop();
handshake_complete_ = false;
message_tag_ = 0;
message_size_ = 0;
size_packet_so_far_ = 0;
payload_input_buffer_.clear();
input_stream_.reset();
output_stream_.reset();
weak_ptr_factory_.InvalidateWeakPtrs();
}
} // namespace gcm