blob: 7f6391803724443ca9814127df14f59ba346edcf [file] [log] [blame]
// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "content/browser/renderer_host/websocket_host.h"
#include "base/basictypes.h"
#include "base/memory/weak_ptr.h"
#include "base/strings/string_util.h"
#include "content/browser/renderer_host/websocket_dispatcher_host.h"
#include "content/browser/ssl/ssl_error_handler.h"
#include "content/browser/ssl/ssl_manager.h"
#include "content/common/websocket_messages.h"
#include "ipc/ipc_message_macros.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_util.h"
#include "net/ssl/ssl_info.h"
#include "net/websockets/websocket_channel.h"
#include "net/websockets/websocket_event_interface.h"
#include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode
#include "net/websockets/websocket_handshake_request_info.h"
#include "net/websockets/websocket_handshake_response_info.h"
#include "url/origin.h"
namespace content {
namespace {
typedef net::WebSocketEventInterface::ChannelState ChannelState;
// Convert a content::WebSocketMessageType to a
// net::WebSocketFrameHeader::OpCode
net::WebSocketFrameHeader::OpCode MessageTypeToOpCode(
WebSocketMessageType type) {
DCHECK(type == WEB_SOCKET_MESSAGE_TYPE_CONTINUATION ||
type == WEB_SOCKET_MESSAGE_TYPE_TEXT ||
type == WEB_SOCKET_MESSAGE_TYPE_BINARY);
typedef net::WebSocketFrameHeader::OpCode OpCode;
// These compile asserts verify that the same underlying values are used for
// both types, so we can simply cast between them.
COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_CONTINUATION) ==
net::WebSocketFrameHeader::kOpCodeContinuation,
enum_values_must_match_for_opcode_continuation);
COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_TEXT) ==
net::WebSocketFrameHeader::kOpCodeText,
enum_values_must_match_for_opcode_text);
COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_BINARY) ==
net::WebSocketFrameHeader::kOpCodeBinary,
enum_values_must_match_for_opcode_binary);
return static_cast<OpCode>(type);
}
WebSocketMessageType OpCodeToMessageType(
net::WebSocketFrameHeader::OpCode opCode) {
DCHECK(opCode == net::WebSocketFrameHeader::kOpCodeContinuation ||
opCode == net::WebSocketFrameHeader::kOpCodeText ||
opCode == net::WebSocketFrameHeader::kOpCodeBinary);
// This cast is guaranteed valid by the COMPILE_ASSERT() statements above.
return static_cast<WebSocketMessageType>(opCode);
}
ChannelState StateCast(WebSocketDispatcherHost::WebSocketHostState host_state) {
const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_ALIVE =
WebSocketDispatcherHost::WEBSOCKET_HOST_ALIVE;
const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_DELETED =
WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED;
DCHECK(host_state == WEBSOCKET_HOST_ALIVE ||
host_state == WEBSOCKET_HOST_DELETED);
// These compile asserts verify that we can get away with using static_cast<>
// for the conversion.
COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_ALIVE) ==
net::WebSocketEventInterface::CHANNEL_ALIVE,
enum_values_must_match_for_state_alive);
COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_DELETED) ==
net::WebSocketEventInterface::CHANNEL_DELETED,
enum_values_must_match_for_state_deleted);
return static_cast<ChannelState>(host_state);
}
// Implementation of net::WebSocketEventInterface. Receives events from our
// WebSocketChannel object. Each event is translated to an IPC and sent to the
// renderer or child process via WebSocketDispatcherHost.
class WebSocketEventHandler : public net::WebSocketEventInterface {
public:
WebSocketEventHandler(WebSocketDispatcherHost* dispatcher,
int routing_id,
int render_frame_id);
virtual ~WebSocketEventHandler();
// net::WebSocketEventInterface implementation
virtual ChannelState OnAddChannelResponse(
bool fail,
const std::string& selected_subprotocol,
const std::string& extensions) OVERRIDE;
virtual ChannelState OnDataFrame(bool fin,
WebSocketMessageType type,
const std::vector<char>& data) OVERRIDE;
virtual ChannelState OnClosingHandshake() OVERRIDE;
virtual ChannelState OnFlowControl(int64 quota) OVERRIDE;
virtual ChannelState OnDropChannel(bool was_clean,
uint16 code,
const std::string& reason) OVERRIDE;
virtual ChannelState OnFailChannel(const std::string& message) OVERRIDE;
virtual ChannelState OnStartOpeningHandshake(
scoped_ptr<net::WebSocketHandshakeRequestInfo> request) OVERRIDE;
virtual ChannelState OnFinishOpeningHandshake(
scoped_ptr<net::WebSocketHandshakeResponseInfo> response) OVERRIDE;
virtual ChannelState OnSSLCertificateError(
scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
const GURL& url,
const net::SSLInfo& ssl_info,
bool fatal) OVERRIDE;
private:
class SSLErrorHandlerDelegate : public SSLErrorHandler::Delegate {
public:
SSLErrorHandlerDelegate(
scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks);
virtual ~SSLErrorHandlerDelegate();
base::WeakPtr<SSLErrorHandler::Delegate> GetWeakPtr();
// SSLErrorHandler::Delegate methods
virtual void CancelSSLRequest(const GlobalRequestID& id,
int error,
const net::SSLInfo* ssl_info) OVERRIDE;
virtual void ContinueSSLRequest(const GlobalRequestID& id) OVERRIDE;
private:
scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks_;
base::WeakPtrFactory<SSLErrorHandlerDelegate> weak_ptr_factory_;
DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate);
};
WebSocketDispatcherHost* const dispatcher_;
const int routing_id_;
const int render_frame_id_;
scoped_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_;
DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler);
};
WebSocketEventHandler::WebSocketEventHandler(
WebSocketDispatcherHost* dispatcher,
int routing_id,
int render_frame_id)
: dispatcher_(dispatcher),
routing_id_(routing_id),
render_frame_id_(render_frame_id) {
}
WebSocketEventHandler::~WebSocketEventHandler() {
DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_;
}
ChannelState WebSocketEventHandler::OnAddChannelResponse(
bool fail,
const std::string& selected_protocol,
const std::string& extensions) {
DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse"
<< " routing_id=" << routing_id_ << " fail=" << fail
<< " selected_protocol=\"" << selected_protocol << "\""
<< " extensions=\"" << extensions << "\"";
return StateCast(dispatcher_->SendAddChannelResponse(
routing_id_, fail, selected_protocol, extensions));
}
ChannelState WebSocketEventHandler::OnDataFrame(
bool fin,
net::WebSocketFrameHeader::OpCode type,
const std::vector<char>& data) {
DVLOG(3) << "WebSocketEventHandler::OnDataFrame"
<< " routing_id=" << routing_id_ << " fin=" << fin
<< " type=" << type << " data is " << data.size() << " bytes";
return StateCast(dispatcher_->SendFrame(
routing_id_, fin, OpCodeToMessageType(type), data));
}
ChannelState WebSocketEventHandler::OnClosingHandshake() {
DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake"
<< " routing_id=" << routing_id_;
return StateCast(dispatcher_->NotifyClosingHandshake(routing_id_));
}
ChannelState WebSocketEventHandler::OnFlowControl(int64 quota) {
DVLOG(3) << "WebSocketEventHandler::OnFlowControl"
<< " routing_id=" << routing_id_ << " quota=" << quota;
return StateCast(dispatcher_->SendFlowControl(routing_id_, quota));
}
ChannelState WebSocketEventHandler::OnDropChannel(bool was_clean,
uint16 code,
const std::string& reason) {
DVLOG(3) << "WebSocketEventHandler::OnDropChannel"
<< " routing_id=" << routing_id_ << " was_clean=" << was_clean
<< " code=" << code << " reason=\"" << reason << "\"";
return StateCast(
dispatcher_->DoDropChannel(routing_id_, was_clean, code, reason));
}
ChannelState WebSocketEventHandler::OnFailChannel(const std::string& message) {
DVLOG(3) << "WebSocketEventHandler::OnFailChannel"
<< " routing_id=" << routing_id_
<< " message=\"" << message << "\"";
return StateCast(dispatcher_->NotifyFailure(routing_id_, message));
}
ChannelState WebSocketEventHandler::OnStartOpeningHandshake(
scoped_ptr<net::WebSocketHandshakeRequestInfo> request) {
bool should_send = dispatcher_->CanReadRawCookies();
DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake "
<< "should_send=" << should_send;
if (!should_send)
return WebSocketEventInterface::CHANNEL_ALIVE;
WebSocketHandshakeRequest request_to_pass;
request_to_pass.url.Swap(&request->url);
net::HttpRequestHeaders::Iterator it(request->headers);
while (it.GetNext())
request_to_pass.headers.push_back(std::make_pair(it.name(), it.value()));
request_to_pass.headers_text =
base::StringPrintf("GET %s HTTP/1.1\r\n",
request_to_pass.url.spec().c_str()) +
request->headers.ToString();
request_to_pass.request_time = request->request_time;
return StateCast(dispatcher_->NotifyStartOpeningHandshake(routing_id_,
request_to_pass));
}
ChannelState WebSocketEventHandler::OnFinishOpeningHandshake(
scoped_ptr<net::WebSocketHandshakeResponseInfo> response) {
bool should_send = dispatcher_->CanReadRawCookies();
DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake "
<< "should_send=" << should_send;
if (!should_send)
return WebSocketEventInterface::CHANNEL_ALIVE;
WebSocketHandshakeResponse response_to_pass;
response_to_pass.url.Swap(&response->url);
response_to_pass.status_code = response->status_code;
response_to_pass.status_text.swap(response->status_text);
void* iter = NULL;
std::string name, value;
while (response->headers->EnumerateHeaderLines(&iter, &name, &value))
response_to_pass.headers.push_back(std::make_pair(name, value));
response_to_pass.headers_text =
net::HttpUtil::ConvertHeadersBackToHTTPResponse(
response->headers->raw_headers());
response_to_pass.response_time = response->response_time;
return StateCast(dispatcher_->NotifyFinishOpeningHandshake(routing_id_,
response_to_pass));
}
ChannelState WebSocketEventHandler::OnSSLCertificateError(
scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
const GURL& url,
const net::SSLInfo& ssl_info,
bool fatal) {
DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError"
<< " routing_id=" << routing_id_ << " url=" << url.spec()
<< " cert_status=" << ssl_info.cert_status << " fatal=" << fatal;
ssl_error_handler_delegate_.reset(
new SSLErrorHandlerDelegate(callbacks.Pass()));
// We don't need request_id to be unique so just make a fake one.
GlobalRequestID request_id(-1, -1);
SSLManager::OnSSLCertificateError(ssl_error_handler_delegate_->GetWeakPtr(),
request_id,
ResourceType::SUB_RESOURCE,
url,
dispatcher_->render_process_id(),
render_frame_id_,
ssl_info,
fatal);
// The above method is always asynchronous.
return WebSocketEventInterface::CHANNEL_ALIVE;
}
WebSocketEventHandler::SSLErrorHandlerDelegate::SSLErrorHandlerDelegate(
scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks)
: callbacks_(callbacks.Pass()), weak_ptr_factory_(this) {}
WebSocketEventHandler::SSLErrorHandlerDelegate::~SSLErrorHandlerDelegate() {}
base::WeakPtr<SSLErrorHandler::Delegate>
WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() {
return weak_ptr_factory_.GetWeakPtr();
}
void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest(
const GlobalRequestID& id,
int error,
const net::SSLInfo* ssl_info) {
DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest"
<< " error=" << error
<< " cert_status=" << (ssl_info ? ssl_info->cert_status
: static_cast<net::CertStatus>(-1));
callbacks_->CancelSSLRequest(error, ssl_info);
}
void WebSocketEventHandler::SSLErrorHandlerDelegate::ContinueSSLRequest(
const GlobalRequestID& id) {
DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest";
callbacks_->ContinueSSLRequest();
}
} // namespace
WebSocketHost::WebSocketHost(int routing_id,
WebSocketDispatcherHost* dispatcher,
net::URLRequestContext* url_request_context)
: dispatcher_(dispatcher),
url_request_context_(url_request_context),
routing_id_(routing_id) {
DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id;
}
WebSocketHost::~WebSocketHost() {}
bool WebSocketHost::OnMessageReceived(const IPC::Message& message) {
bool handled = true;
IPC_BEGIN_MESSAGE_MAP(WebSocketHost, message)
IPC_MESSAGE_HANDLER(WebSocketHostMsg_AddChannelRequest, OnAddChannelRequest)
IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame, OnSendFrame)
IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl, OnFlowControl)
IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel, OnDropChannel)
IPC_MESSAGE_UNHANDLED(handled = false)
IPC_END_MESSAGE_MAP()
return handled;
}
void WebSocketHost::OnAddChannelRequest(
const GURL& socket_url,
const std::vector<std::string>& requested_protocols,
const url::Origin& origin,
int render_frame_id) {
DVLOG(3) << "WebSocketHost::OnAddChannelRequest"
<< " routing_id=" << routing_id_ << " socket_url=\"" << socket_url
<< "\" requested_protocols=\""
<< JoinString(requested_protocols, ", ") << "\" origin=\""
<< origin.string() << "\"";
DCHECK(!channel_);
scoped_ptr<net::WebSocketEventInterface> event_interface(
new WebSocketEventHandler(dispatcher_, routing_id_, render_frame_id));
channel_.reset(
new net::WebSocketChannel(event_interface.Pass(), url_request_context_));
channel_->SendAddChannelRequest(socket_url, requested_protocols, origin);
}
void WebSocketHost::OnSendFrame(bool fin,
WebSocketMessageType type,
const std::vector<char>& data) {
DVLOG(3) << "WebSocketHost::OnSendFrame"
<< " routing_id=" << routing_id_ << " fin=" << fin
<< " type=" << type << " data is " << data.size() << " bytes";
DCHECK(channel_);
channel_->SendFrame(fin, MessageTypeToOpCode(type), data);
}
void WebSocketHost::OnFlowControl(int64 quota) {
DVLOG(3) << "WebSocketHost::OnFlowControl"
<< " routing_id=" << routing_id_ << " quota=" << quota;
DCHECK(channel_);
channel_->SendFlowControl(quota);
}
void WebSocketHost::OnDropChannel(bool was_clean,
uint16 code,
const std::string& reason) {
DVLOG(3) << "WebSocketHost::OnDropChannel"
<< " routing_id=" << routing_id_ << " was_clean=" << was_clean
<< " code=" << code << " reason=\"" << reason << "\"";
DCHECK(channel_);
// TODO(yhirano): Handle |was_clean| appropriately.
channel_->StartClosingHandshake(code, reason);
}
} // namespace content