| // Copyright 2013 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "google_apis/gcm/engine/connection_handler_impl.h" |
| |
| #include <utility> |
| |
| #include "base/bind.h" |
| #include "base/location.h" |
| #include "base/threading/thread_task_runner_handle.h" |
| #include "google/protobuf/io/coded_stream.h" |
| #include "google/protobuf/io/zero_copy_stream_impl_lite.h" |
| #include "google_apis/gcm/base/mcs_util.h" |
| #include "google_apis/gcm/base/socket_stream.h" |
| #include "google_apis/gcm/protocol/mcs.pb.h" |
| #include "net/base/net_errors.h" |
| #include "net/socket/stream_socket.h" |
| |
| using namespace google::protobuf::io; |
| |
| namespace gcm { |
| |
| namespace { |
| |
| // # of bytes a MCS version packet consumes. |
| const int kVersionPacketLen = 1; |
| // # of bytes a tag packet consumes. |
| const int kTagPacketLen = 1; |
| // Max # of bytes a length packet consumes. A Varint32 can consume up to 5 bytes |
| // (the msb in each byte is reserved for denoting whether more bytes follow). |
| // Although the protocol only allows for 4KiB payloads currently, and the socket |
| // stream buffer is only of size 8KiB, it's possible for certain applications to |
| // have larger message sizes. When payload is larger than 4KiB, an temporary |
| // in-memory buffer is used instead of the normal in-place socket stream buffer. |
| const int kSizePacketLenMin = 1; |
| const int kSizePacketLenMax = 5; |
| |
| // The normal limit for a data packet is 4KiB. Any data packet with a size |
| // larger than this uses the temporary in-memory buffer, |
| const int kDefaultDataPacketLimit = 1024 * 4; |
| |
| // The current MCS protocol version. |
| const int kMCSVersion = 41; |
| |
| } // namespace |
| |
| ConnectionHandlerImpl::ConnectionHandlerImpl( |
| base::TimeDelta read_timeout, |
| const ProtoReceivedCallback& read_callback, |
| const ProtoSentCallback& write_callback, |
| const ConnectionChangedCallback& connection_callback) |
| : read_timeout_(read_timeout), |
| handshake_complete_(false), |
| message_tag_(0), |
| message_size_(0), |
| read_callback_(read_callback), |
| write_callback_(write_callback), |
| connection_callback_(connection_callback), |
| size_packet_so_far_(0), |
| weak_ptr_factory_(this) { |
| } |
| |
| ConnectionHandlerImpl::~ConnectionHandlerImpl() { |
| } |
| |
| void ConnectionHandlerImpl::Init( |
| const mcs_proto::LoginRequest& login_request, |
| mojo::ScopedDataPipeConsumerHandle receive_stream, |
| mojo::ScopedDataPipeProducerHandle send_stream) { |
| DCHECK(!read_callback_.is_null()); |
| DCHECK(!write_callback_.is_null()); |
| DCHECK(!connection_callback_.is_null()); |
| |
| // Invalidate any previously outstanding reads. |
| weak_ptr_factory_.InvalidateWeakPtrs(); |
| |
| handshake_complete_ = false; |
| message_tag_ = 0; |
| message_size_ = 0; |
| input_stream_.reset(new SocketInputStream(std::move(receive_stream))); |
| output_stream_.reset(new SocketOutputStream(std::move(send_stream))); |
| |
| Login(login_request); |
| } |
| |
| void ConnectionHandlerImpl::Reset() { |
| CloseConnection(); |
| } |
| |
| bool ConnectionHandlerImpl::CanSendMessage() const { |
| return handshake_complete_ && output_stream_.get() && |
| output_stream_->GetState() == SocketOutputStream::EMPTY; |
| } |
| |
| void ConnectionHandlerImpl::SendMessage( |
| const google::protobuf::MessageLite& message) { |
| DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); |
| DCHECK(handshake_complete_); |
| |
| { |
| CodedOutputStream coded_output_stream(output_stream_.get()); |
| DVLOG(1) << "Writing proto of size " << message.ByteSize(); |
| int tag = GetMCSProtoTag(message); |
| DCHECK_NE(tag, -1); |
| coded_output_stream.WriteRaw(&tag, 1); |
| coded_output_stream.WriteVarint32(message.ByteSize()); |
| message.SerializeToCodedStream(&coded_output_stream); |
| } |
| |
| if (output_stream_->Flush( |
| base::Bind(&ConnectionHandlerImpl::OnMessageSent, |
| weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { |
| OnMessageSent(); |
| } |
| } |
| |
| void ConnectionHandlerImpl::Login( |
| const google::protobuf::MessageLite& login_request) { |
| DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); |
| |
| const char version_byte[1] = {kMCSVersion}; |
| const char login_request_tag[1] = {kLoginRequestTag}; |
| { |
| CodedOutputStream coded_output_stream(output_stream_.get()); |
| coded_output_stream.WriteRaw(version_byte, 1); |
| coded_output_stream.WriteRaw(login_request_tag, 1); |
| coded_output_stream.WriteVarint32(login_request.ByteSize()); |
| login_request.SerializeToCodedStream(&coded_output_stream); |
| } |
| |
| if (output_stream_->Flush( |
| base::Bind(&ConnectionHandlerImpl::OnMessageSent, |
| weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { |
| base::ThreadTaskRunnerHandle::Get()->PostTask( |
| FROM_HERE, base::BindOnce(&ConnectionHandlerImpl::OnMessageSent, |
| weak_ptr_factory_.GetWeakPtr())); |
| } |
| |
| read_timeout_timer_.Start(FROM_HERE, |
| read_timeout_, |
| base::Bind(&ConnectionHandlerImpl::OnTimeout, |
| weak_ptr_factory_.GetWeakPtr())); |
| WaitForData(MCS_VERSION_TAG_AND_SIZE); |
| } |
| |
| void ConnectionHandlerImpl::OnMessageSent() { |
| if (!output_stream_.get()) { |
| // The connection has already been closed. Just return. |
| DCHECK(!input_stream_.get()); |
| DCHECK(!read_timeout_timer_.IsRunning()); |
| return; |
| } |
| |
| if (output_stream_->GetState() != SocketOutputStream::EMPTY) { |
| int last_error = output_stream_->last_error(); |
| CloseConnection(); |
| // If the socket stream had an error, plumb it up, else plumb up FAILED. |
| if (last_error == net::OK) |
| last_error = net::ERR_FAILED; |
| connection_callback_.Run(last_error); |
| return; |
| } |
| |
| write_callback_.Run(); |
| } |
| |
| void ConnectionHandlerImpl::GetNextMessage() { |
| DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() || |
| SocketInputStream::READY == input_stream_->GetState()); |
| message_tag_ = 0; |
| message_size_ = 0; |
| |
| WaitForData(MCS_TAG_AND_SIZE); |
| } |
| |
| void ConnectionHandlerImpl::WaitForData(ProcessingState state) { |
| DVLOG(1) << "Waiting for MCS data: state == " << state; |
| |
| if (!input_stream_) { |
| // The connection has already been closed. Just return. |
| DCHECK(!output_stream_.get()); |
| DCHECK(!read_timeout_timer_.IsRunning()); |
| return; |
| } |
| |
| if (input_stream_->GetState() != SocketInputStream::EMPTY && |
| input_stream_->GetState() != SocketInputStream::READY) { |
| // An error occurred. |
| int last_error = output_stream_->last_error(); |
| CloseConnection(); |
| // If the socket stream had an error, plumb it up, else plumb up FAILED. |
| if (last_error == net::OK) |
| last_error = net::ERR_FAILED; |
| connection_callback_.Run(last_error); |
| return; |
| } |
| |
| // Used to determine whether a Socket::Read is necessary. |
| int min_bytes_needed = 0; |
| // Used to limit the size of the Socket::Read. |
| int max_bytes_needed = 0; |
| |
| switch(state) { |
| case MCS_VERSION_TAG_AND_SIZE: |
| min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin; |
| max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax; |
| break; |
| case MCS_TAG_AND_SIZE: |
| min_bytes_needed = kTagPacketLen + kSizePacketLenMin; |
| max_bytes_needed = kTagPacketLen + kSizePacketLenMax; |
| break; |
| case MCS_SIZE: |
| min_bytes_needed = size_packet_so_far_ + 1; |
| max_bytes_needed = kSizePacketLenMax; |
| break; |
| case MCS_PROTO_BYTES: |
| read_timeout_timer_.Reset(); |
| if (message_size_ < kDefaultDataPacketLimit) { |
| // No variability in the message size, set both to the same. |
| min_bytes_needed = message_size_; |
| max_bytes_needed = message_size_; |
| } else { |
| int bytes_left = message_size_ - payload_input_buffer_.size(); |
| if (bytes_left > kDefaultDataPacketLimit) |
| bytes_left = kDefaultDataPacketLimit; |
| min_bytes_needed = bytes_left; |
| max_bytes_needed = bytes_left; |
| } |
| break; |
| } |
| DCHECK_GE(max_bytes_needed, min_bytes_needed); |
| |
| int unread_byte_count = input_stream_->UnreadByteCount(); |
| if (min_bytes_needed > unread_byte_count && |
| input_stream_->Refresh( |
| base::Bind(&ConnectionHandlerImpl::WaitForData, |
| weak_ptr_factory_.GetWeakPtr(), |
| state), |
| max_bytes_needed - unread_byte_count) == net::ERR_IO_PENDING) { |
| return; |
| } |
| |
| // Check for refresh errors. |
| if (input_stream_->GetState() != SocketInputStream::READY) { |
| // An error occurred. |
| int last_error = input_stream_->last_error(); |
| CloseConnection(); |
| // If the socket stream had an error, plumb it up, else plumb up FAILED. |
| if (last_error == net::OK) |
| last_error = net::ERR_FAILED; |
| connection_callback_.Run(last_error); |
| return; |
| } |
| |
| // Check whether read is complete, or needs to be continued ( |
| // SocketInputStream::Refresh can finish without reading all the data). |
| if (input_stream_->UnreadByteCount() < min_bytes_needed) { |
| DVLOG(1) << "Socket read finished prematurely. Waiting for " |
| << min_bytes_needed - input_stream_->UnreadByteCount() |
| << " more bytes."; |
| base::ThreadTaskRunnerHandle::Get()->PostTask( |
| FROM_HERE, |
| base::BindOnce(&ConnectionHandlerImpl::WaitForData, |
| weak_ptr_factory_.GetWeakPtr(), MCS_PROTO_BYTES)); |
| return; |
| } |
| |
| // Received enough bytes, process them. |
| DVLOG(1) << "Processing MCS data: state == " << state; |
| switch(state) { |
| case MCS_VERSION_TAG_AND_SIZE: |
| OnGotVersion(); |
| break; |
| case MCS_TAG_AND_SIZE: |
| OnGotMessageTag(); |
| break; |
| case MCS_SIZE: |
| OnGotMessageSize(); |
| break; |
| case MCS_PROTO_BYTES: |
| OnGotMessageBytes(); |
| break; |
| } |
| } |
| |
| void ConnectionHandlerImpl::OnGotVersion() { |
| uint8_t version = 0; |
| { |
| CodedInputStream coded_input_stream(input_stream_.get()); |
| coded_input_stream.ReadRaw(&version, 1); |
| } |
| // TODO(zea): remove this when the server is ready. |
| if (version < kMCSVersion && version != 38) { |
| LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version); |
| connection_callback_.Run(net::ERR_FAILED); |
| return; |
| } |
| |
| input_stream_->RebuildBuffer(); |
| |
| // Process the LoginResponse message tag. |
| OnGotMessageTag(); |
| } |
| |
| void ConnectionHandlerImpl::OnGotMessageTag() { |
| if (input_stream_->GetState() != SocketInputStream::READY) { |
| LOG(ERROR) << "Failed to receive protobuf tag."; |
| read_callback_.Run(std::unique_ptr<google::protobuf::MessageLite>()); |
| return; |
| } |
| |
| { |
| CodedInputStream coded_input_stream(input_stream_.get()); |
| coded_input_stream.ReadRaw(&message_tag_, 1); |
| } |
| |
| DVLOG(1) << "Received proto of type " |
| << static_cast<unsigned int>(message_tag_); |
| |
| if (!read_timeout_timer_.IsRunning()) { |
| read_timeout_timer_.Start(FROM_HERE, |
| read_timeout_, |
| base::Bind(&ConnectionHandlerImpl::OnTimeout, |
| weak_ptr_factory_.GetWeakPtr())); |
| } |
| OnGotMessageSize(); |
| } |
| |
| void ConnectionHandlerImpl::OnGotMessageSize() { |
| if (input_stream_->GetState() != SocketInputStream::READY) { |
| LOG(ERROR) << "Failed to receive message size."; |
| read_callback_.Run(std::unique_ptr<google::protobuf::MessageLite>()); |
| return; |
| } |
| |
| int prev_byte_count = input_stream_->UnreadByteCount(); |
| int result = net::OK; |
| bool incomplete_size_packet = false; |
| { |
| CodedInputStream coded_input_stream(input_stream_.get()); |
| if (!coded_input_stream.ReadVarint32(&message_size_)) { |
| DVLOG(1) << "Expecting another message size byte."; |
| if (prev_byte_count >= kSizePacketLenMax) { |
| // Already had enough bytes, something else went wrong. |
| LOG(ERROR) << "Failed to process message size"; |
| result = net::ERR_FILE_TOO_BIG; |
| } else { |
| // Back up by the amount read. |
| int bytes_read = prev_byte_count - input_stream_->UnreadByteCount(); |
| input_stream_->BackUp(bytes_read); |
| size_packet_so_far_ = bytes_read; |
| incomplete_size_packet = true; |
| } |
| } |
| } |
| |
| if (result != net::OK) { |
| connection_callback_.Run(result); |
| return; |
| } else if (incomplete_size_packet) { |
| WaitForData(MCS_SIZE); |
| return; |
| } |
| |
| DVLOG(1) << "Proto size: " << message_size_; |
| size_packet_so_far_ = 0; |
| payload_input_buffer_.clear(); |
| |
| if (message_size_ > 0) |
| WaitForData(MCS_PROTO_BYTES); |
| else |
| OnGotMessageBytes(); |
| } |
| |
| void ConnectionHandlerImpl::OnGotMessageBytes() { |
| read_timeout_timer_.Stop(); |
| std::unique_ptr<google::protobuf::MessageLite> protobuf( |
| BuildProtobufFromTag(message_tag_)); |
| // Messages with no content are valid; just use the default protobuf for |
| // that tag. |
| if (protobuf.get() && message_size_ == 0) { |
| base::ThreadTaskRunnerHandle::Get()->PostTask( |
| FROM_HERE, base::BindOnce(&ConnectionHandlerImpl::GetNextMessage, |
| weak_ptr_factory_.GetWeakPtr())); |
| read_callback_.Run(std::move(protobuf)); |
| return; |
| } |
| |
| if (input_stream_->GetState() != SocketInputStream::READY) { |
| LOG(ERROR) << "Failed to extract protobuf bytes of type " |
| << static_cast<unsigned int>(message_tag_); |
| // Reset the connection. |
| connection_callback_.Run(net::ERR_FAILED); |
| return; |
| } |
| |
| if (!protobuf.get()) { |
| LOG(ERROR) << "Received message of invalid type " |
| << static_cast<unsigned int>(message_tag_); |
| connection_callback_.Run(net::ERR_INVALID_ARGUMENT); |
| return; |
| } |
| |
| int result = net::OK; |
| if (message_size_ < kDefaultDataPacketLimit) { |
| CodedInputStream coded_input_stream(input_stream_.get()); |
| if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) { |
| LOG(ERROR) << "Unable to parse GCM message of type " |
| << static_cast<unsigned int>(message_tag_); |
| result = net::ERR_FAILED; |
| } |
| } else { |
| // Copy any data in the input stream onto the end of the buffer. |
| const void* data_ptr = NULL; |
| int size = 0; |
| input_stream_->Next(&data_ptr, &size); |
| payload_input_buffer_.insert(payload_input_buffer_.end(), |
| static_cast<const uint8_t*>(data_ptr), |
| static_cast<const uint8_t*>(data_ptr) + size); |
| DCHECK_LE(payload_input_buffer_.size(), message_size_); |
| |
| if (payload_input_buffer_.size() == message_size_) { |
| ArrayInputStream buffer_input_stream(payload_input_buffer_.data(), |
| payload_input_buffer_.size()); |
| CodedInputStream coded_input_stream(&buffer_input_stream); |
| if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) { |
| LOG(ERROR) << "Unable to parse GCM message of type " |
| << static_cast<unsigned int>(message_tag_); |
| result = net::ERR_FAILED; |
| } |
| } else { |
| // Continue reading data. |
| DVLOG(1) << "Continuing data read. Buffer size is " |
| << payload_input_buffer_.size() |
| << ", expecting " << message_size_; |
| input_stream_->RebuildBuffer(); |
| |
| read_timeout_timer_.Start(FROM_HERE, |
| read_timeout_, |
| base::Bind(&ConnectionHandlerImpl::OnTimeout, |
| weak_ptr_factory_.GetWeakPtr())); |
| WaitForData(MCS_PROTO_BYTES); |
| return; |
| } |
| } |
| |
| if (result != net::OK) { |
| // Reset the connection. |
| connection_callback_.Run(result); |
| return; |
| } |
| |
| input_stream_->RebuildBuffer(); |
| base::ThreadTaskRunnerHandle::Get()->PostTask( |
| FROM_HERE, base::BindOnce(&ConnectionHandlerImpl::GetNextMessage, |
| weak_ptr_factory_.GetWeakPtr())); |
| if (message_tag_ == kLoginResponseTag) { |
| if (handshake_complete_) { |
| LOG(ERROR) << "Unexpected login response."; |
| } else { |
| handshake_complete_ = true; |
| DVLOG(1) << "GCM Handshake complete."; |
| connection_callback_.Run(net::OK); |
| } |
| } |
| read_callback_.Run(std::move(protobuf)); |
| } |
| |
| void ConnectionHandlerImpl::OnTimeout() { |
| LOG(ERROR) << "Timed out waiting for GCM Protocol buffer."; |
| CloseConnection(); |
| connection_callback_.Run(net::ERR_TIMED_OUT); |
| } |
| |
| void ConnectionHandlerImpl::CloseConnection() { |
| DVLOG(1) << "Closing connection."; |
| read_timeout_timer_.Stop(); |
| handshake_complete_ = false; |
| message_tag_ = 0; |
| message_size_ = 0; |
| size_packet_so_far_ = 0; |
| payload_input_buffer_.clear(); |
| input_stream_.reset(); |
| output_stream_.reset(); |
| weak_ptr_factory_.InvalidateWeakPtrs(); |
| } |
| |
| } // namespace gcm |