blob: efdbb3204c29e2cafc8a8c8eb8258c9c0ec69290 [file] [log] [blame]
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/protocol/negotiating_authenticator.h"
#include <algorithm>
#include <sstream>
#include "base/logging.h"
#include "base/string_split.h"
#include "crypto/rsa_private_key.h"
#include "remoting/protocol/channel_authenticator.h"
#include "remoting/protocol/v2_authenticator.h"
#include "third_party/libjingle/source/talk/xmllite/xmlelement.h"
namespace remoting {
namespace protocol {
namespace {
const buzz::StaticQName kMethodAttributeQName = { "", "method" };
const buzz::StaticQName kSupportedMethodsAttributeQName =
{ "", "supported-methods" };
const char kSupportedMethodsSeparator = ',';
} // namespace
// static
bool NegotiatingAuthenticator::IsNegotiableMessage(
const buzz::XmlElement* message) {
return message->HasAttr(kSupportedMethodsAttributeQName);
}
// static
scoped_ptr<Authenticator> NegotiatingAuthenticator::CreateForClient(
const std::string& authentication_tag,
const std::string& shared_secret,
const std::vector<AuthenticationMethod>& methods) {
scoped_ptr<NegotiatingAuthenticator> result(
new NegotiatingAuthenticator(MESSAGE_READY));
result->authentication_tag_ = authentication_tag;
result->shared_secret_ = shared_secret;
DCHECK(!methods.empty());
for (std::vector<AuthenticationMethod>::const_iterator it = methods.begin();
it != methods.end(); ++it) {
result->AddMethod(*it);
}
return scoped_ptr<Authenticator>(result.Pass());
}
// static
scoped_ptr<Authenticator> NegotiatingAuthenticator::CreateForHost(
const std::string& local_cert,
const crypto::RSAPrivateKey& local_private_key,
const std::string& shared_secret_hash,
AuthenticationMethod::HashFunction hash_function) {
scoped_ptr<NegotiatingAuthenticator> result(
new NegotiatingAuthenticator(WAITING_MESSAGE));
result->local_cert_ = local_cert;
result->local_private_key_.reset(local_private_key.Copy());
result->shared_secret_hash_ = shared_secret_hash;
result->AddMethod(AuthenticationMethod::Spake2(hash_function));
return scoped_ptr<Authenticator>(result.Pass());
}
NegotiatingAuthenticator::NegotiatingAuthenticator(
Authenticator::State initial_state)
: certificate_sent_(false),
current_method_(AuthenticationMethod::Invalid()),
state_(initial_state),
rejection_reason_(INVALID_CREDENTIALS) {
}
NegotiatingAuthenticator::~NegotiatingAuthenticator() {
}
Authenticator::State NegotiatingAuthenticator::state() const {
return state_;
}
Authenticator::RejectionReason
NegotiatingAuthenticator::rejection_reason() const {
return rejection_reason_;
}
void NegotiatingAuthenticator::ProcessMessage(const buzz::XmlElement* message) {
DCHECK_EQ(state(), WAITING_MESSAGE);
std::string method_attr = message->Attr(kMethodAttributeQName);
AuthenticationMethod method = AuthenticationMethod::FromString(method_attr);
// Check if the remote end specified a method that is not supported locally.
if (method.is_valid() &&
std::find(methods_.begin(), methods_.end(), method) == methods_.end()) {
method = AuthenticationMethod::Invalid();
}
// If the remote peer did not specify auth method or specified unknown
// method then select the first known method from the supported-methods
// attribute.
if (!method.is_valid()) {
std::string supported_methods_attr =
message->Attr(kSupportedMethodsAttributeQName);
if (supported_methods_attr.empty()) {
// Message contains neither method nor supported-methods attributes.
state_ = REJECTED;
rejection_reason_ = PROTOCOL_ERROR;
return;
}
// Find the first mutually-supported method in the peer's list of
// supported-methods.
std::vector<std::string> supported_methods_strs;
base::SplitString(supported_methods_attr, kSupportedMethodsSeparator,
&supported_methods_strs);
for (std::vector<std::string>::iterator it = supported_methods_strs.begin();
it != supported_methods_strs.end(); ++it) {
AuthenticationMethod list_value = AuthenticationMethod::FromString(*it);
if (list_value.is_valid() &&
std::find(methods_.begin(),
methods_.end(), list_value) != methods_.end()) {
// Found common method.
method = list_value;
break;
}
}
if (!method.is_valid()) {
// Failed to find a common auth method.
state_ = REJECTED;
rejection_reason_ = PROTOCOL_ERROR;
return;
}
// Drop the current message because we've chosen a different
// method.
state_ = MESSAGE_READY;
}
DCHECK(method.is_valid());
// Replace current authenticator if the method has changed.
if (method != current_method_) {
current_method_ = method;
CreateAuthenticator(state_);
}
if (state_ == WAITING_MESSAGE) {
current_authenticator_->ProcessMessage(message);
state_ = current_authenticator_->state();
if (state_ == REJECTED)
rejection_reason_ = current_authenticator_->rejection_reason();
}
}
scoped_ptr<buzz::XmlElement> NegotiatingAuthenticator::GetNextMessage() {
DCHECK_EQ(state(), MESSAGE_READY);
bool add_supported_methods_attr = false;
// Initialize current method in case it is not initialized
// yet. Normally happens only on client.
if (!current_method_.is_valid()) {
CHECK(!methods_.empty());
// Initially try the first method.
current_method_ = methods_[0];
CreateAuthenticator(MESSAGE_READY);
add_supported_methods_attr = true;
}
scoped_ptr<buzz::XmlElement> result =
current_authenticator_->GetNextMessage();
state_ = current_authenticator_->state();
DCHECK_NE(state_, REJECTED);
result->AddAttr(kMethodAttributeQName, current_method_.ToString());
if (add_supported_methods_attr) {
std::stringstream supported_methods(std::stringstream::out);
for (std::vector<AuthenticationMethod>::iterator it = methods_.begin();
it != methods_.end(); ++it) {
if (it != methods_.begin())
supported_methods << kSupportedMethodsSeparator;
supported_methods << it->ToString();
}
result->AddAttr(kSupportedMethodsAttributeQName, supported_methods.str());
}
return result.Pass();
}
void NegotiatingAuthenticator::AddMethod(const AuthenticationMethod& method) {
DCHECK(method.is_valid());
methods_.push_back(method);
}
scoped_ptr<ChannelAuthenticator>
NegotiatingAuthenticator::CreateChannelAuthenticator() const {
DCHECK_EQ(state(), ACCEPTED);
return current_authenticator_->CreateChannelAuthenticator();
}
bool NegotiatingAuthenticator::is_host_side() const {
return local_private_key_.get() != NULL;
}
void NegotiatingAuthenticator::CreateAuthenticator(State initial_state) {
if (is_host_side()) {
current_authenticator_ = V2Authenticator::CreateForHost(
local_cert_, *local_private_key_.get(),
shared_secret_hash_, initial_state);
} else {
current_authenticator_ = V2Authenticator::CreateForClient(
AuthenticationMethod::ApplyHashFunction(
current_method_.hash_function(),
authentication_tag_, shared_secret_),
initial_state);
}
}
} // namespace protocol
} // namespace remoting