blob: f894110238c85bf2d2d153fdc4d743932b13f325 [file] [log] [blame]
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/cryptauth/secure_channel.h"
#include <string>
#include "base/bind.h"
#include "base/memory/ptr_util.h"
#include "base/memory/weak_ptr.h"
#include "components/cryptauth/fake_authenticator.h"
#include "components/cryptauth/fake_connection.h"
#include "components/cryptauth/fake_cryptauth_service.h"
#include "components/cryptauth/fake_secure_context.h"
#include "components/cryptauth/fake_secure_message_delegate.h"
#include "components/cryptauth/remote_device_test_util.h"
#include "components/cryptauth/wire_message.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace cryptauth {
namespace {
const std::string test_user_id = "testUserId";
struct SecureChannelStatusChange {
SecureChannelStatusChange(
const SecureChannel::Status& old_status,
const SecureChannel::Status& new_status)
: old_status(old_status), new_status(new_status) {}
SecureChannel::Status old_status;
SecureChannel::Status new_status;
};
struct ReceivedMessage {
ReceivedMessage(const std::string& feature, const std::string& payload)
: feature(feature), payload(payload) {}
std::string feature;
std::string payload;
};
class TestObserver : public SecureChannel::Observer {
public:
TestObserver(SecureChannel* secure_channel)
: secure_channel_(secure_channel) {}
// SecureChannel::Observer:
void OnSecureChannelStatusChanged(
SecureChannel* secure_channel,
const SecureChannel::Status& old_status,
const SecureChannel::Status& new_status) override {
DCHECK(secure_channel == secure_channel_);
connection_status_changes_.push_back(
SecureChannelStatusChange(old_status, new_status));
}
void OnMessageReceived(SecureChannel* secure_channel,
const std::string& feature,
const std::string& payload) override {
DCHECK(secure_channel == secure_channel_);
received_messages_.push_back(ReceivedMessage(feature, payload));
}
std::vector<SecureChannelStatusChange>& connection_status_changes() {
return connection_status_changes_;
}
std::vector<ReceivedMessage>& received_messages() {
return received_messages_;
}
private:
SecureChannel* secure_channel_;
std::vector<SecureChannelStatusChange> connection_status_changes_;
std::vector<ReceivedMessage> received_messages_;
};
class TestAuthenticatorFactory : public DeviceToDeviceAuthenticator::Factory {
public:
TestAuthenticatorFactory() : last_instance_(nullptr) {}
std::unique_ptr<Authenticator> BuildInstance(
cryptauth::Connection* connection,
const std::string& account_id,
std::unique_ptr<cryptauth::SecureMessageDelegate>
secure_message_delegate) override {
last_instance_ = new FakeAuthenticator();
return base::WrapUnique(last_instance_);
}
Authenticator* last_instance() {
return last_instance_;
}
private:
Authenticator* last_instance_;
};
RemoteDevice CreateTestRemoteDevice() {
RemoteDevice remote_device = GenerateTestRemoteDevices(1)[0];
remote_device.user_id = test_user_id;
return remote_device;
}
class TestSecureChannel : public SecureChannel {
public:
TestSecureChannel(std::unique_ptr<Connection> connection,
CryptAuthService* cryptauth_service)
: SecureChannel(std::move(connection), cryptauth_service) {}
};
} // namespace
class CryptAuthSecureChannelTest : public testing::Test {
protected:
CryptAuthSecureChannelTest()
: test_device_(CreateTestRemoteDevice()),
weak_ptr_factory_(this) {}
void SetUp() override {
test_authenticator_factory_ = base::MakeUnique<TestAuthenticatorFactory>();
DeviceToDeviceAuthenticator::Factory::SetInstanceForTesting(
test_authenticator_factory_.get());
fake_secure_context_ = nullptr;
fake_cryptauth_service_ = base::MakeUnique<FakeCryptAuthService>();
fake_connection_ =
new FakeConnection(test_device_, /* should_auto_connect */ false);
EXPECT_FALSE(fake_connection_->observers().size());
secure_channel_ = base::MakeUnique<TestSecureChannel>(
base::WrapUnique(fake_connection_), fake_cryptauth_service_.get());
EXPECT_EQ(static_cast<size_t>(1), fake_connection_->observers().size());
EXPECT_EQ(secure_channel_.get(), fake_connection_->observers()[0]);
test_observer_ = base::MakeUnique<TestObserver>(secure_channel_.get());
secure_channel_->AddObserver(test_observer_.get());
}
void TearDown() override {
// All state changes should have already been verified. This ensures that
// no test has missed one.
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>());
// Same with received messages.
VerifyReceivedMessages(std::vector<ReceivedMessage>());
// Same with messages being sent.
VerifyNoMessageBeingSent();
}
void VerifyConnectionStateChanges(
const std::vector<SecureChannelStatusChange>& expected_changes) {
verified_status_changes_.insert(
verified_status_changes_.end(),
expected_changes.begin(),
expected_changes.end());
ASSERT_EQ(
verified_status_changes_.size(),
test_observer_->connection_status_changes().size());
for (size_t i = 0; i < verified_status_changes_.size(); i++) {
EXPECT_EQ(
verified_status_changes_[i].old_status,
test_observer_->connection_status_changes()[i].old_status);
EXPECT_EQ(
verified_status_changes_[i].new_status,
test_observer_->connection_status_changes()[i].new_status);
}
}
void VerifyReceivedMessages(
const std::vector<ReceivedMessage>& expected_messages) {
verified_received_messages_.insert(
verified_received_messages_.end(),
expected_messages.begin(),
expected_messages.end());
ASSERT_EQ(
verified_received_messages_.size(),
test_observer_->received_messages().size());
for (size_t i = 0; i < verified_received_messages_.size(); i++) {
EXPECT_EQ(
verified_received_messages_[i].feature,
test_observer_->received_messages()[i].feature);
EXPECT_EQ(
verified_received_messages_[i].payload,
test_observer_->received_messages()[i].payload);
}
}
void FailAuthentication(Authenticator::Result result) {
ASSERT_NE(result, Authenticator::Result::SUCCESS);
FakeAuthenticator* authenticator = static_cast<FakeAuthenticator*>(
test_authenticator_factory_->last_instance());
authenticator->last_callback().Run(result, nullptr);
}
void AuthenticateSuccessfully() {
FakeAuthenticator* authenticator = static_cast<FakeAuthenticator*>(
test_authenticator_factory_->last_instance());
fake_secure_context_ = new FakeSecureContext();
authenticator->last_callback().Run(
Authenticator::Result::SUCCESS, base::WrapUnique(fake_secure_context_));
}
void ConnectAndAuthenticate() {
secure_channel_->Initialize();
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::DISCONNECTED,
SecureChannel::Status::CONNECTING
}
});
fake_connection_->CompleteInProgressConnection(/* success */ true);
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::CONNECTING,
SecureChannel::Status::CONNECTED
},
{
SecureChannel::Status::CONNECTED,
SecureChannel::Status::AUTHENTICATING
}
});
AuthenticateSuccessfully();
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::AUTHENTICATING,
SecureChannel::Status::AUTHENTICATED
}
});
}
void StartSendingMessage(
const std::string& feature, const std::string& payload) {
secure_channel_->SendMessage(feature, payload);
VerifyMessageBeingSent(feature, payload);
}
void StartAndFinishSendingMessage(
const std::string& feature, const std::string& payload, bool success) {
StartSendingMessage(feature, payload);
fake_connection_->FinishSendingMessageWithSuccess(success);
}
void VerifyNoMessageBeingSent() {
EXPECT_FALSE(fake_connection_->current_message());
}
void VerifyMessageBeingSent(
const std::string& feature, const std::string& payload) {
WireMessage* message_being_sent = fake_connection_->current_message();
// Note that despite the fact that |Encode()| has an asynchronous interface,
// the implementation will call |VerifyWireMessageContents()| synchronously.
fake_secure_context_->Encode(
payload,
base::Bind(&CryptAuthSecureChannelTest::VerifyWireMessageContents,
weak_ptr_factory_.GetWeakPtr(),
message_being_sent,
feature));
}
void VerifyWireMessageContents(
WireMessage* wire_message,
const std::string& expected_feature,
const std::string& expected_payload) {
EXPECT_EQ(expected_feature, wire_message->feature());
EXPECT_EQ(expected_payload, wire_message->payload());
}
// Owned by secure_channel_.
FakeConnection* fake_connection_;
std::unique_ptr<FakeCryptAuthService> fake_cryptauth_service_;
// Owned by secure_channel_ once authentication has completed successfully.
FakeSecureContext* fake_secure_context_;
std::vector<SecureChannelStatusChange> verified_status_changes_;
std::vector<ReceivedMessage> verified_received_messages_;
std::unique_ptr<SecureChannel> secure_channel_;
std::unique_ptr<TestObserver> test_observer_;
std::unique_ptr<TestAuthenticatorFactory> test_authenticator_factory_;
const RemoteDevice test_device_;
base::WeakPtrFactory<CryptAuthSecureChannelTest> weak_ptr_factory_;
private:
DISALLOW_COPY_AND_ASSIGN(CryptAuthSecureChannelTest);
};
TEST_F(CryptAuthSecureChannelTest, ConnectionAttemptFails) {
secure_channel_->Initialize();
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::DISCONNECTED,
SecureChannel::Status::CONNECTING
}
});
fake_connection_->CompleteInProgressConnection(/* success */ false);
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::CONNECTING,
SecureChannel::Status::DISCONNECTED
}
});
}
TEST_F(CryptAuthSecureChannelTest, DisconnectBeforeAuthentication) {
secure_channel_->Initialize();
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::DISCONNECTED,
SecureChannel::Status::CONNECTING
}
});
fake_connection_->Disconnect();
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::CONNECTING,
SecureChannel::Status::DISCONNECTED
}
});
}
TEST_F(CryptAuthSecureChannelTest, AuthenticationFails_Disconnect) {
secure_channel_->Initialize();
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::DISCONNECTED,
SecureChannel::Status::CONNECTING
}
});
fake_connection_->CompleteInProgressConnection(/* success */ true);
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::CONNECTING,
SecureChannel::Status::CONNECTED
},
{
SecureChannel::Status::CONNECTED,
SecureChannel::Status::AUTHENTICATING
}
});
FailAuthentication(Authenticator::Result::DISCONNECTED);
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::AUTHENTICATING,
SecureChannel::Status::DISCONNECTED
}
});
}
TEST_F(CryptAuthSecureChannelTest, AuthenticationFails_Failure) {
secure_channel_->Initialize();
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::DISCONNECTED,
SecureChannel::Status::CONNECTING
}
});
fake_connection_->CompleteInProgressConnection(/* success */ true);
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::CONNECTING,
SecureChannel::Status::CONNECTED
},
{
SecureChannel::Status::CONNECTED,
SecureChannel::Status::AUTHENTICATING
}
});
FailAuthentication(Authenticator::Result::FAILURE);
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::AUTHENTICATING,
SecureChannel::Status::DISCONNECTED
}
});
}
TEST_F(CryptAuthSecureChannelTest, SendMessage_DisconnectWhileSending) {
ConnectAndAuthenticate();
StartSendingMessage("feature", "payload");
fake_connection_->Disconnect();
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::AUTHENTICATED,
SecureChannel::Status::DISCONNECTED
}
});
fake_connection_->FinishSendingMessageWithSuccess(false);
// No further state change should have occurred.
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>());
}
TEST_F(
CryptAuthSecureChannelTest,
SendMessage_DisconnectWhileSending_ThenSendCompletedOccurs) {
ConnectAndAuthenticate();
StartSendingMessage("feature", "payload");
fake_connection_->Disconnect();
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::AUTHENTICATED,
SecureChannel::Status::DISCONNECTED
}
});
// If, due to a race condition, a disconnection occurs and |SendCompleted()|
// is called in the success case, nothing should occur.
fake_connection_->FinishSendingMessageWithSuccess(true);
// No further state change should have occurred.
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>());
}
TEST_F(CryptAuthSecureChannelTest, SendMessage_Failure) {
ConnectAndAuthenticate();
StartAndFinishSendingMessage("feature", "payload", /* success */ false);
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::AUTHENTICATED,
SecureChannel::Status::DISCONNECTED
}
});
}
TEST_F(CryptAuthSecureChannelTest, SendMessage_Success) {
ConnectAndAuthenticate();
StartAndFinishSendingMessage("feature", "payload", /* success */ true);
}
TEST_F(CryptAuthSecureChannelTest, SendMessage_MultipleMessages_Success) {
ConnectAndAuthenticate();
// Send a second message before the first has completed.
secure_channel_->SendMessage("feature1", "payload1");
secure_channel_->SendMessage("feature2", "payload2");
// The first message should still be sending.
VerifyMessageBeingSent("feature1", "payload1");
// Send the first message.
fake_connection_->FinishSendingMessageWithSuccess(true);
// Now, the second message should be sending.
VerifyMessageBeingSent("feature2", "payload2");
fake_connection_->FinishSendingMessageWithSuccess(true);
}
TEST_F(CryptAuthSecureChannelTest, SendMessage_MultipleMessages_FirstFails) {
ConnectAndAuthenticate();
// Send a second message before the first has completed.
secure_channel_->SendMessage("feature1", "payload1");
secure_channel_->SendMessage("feature2", "payload2");
// The first message should still be sending.
VerifyMessageBeingSent("feature1", "payload1");
// Fail sending the first message.
fake_connection_->FinishSendingMessageWithSuccess(false);
// The connection should have become disconnected.
VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange> {
{
SecureChannel::Status::AUTHENTICATED,
SecureChannel::Status::DISCONNECTED
}
});
// The first message failed, so no other ones should be tried afterward.
VerifyNoMessageBeingSent();
}
TEST_F(CryptAuthSecureChannelTest, ReceiveMessage) {
ConnectAndAuthenticate();
// Note: FakeSecureContext's Encode() function simply adds ", but encoded" to
// the end of the message.
fake_connection_->ReceiveMessage("feature", "payload, but encoded");
VerifyReceivedMessages(std::vector<ReceivedMessage> {
{"feature", "payload"}
});
}
TEST_F(CryptAuthSecureChannelTest, SendAndReceiveMessages) {
ConnectAndAuthenticate();
StartAndFinishSendingMessage("feature", "request1", /* success */ true);
// Note: FakeSecureContext's Encode() function simply adds ", but encoded" to
// the end of the message.
fake_connection_->ReceiveMessage("feature", "response1, but encoded");
VerifyReceivedMessages(std::vector<ReceivedMessage> {
{"feature", "response1"}
});
StartAndFinishSendingMessage("feature", "request2", /* success */ true);
fake_connection_->ReceiveMessage("feature", "response2, but encoded");
VerifyReceivedMessages(std::vector<ReceivedMessage> {
{"feature", "response2"}
});
}
} // namespace cryptauth