| // Copyright (c) 2011 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "net/socket/web_socket_server_socket.h" |
| |
| #include <algorithm> |
| #include <deque> |
| #include <limits> |
| #include <map> |
| #include <vector> |
| |
| #if defined(OS_WIN) |
| #include <winsock2.h> // for htonl |
| #else |
| #include <arpa/inet.h> |
| #endif |
| |
| #include "base/basictypes.h" |
| #include "base/logging.h" |
| #include "base/md5.h" |
| #include "base/memory/ref_counted.h" |
| #include "base/memory/scoped_ptr.h" |
| #include "base/message_loop.h" |
| #include "base/string_util.h" |
| #include "base/task.h" |
| #include "googleurl/src/gurl.h" |
| #include "net/base/completion_callback.h" |
| #include "net/base/io_buffer.h" |
| #include "net/base/net_errors.h" |
| |
| namespace { |
| |
| const size_t kHandshakeLimitBytes = 1 << 14; |
| |
| const char kCrOctet = '\r'; |
| COMPILE_ASSERT(kCrOctet == '\x0d', ASCII); |
| const char kLfOctet = '\n'; |
| COMPILE_ASSERT(kLfOctet == '\x0a', ASCII); |
| const char kSpaceOctet = ' '; |
| COMPILE_ASSERT(kSpaceOctet == '\x20', ASCII); |
| const char kCommaOctet = ','; |
| COMPILE_ASSERT(kCommaOctet == '\x2c', ASCII); |
| |
| const char kCRLF[] = { kCrOctet, kLfOctet, 0 }; |
| const char kCRLFCRLF[] = { kCrOctet, kLfOctet, kCrOctet, kLfOctet, 0 }; |
| |
| const char kPlainHostFieldName[] = "Host"; |
| const char kPlainOriginFieldName[] = "Origin"; |
| const char kOriginFieldName[] = "Sec-WebSocket-Origin"; |
| const char kProtocolFieldName[] = "Sec-WebSocket-Protocol"; |
| const char kVersionFieldName[] = "Sec-WebSocket-Version"; |
| const char kLocationFieldName[] = "Sec-WebSocket-Location"; |
| const char kKey1FieldName[] = "Sec-WebSocket-Key1"; |
| const char kKey2FieldName[] = "Sec-WebSocket-Key2"; |
| |
| int CountSpaces(const std::string& s) { |
| return std::count(s.begin(), s.end(), kSpaceOctet); |
| } |
| |
| // Returns true on success. |
| bool FetchDecimalDigits(const std::string& s, uint32* result) { |
| *result = 0; |
| bool got_something = false; |
| for (size_t i = 0; i < s.size(); ++i) { |
| if (IsAsciiDigit(s[i])) { |
| got_something = true; |
| if (*result > std::numeric_limits<uint32>::max() / 10) |
| return false; |
| *result *= 10; |
| int digit = s[i] - '0'; |
| if (*result > std::numeric_limits<uint32>::max() - digit) |
| return false; |
| *result += digit; |
| } |
| } |
| return got_something; |
| } |
| |
| // Returns number of fetched subprotocols or negative error code. |
| int FetchSubprotocolList( |
| const std::string& s, std::vector<std::string>* subprotocol_list) { |
| subprotocol_list->clear(); |
| subprotocol_list->push_back(std::string()); |
| for (size_t i = 0; i < s.size(); ++i) { |
| if (s[i] > '\x20' && s[i] < '\x7f' && s[i] != kCommaOctet) |
| subprotocol_list->back() += s[i]; |
| else if (!subprotocol_list->back().empty()) { |
| if (subprotocol_list->size() < 16) |
| subprotocol_list->push_back(std::string()); |
| else |
| return net::ERR_LIMIT_VIOLATION; |
| } |
| } |
| if (subprotocol_list->back().empty()) |
| subprotocol_list->pop_back(); |
| if (subprotocol_list->empty()) |
| return net::ERR_WS_PROTOCOL_ERROR; |
| |
| { |
| std::vector<std::string> tmp(*subprotocol_list); |
| std::sort(tmp.begin(), tmp.end()); |
| if (tmp.end() != std::unique(tmp.begin(), tmp.end())) |
| return net::ERR_WS_PROTOCOL_ERROR; |
| } |
| return subprotocol_list->size(); |
| } |
| |
| class WebSocketServerSocketImpl : public net::WebSocketServerSocket { |
| public: |
| WebSocketServerSocketImpl(net::Socket* transport_socket, Delegate* delegate) |
| : phase_(PHASE_NYMPH), |
| frame_bytes_remaining_(0), |
| transport_socket_(transport_socket), |
| delegate_(delegate), |
| handshake_buf_(new net::IOBuffer(kHandshakeLimitBytes)), |
| fill_handshake_buf_(new net::DrainableIOBuffer( |
| handshake_buf_, kHandshakeLimitBytes)), |
| process_handshake_buf_(new net::DrainableIOBuffer( |
| handshake_buf_, kHandshakeLimitBytes)), |
| transport_read_callback_(NewCallback( |
| this, &WebSocketServerSocketImpl::OnRead)), |
| transport_write_callback_(NewCallback( |
| this, &WebSocketServerSocketImpl::OnWrite)), |
| is_transport_read_pending_(false), |
| is_transport_write_pending_(false), |
| method_factory_(this) { |
| DCHECK(transport_socket); |
| DCHECK(delegate); |
| } |
| |
| virtual ~WebSocketServerSocketImpl() { |
| std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ); |
| if (it != pending_reqs_.end() && |
| it->type == PendingReq::TYPE_READ && |
| it->io_buf != NULL && |
| it->io_buf->data() != NULL && |
| it->callback != 0) { |
| it->callback->Run(0); // Report EOF. |
| } |
| } |
| |
| private: |
| enum Phase { |
| // Before Accept() is called. |
| PHASE_NYMPH, |
| |
| // After Accept() is called and until handshake success/fail. |
| PHASE_HANDSHAKE, |
| |
| // Processing data stream. |
| PHASE_FRAME_OUTSIDE, // Outside data frame. |
| PHASE_FRAME_INSIDE, // Inside text frame. |
| PHASE_FRAME_LENGTH, // Reading length of binary frame. |
| PHASE_FRAME_SKIP, // Skipping binary frame. |
| |
| // After termination. |
| PHASE_SHUT |
| }; |
| |
| struct PendingReq { |
| enum Type { |
| // Frame delimiters or handshake (as opposed to user data). |
| TYPE_METADATA = 1 << 0, |
| // Read request. |
| TYPE_READ = 1 << 1, |
| // Write request. |
| TYPE_WRITE = 1 << 2, |
| |
| TYPE_READ_METADATA = TYPE_READ | TYPE_METADATA, |
| TYPE_WRITE_METADATA = TYPE_WRITE | TYPE_METADATA |
| }; |
| |
| PendingReq(Type type, net::DrainableIOBuffer* io_buf, |
| net::CompletionCallback* callback) |
| : type(type), |
| io_buf(io_buf), |
| callback(callback) { |
| switch (type) { |
| case PendingReq::TYPE_READ: |
| case PendingReq::TYPE_WRITE: |
| case PendingReq::TYPE_READ_METADATA: |
| case PendingReq::TYPE_WRITE_METADATA: { |
| DCHECK(io_buf); |
| break; |
| } |
| default: { |
| NOTREACHED(); |
| break; |
| } |
| } |
| } |
| |
| Type type; |
| scoped_refptr<net::DrainableIOBuffer> io_buf; |
| net::CompletionCallback* callback; |
| }; |
| |
| // Socket implementation. |
| virtual int Read(net::IOBuffer* buf, int buf_len, |
| net::CompletionCallback* callback) OVERRIDE { |
| if (buf_len == 0) |
| return 0; |
| if (buf == NULL || buf_len < 0) { |
| NOTREACHED(); |
| return net::ERR_INVALID_ARGUMENT; |
| } |
| while (int bytes_remaining = fill_handshake_buf_->BytesConsumed() - |
| process_handshake_buf_->BytesConsumed()) { |
| DCHECK(!is_transport_read_pending_); |
| DCHECK(GetPendingReq(PendingReq::TYPE_READ) == pending_reqs_.end()); |
| switch (phase_) { |
| case PHASE_FRAME_OUTSIDE: |
| case PHASE_FRAME_INSIDE: |
| case PHASE_FRAME_LENGTH: |
| case PHASE_FRAME_SKIP: { |
| int n = std::min(bytes_remaining, buf_len); |
| int rv = ProcessDataFrames( |
| process_handshake_buf_->data(), n, buf->data(), buf_len); |
| process_handshake_buf_->DidConsume(n); |
| if (rv == 0) { |
| // ProcessDataFrames may return zero for non-empty buffer if it |
| // contains only frame delimiters without real data. In this case: |
| // try again and do not just return zero (zero stands for EOF). |
| continue; |
| } |
| return rv; |
| } |
| case PHASE_SHUT: { |
| return 0; |
| } |
| case PHASE_NYMPH: |
| case PHASE_HANDSHAKE: |
| default: { |
| NOTREACHED(); |
| return net::ERR_UNEXPECTED; |
| } |
| } |
| } |
| switch (phase_) { |
| case PHASE_FRAME_OUTSIDE: |
| case PHASE_FRAME_INSIDE: |
| case PHASE_FRAME_LENGTH: |
| case PHASE_FRAME_SKIP: { |
| pending_reqs_.push_back(PendingReq( |
| PendingReq::TYPE_READ, |
| new net::DrainableIOBuffer(buf, buf_len), |
| callback)); |
| ConsiderTransportRead(); |
| break; |
| } |
| case PHASE_SHUT: { |
| return 0; |
| } |
| case PHASE_NYMPH: |
| case PHASE_HANDSHAKE: |
| default: { |
| NOTREACHED(); |
| return net::ERR_UNEXPECTED; |
| } |
| } |
| return net::ERR_IO_PENDING; |
| } |
| |
| virtual int Write(net::IOBuffer* buf, int buf_len, |
| net::CompletionCallback* callback) OVERRIDE { |
| if (buf_len == 0) |
| return 0; |
| if (buf == NULL || buf_len < 0) { |
| NOTREACHED(); |
| return net::ERR_INVALID_ARGUMENT; |
| } |
| DCHECK_EQ(std::find(buf->data(), buf->data() + buf_len, '\xff'), |
| buf->data() + buf_len); |
| switch (phase_) { |
| case PHASE_FRAME_OUTSIDE: |
| case PHASE_FRAME_INSIDE: |
| case PHASE_FRAME_LENGTH: |
| case PHASE_FRAME_SKIP: { |
| break; |
| } |
| case PHASE_SHUT: { |
| return net::ERR_SOCKET_NOT_CONNECTED; |
| } |
| case PHASE_NYMPH: |
| case PHASE_HANDSHAKE: |
| default: { |
| NOTREACHED(); |
| return net::ERR_UNEXPECTED; |
| } |
| } |
| |
| net::IOBuffer* frame_start = new net::IOBuffer(1); |
| frame_start->data()[0] = '\x00'; |
| pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, |
| new net::DrainableIOBuffer(frame_start, 1), |
| NULL)); |
| |
| pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE, |
| new net::DrainableIOBuffer(buf, buf_len), |
| callback)); |
| |
| net::IOBuffer* frame_end = new net::IOBuffer(1); |
| frame_end->data()[0] = '\xff'; |
| pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, |
| new net::DrainableIOBuffer(frame_end, 1), |
| NULL)); |
| |
| ConsiderTransportWrite(); |
| return net::ERR_IO_PENDING; |
| } |
| |
| virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { |
| return transport_socket_->SetReceiveBufferSize(size); |
| } |
| |
| virtual bool SetSendBufferSize(int32 size) OVERRIDE { |
| return transport_socket_->SetSendBufferSize(size); |
| } |
| |
| // WebSocketServerSocket implementation. |
| virtual int Accept(net::CompletionCallback* callback) { |
| if (phase_ != PHASE_NYMPH) |
| return net::ERR_UNEXPECTED; |
| phase_ = PHASE_HANDSHAKE; |
| pending_reqs_.push_front(PendingReq( |
| PendingReq::TYPE_READ_METADATA, fill_handshake_buf_.get(), callback)); |
| ConsiderTransportRead(); |
| return net::ERR_IO_PENDING; |
| } |
| |
| std::deque<PendingReq>::iterator GetPendingReq(PendingReq::Type type) { |
| for (std::deque<PendingReq>::iterator it = pending_reqs_.begin(); |
| it != pending_reqs_.end(); ++it) { |
| if (it->type & type) |
| return it; |
| } |
| return pending_reqs_.end(); |
| } |
| |
| void ConsiderTransportRead() { |
| if (pending_reqs_.empty()) |
| return; |
| if (is_transport_read_pending_) |
| return; |
| std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ); |
| if (it == pending_reqs_.end()) |
| return; |
| if (it->io_buf == NULL || it->io_buf->BytesRemaining() == 0) { |
| NOTREACHED(); |
| return; |
| } |
| is_transport_read_pending_ = true; |
| int rv = transport_socket_->Read( |
| it->io_buf.get(), it->io_buf->BytesRemaining(), |
| transport_read_callback_.get()); |
| if (rv != net::ERR_IO_PENDING) { |
| // PostTask rather than direct call in order to: |
| // (1) guarantee calling callback after returning from Read(); |
| // (2) avoid potential stack overflow; |
| MessageLoop::current()->PostTask(FROM_HERE, |
| method_factory_.NewRunnableMethod( |
| &WebSocketServerSocketImpl::OnRead, rv)); |
| } |
| } |
| |
| void ConsiderTransportWrite() { |
| if (is_transport_write_pending_) |
| return; |
| if (pending_reqs_.empty()) |
| return; |
| std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_WRITE); |
| if (it == pending_reqs_.end()) |
| return; |
| if (it->io_buf == NULL || it->io_buf->BytesRemaining() == 0) { |
| NOTREACHED(); |
| Shut(net::ERR_UNEXPECTED); |
| return; |
| } |
| is_transport_write_pending_ = true; |
| int rv = transport_socket_->Write( |
| it->io_buf.get(), it->io_buf->BytesRemaining(), |
| transport_write_callback_.get()); |
| if (rv != net::ERR_IO_PENDING) { |
| // PostTask rather than direct call in order to: |
| // (1) guarantee calling callback after returning from Read(); |
| // (2) avoid potential stack overflow; |
| MessageLoop::current()->PostTask(FROM_HERE, |
| method_factory_.NewRunnableMethod( |
| &WebSocketServerSocketImpl::OnWrite, rv)); |
| } |
| } |
| |
| void Shut(int result) { |
| if (result > 0 || result == net::ERR_IO_PENDING) |
| result = net::ERR_UNEXPECTED; |
| if (result != 0) { |
| while (!pending_reqs_.empty()) { |
| PendingReq& req = pending_reqs_.front(); |
| if (req.callback) |
| req.callback->Run(result); |
| pending_reqs_.pop_front(); |
| } |
| transport_socket_.reset(); // terminate underlying connection. |
| } |
| phase_ = PHASE_SHUT; |
| } |
| |
| // Callbacks for transport socket. |
| void OnRead(int result) { |
| if (!is_transport_read_pending_) { |
| NOTREACHED(); |
| Shut(net::ERR_UNEXPECTED); |
| return; |
| } |
| is_transport_read_pending_ = false; |
| |
| if (result <= 0) { |
| Shut(result); |
| return; |
| } |
| |
| std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ); |
| if (it == pending_reqs_.end() || |
| it->io_buf == NULL || |
| it->io_buf->data() == NULL) { |
| NOTREACHED(); |
| Shut(net::ERR_UNEXPECTED); |
| return; |
| } |
| if ((phase_ == PHASE_HANDSHAKE) == (it->type == PendingReq::TYPE_READ)) { |
| NOTREACHED(); |
| Shut(net::ERR_UNEXPECTED); |
| return; |
| } |
| |
| switch (phase_) { |
| case PHASE_HANDSHAKE: { |
| if (it != pending_reqs_.begin() || it->io_buf != fill_handshake_buf_) { |
| NOTREACHED(); |
| Shut(net::ERR_UNEXPECTED); |
| return; |
| } |
| fill_handshake_buf_->DidConsume(result); |
| // ProcessHandshake invalidates iterators for |pending_reqs_| |
| int rv = ProcessHandshake(); |
| if (rv > 0) { |
| process_handshake_buf_->DidConsume(rv); |
| phase_ = PHASE_FRAME_OUTSIDE; |
| net::CompletionCallback* cb = pending_reqs_.front().callback; |
| pending_reqs_.pop_front(); |
| ConsiderTransportWrite(); // Schedule answer handshake. |
| if (cb) |
| cb->Run(0); |
| } else if (rv == net::ERR_IO_PENDING) { |
| if (fill_handshake_buf_->BytesRemaining() < 1) |
| Shut(net::ERR_LIMIT_VIOLATION); |
| } else if (rv < 0) { |
| Shut(rv); |
| } else { |
| Shut(net::ERR_UNEXPECTED); |
| } |
| break; |
| } |
| case PHASE_FRAME_OUTSIDE: |
| case PHASE_FRAME_INSIDE: |
| case PHASE_FRAME_LENGTH: |
| case PHASE_FRAME_SKIP: { |
| int rv = ProcessDataFrames( |
| it->io_buf->data(), result, |
| it->io_buf->data(), it->io_buf->BytesRemaining()); |
| if (rv < 0) { |
| Shut(rv); |
| return; |
| } |
| if (rv > 0 || phase_ == PHASE_SHUT) { |
| net::CompletionCallback* cb = it->callback; |
| pending_reqs_.erase(it); |
| if (cb) |
| cb->Run(rv); |
| } |
| break; |
| } |
| case PHASE_NYMPH: |
| default: { |
| NOTREACHED(); |
| Shut(net::ERR_UNEXPECTED); |
| break; |
| } |
| } |
| ConsiderTransportRead(); |
| } |
| |
| void OnWrite(int result) { |
| if (!is_transport_write_pending_) { |
| NOTREACHED(); |
| Shut(net::ERR_UNEXPECTED); |
| return; |
| } |
| is_transport_write_pending_ = false; |
| |
| if (result < 0) { |
| Shut(result); |
| return; |
| } |
| |
| std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_WRITE); |
| if (it == pending_reqs_.end() || |
| it->io_buf == NULL || |
| it->io_buf->data() == NULL) { |
| NOTREACHED(); |
| Shut(net::ERR_UNEXPECTED); |
| return; |
| } |
| DCHECK_LE(result, it->io_buf->BytesRemaining()); |
| it->io_buf->DidConsume(result); |
| if (it->io_buf->BytesRemaining() == 0) { |
| net::CompletionCallback* cb = it->callback; |
| int bytes_written = it->io_buf->BytesConsumed(); |
| DCHECK_GT(bytes_written, 0); |
| pending_reqs_.erase(it); |
| if (cb) |
| cb->Run(bytes_written); |
| } |
| ConsiderTransportWrite(); |
| } |
| |
| // Returns (positive) number of consumed bytes on success. |
| // Returns ERR_IO_PENDING in case of incomplete input. |
| // Returns ERR_WS_PROTOCOL_ERROR or ERR_LIMIT_VIOLATION in case of failure to |
| // reasonably parse input. |
| int ProcessHandshake() { |
| static const char kGetPrefix[] = "GET "; |
| static const char kKeyValueDelimiter[] = ": "; |
| |
| class Fields { |
| public: |
| bool Has(const std::string& name) { |
| return map_.find(StringToLowerASCII(name)) != map_.end(); |
| } |
| |
| std::string Get(const std::string& name) { |
| return Has(name) ? map_[StringToLowerASCII(name)] : std::string(); |
| } |
| |
| void Set(const std::string& name, const std::string& value) { |
| map_[StringToLowerASCII(name)] = StringToLowerASCII(value); |
| } |
| |
| private: |
| std::map<std::string, std::string> map_; |
| } fields; |
| |
| char* buf = process_handshake_buf_->data(); |
| size_t buf_size = fill_handshake_buf_->BytesConsumed(); |
| |
| if (buf_size < 1) |
| return net::ERR_IO_PENDING; |
| if (!std::equal(buf, buf + std::min(buf_size, strlen(kGetPrefix)), |
| kGetPrefix)) { |
| // Data head does not match what is expected. |
| return net::ERR_WS_PROTOCOL_ERROR; |
| } |
| if (buf_size >= kHandshakeLimitBytes) |
| return net::ERR_LIMIT_VIOLATION; |
| char* buf_end = buf + buf_size; |
| |
| if (buf_size < strlen(kGetPrefix)) |
| return net::ERR_IO_PENDING; |
| char* resource_begin = buf + strlen(kGetPrefix); |
| char* resource_end = std::find(resource_begin, buf_end, kSpaceOctet); |
| if (resource_end == buf_end) |
| return net::ERR_IO_PENDING; |
| std::string resource(resource_begin, resource_end); |
| if (!IsStringUTF8(resource) || |
| resource.find_first_of(kCRLF) != std::string::npos) { |
| return net::ERR_WS_PROTOCOL_ERROR; |
| } |
| char* term_pos = std::search( |
| buf, buf_end, kCRLFCRLF, kCRLFCRLF + strlen(kCRLFCRLF)); |
| char key3[8]; // Notation (key3) matches websocket RFC. |
| size_t message_len = buf_end - term_pos; |
| if (message_len < sizeof(key3) + strlen(kCRLFCRLF)) |
| return net::ERR_IO_PENDING; |
| term_pos += strlen(kCRLFCRLF); |
| memcpy(key3, term_pos, sizeof(key3)); |
| term_pos += sizeof(key3); |
| // First line is "GET resource" line, so skip it. |
| char* pos = std::search(buf, term_pos, kCRLF, kCRLF + strlen(kCRLF)); |
| if (pos == term_pos) |
| return net::ERR_WS_PROTOCOL_ERROR; |
| for (;;) { |
| pos += strlen(kCRLF); |
| if (term_pos - pos < |
| static_cast<ptrdiff_t>(sizeof(key3) + strlen(kCRLF))) { |
| return net::ERR_WS_PROTOCOL_ERROR; |
| } |
| if (term_pos - pos == |
| static_cast<ptrdiff_t>(sizeof(key3) + strlen(kCRLF))) { |
| break; |
| } |
| char* next_pos = std::search( |
| pos, term_pos, kKeyValueDelimiter, |
| kKeyValueDelimiter + strlen(kKeyValueDelimiter)); |
| if (next_pos == term_pos) |
| return net::ERR_WS_PROTOCOL_ERROR; |
| std::string key(pos, next_pos); |
| if (!IsStringASCII(key) || |
| key.find_first_of(kCRLF) != std::string::npos) { |
| return net::ERR_WS_PROTOCOL_ERROR; |
| } |
| pos = std::search(next_pos += strlen(kKeyValueDelimiter), term_pos, |
| kCRLF, kCRLF + strlen(kCRLF)); |
| if (pos == term_pos) |
| return net::ERR_WS_PROTOCOL_ERROR; |
| if (!key.empty()) { |
| std::string value(next_pos, pos); |
| if (!IsStringASCII(value) || |
| value.find_first_of(kCRLF) != std::string::npos) { |
| return net::ERR_WS_PROTOCOL_ERROR; |
| } |
| fields.Set(key, value); |
| } |
| } |
| |
| // Values of Upgrade and Connection fields are hardcoded in the protocol. |
| if (fields.Get("Upgrade") != "websocket" || |
| fields.Get("Connection") != "upgrade") { |
| return net::ERR_WS_PROTOCOL_ERROR; |
| } |
| if (fields.Has(kVersionFieldName)) { |
| NOTIMPLEMENTED(); // new protocol. |
| return net::ERR_NOT_IMPLEMENTED; |
| } |
| |
| if (!fields.Has(kPlainOriginFieldName)) |
| return net::ERR_CONNECTION_REFUSED; |
| // Normalize (e.g. w.r.t. leading slashes) origin. |
| GURL origin = GURL(fields.Get(kPlainOriginFieldName)).GetOrigin(); |
| if (!origin.is_valid()) |
| return net::ERR_WS_PROTOCOL_ERROR; |
| std::string normalized_origin = origin.spec(); |
| |
| if (!fields.Has(kPlainHostFieldName)) |
| return net::ERR_CONNECTION_REFUSED; |
| |
| std::vector<std::string> subprotocol_list; |
| if (fields.Has(kProtocolFieldName)) { |
| int rv = FetchSubprotocolList( |
| fields.Get(kProtocolFieldName), &subprotocol_list); |
| if (rv < 0) |
| return rv; |
| DCHECK(subprotocol_list.end() == std::find( |
| subprotocol_list.begin(), subprotocol_list.end(), "")); |
| } |
| |
| std::string location; |
| std::string subprotocol; |
| if (!delegate_->ValidateWebSocket(resource, |
| normalized_origin, |
| fields.Get(kPlainHostFieldName), |
| subprotocol_list, |
| &location, |
| &subprotocol)) { |
| return net::ERR_CONNECTION_REFUSED; |
| } |
| if (subprotocol_list.empty()) { |
| DCHECK(subprotocol.empty()); |
| } else { |
| if (!subprotocol.empty()) { |
| if (subprotocol_list.end() == std::find( |
| subprotocol_list.begin(), subprotocol_list.end(), subprotocol)) { |
| NOTREACHED() << "delegate must pick subprotocol from given list"; |
| return net::ERR_UNEXPECTED; |
| } |
| } |
| } |
| |
| uint32 key_number1 = 0; |
| uint32 key_number2 = 0; |
| if (!FetchDecimalDigits(fields.Get(kKey1FieldName), &key_number1) || |
| !FetchDecimalDigits(fields.Get(kKey2FieldName), &key_number2)) { |
| return net::ERR_WS_PROTOCOL_ERROR; |
| } |
| |
| // We limit incoming header size so following numbers shall not be too high. |
| int spaces1 = CountSpaces(fields.Get(kKey1FieldName)); |
| int spaces2 = CountSpaces(fields.Get(kKey2FieldName)); |
| if (spaces1 == 0 || |
| spaces2 == 0 || |
| key_number1 % spaces1 != 0 || |
| key_number2 % spaces2 != 0) { |
| return net::ERR_WS_PROTOCOL_ERROR; |
| } |
| |
| char challenge[4 + 4 + sizeof(key3)]; |
| int32 part1 = htonl(key_number1 / spaces1); |
| int32 part2 = htonl(key_number2 / spaces2); |
| memcpy(challenge, &part1, 4); |
| memcpy(challenge + 4, &part2, 4); |
| memcpy(challenge + 4 + 4, key3, sizeof(key3)); |
| base::MD5Digest challenge_response; |
| base::MD5Sum(challenge, sizeof(challenge), &challenge_response); |
| |
| // Concocting response handshake. |
| class Buffer { |
| public: |
| Buffer() |
| : io_buf_(new net::IOBuffer(kHandshakeLimitBytes)), |
| bytes_written_(0), |
| is_ok_(true) { |
| } |
| |
| bool Write(const void* p, int len) { |
| DCHECK(p); |
| DCHECK_GE(len, 0); |
| if (!is_ok_) |
| return false; |
| if (bytes_written_ + len > kHandshakeLimitBytes) { |
| NOTREACHED(); |
| is_ok_ = false; |
| return false; |
| } |
| memcpy(io_buf_->data() + bytes_written_, p, len); |
| bytes_written_ += len; |
| return true; |
| } |
| |
| bool WriteLine(const char* p) { |
| return Write(p, strlen(p)) && Write(kCRLF, strlen(kCRLF)); |
| } |
| |
| operator net::DrainableIOBuffer*() { |
| return new net::DrainableIOBuffer(io_buf_, bytes_written_); |
| } |
| |
| bool is_ok() { return is_ok_; } |
| |
| private: |
| net::IOBuffer* io_buf_; |
| size_t bytes_written_; |
| bool is_ok_; |
| } buffer; |
| |
| buffer.WriteLine("HTTP/1.1 101 WebSocket Protocol Handshake"); |
| buffer.WriteLine("Upgrade: WebSocket"); |
| buffer.WriteLine("Connection: Upgrade"); |
| |
| { |
| // Take care of Location field. |
| char tmp[2048]; |
| int rv = base::snprintf(tmp, sizeof(tmp), |
| "%s: %s", |
| kLocationFieldName, |
| location.c_str()); |
| if (rv <= 0 || rv + 0u >= sizeof(tmp)) |
| return net::ERR_LIMIT_VIOLATION; |
| buffer.WriteLine(tmp); |
| } |
| { |
| // Take care of Origin field. |
| char tmp[2048]; |
| int rv = base::snprintf(tmp, sizeof(tmp), |
| "%s: %s", |
| kOriginFieldName, |
| fields.Get(kPlainOriginFieldName).c_str()); |
| if (rv <= 0 || rv + 0u >= sizeof(tmp)) |
| return net::ERR_LIMIT_VIOLATION; |
| buffer.WriteLine(tmp); |
| } |
| if (!subprotocol.empty()) { |
| char tmp[2048]; |
| int rv = base::snprintf(tmp, sizeof(tmp), |
| "%s: %s", |
| kProtocolFieldName, |
| subprotocol.c_str()); |
| if (rv <= 0 || rv + 0u >= sizeof(tmp)) |
| return net::ERR_LIMIT_VIOLATION; |
| buffer.WriteLine(tmp); |
| } |
| buffer.WriteLine(""); |
| buffer.Write(&challenge_response, sizeof(challenge_response)); |
| |
| if (!buffer.is_ok()) |
| return net::ERR_LIMIT_VIOLATION; |
| |
| pending_reqs_.push_back(PendingReq( |
| PendingReq::TYPE_WRITE_METADATA, buffer, NULL)); |
| DCHECK_GT(term_pos - buf, 0); |
| return term_pos - buf; |
| } |
| |
| // Removes frame delimiters and returns net number of data bytes (or error). |
| // |out| may be equal to |buf|, in that case it is in-place operation. |
| int ProcessDataFrames(char* buf, int buf_len, char* out, int out_len) { |
| if (out_len < buf_len) { |
| NOTREACHED(); |
| return net::ERR_UNEXPECTED; |
| } |
| int out_pos = 0; |
| for (char* p = buf; p < buf + buf_len; ++p) { |
| switch (phase_) { |
| case PHASE_FRAME_INSIDE: { |
| if (*p == '\x00') |
| return net::ERR_WS_PROTOCOL_ERROR; |
| if (*p == '\xff') |
| phase_ = PHASE_FRAME_OUTSIDE; |
| else |
| out[out_pos++] = *p; |
| break; |
| } |
| case PHASE_FRAME_OUTSIDE: { |
| if (*p == '\x00') { |
| phase_ = PHASE_FRAME_INSIDE; |
| } else if (*p == '\xff') { |
| phase_ = PHASE_FRAME_LENGTH; |
| frame_bytes_remaining_ = 0; |
| } |
| else { |
| return net::ERR_WS_PROTOCOL_ERROR; |
| } |
| break; |
| } |
| case PHASE_FRAME_LENGTH: { |
| static const int kValueBits = 7; |
| static const char kValueMask = (1 << kValueBits) - 1; |
| frame_bytes_remaining_ <<= kValueBits; |
| frame_bytes_remaining_ += (*p & kValueMask); |
| if (*p & ~kValueMask) { |
| // Check that next byte would not overflow. |
| if (frame_bytes_remaining_ > |
| (std::numeric_limits<int>::max() - ((1 << 7) - 1)) >> 7) { |
| return net::ERR_LIMIT_VIOLATION; |
| } |
| } else { |
| if (frame_bytes_remaining_ == 0) { |
| phase_ = PHASE_SHUT; |
| return out_pos; |
| } else { |
| phase_ = PHASE_FRAME_SKIP; |
| } |
| } |
| break; |
| } |
| case PHASE_FRAME_SKIP: { |
| DCHECK_GE(frame_bytes_remaining_, 1); |
| frame_bytes_remaining_ -= 1; |
| if (frame_bytes_remaining_ < 1) |
| phase_ = PHASE_FRAME_OUTSIDE; |
| break; |
| } |
| default: { |
| NOTREACHED(); |
| } |
| } |
| } |
| return out_pos; |
| } |
| |
| // State machinery. |
| Phase phase_; |
| |
| // Counts frame length for PHASE_FRAME_LENGTH and PHASE_FRAME_SKIP. |
| int frame_bytes_remaining_; |
| |
| // Underlying socket. |
| scoped_ptr<net::Socket> transport_socket_; |
| |
| // Validation is performed via delegate. |
| Delegate* delegate_; |
| |
| // IOBuffer used to communicate with transport at initial stage. |
| scoped_refptr<net::IOBuffer> handshake_buf_; |
| scoped_refptr<net::DrainableIOBuffer> fill_handshake_buf_; |
| scoped_refptr<net::DrainableIOBuffer> process_handshake_buf_; |
| |
| // Pending io requests we need to complete. |
| std::deque<PendingReq> pending_reqs_; |
| |
| // Callbacks from transport to us. |
| scoped_ptr<net::CompletionCallback> transport_read_callback_; |
| scoped_ptr<net::CompletionCallback> transport_write_callback_; |
| |
| // Whether transport requests are pending. |
| bool is_transport_read_pending_; |
| bool is_transport_write_pending_; |
| |
| ScopedRunnableMethodFactory<WebSocketServerSocketImpl> method_factory_; |
| |
| DISALLOW_COPY_AND_ASSIGN(WebSocketServerSocketImpl); |
| }; |
| |
| } // namespace |
| |
| namespace net { |
| |
| WebSocketServerSocket* CreateWebSocketServerSocket( |
| Socket* transport_socket, WebSocketServerSocket::Delegate* delegate) { |
| return new WebSocketServerSocketImpl(transport_socket, delegate); |
| } |
| |
| WebSocketServerSocket::~WebSocketServerSocket() { |
| } |
| |
| } // namespace net; |