blob: f2ef831246fadd568b3277941392c1269544d85b [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/legion/secure_channel_impl.h"
#include <memory>
#include <vector>
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/task/sequenced_task_runner.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "components/legion/attestation_handler.h"
#include "components/legion/legion_common.h"
#include "components/legion/secure_session.h"
#include "components/legion/transport.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/oak/chromium/proto/session/session.pb.h"
#include "third_party/oak/chromium/proto/session/session.test.h"
namespace legion {
namespace {
std::string BytesToString(const Request& bytes) {
return std::string(bytes.begin(), bytes.end());
}
Request StringToBytes(const std::string& str) {
return Request(str.begin(), str.end());
}
using ::testing::_;
using ::testing::Return;
using ::testing::StrictMock;
using oak::session::v1::AttestRequest;
using oak::session::v1::AttestResponse;
using oak::session::v1::EqualsSessionRequest;
using oak::session::v1::HandshakeRequest;
using oak::session::v1::HandshakeResponse;
using oak::session::v1::SessionRequest;
class MockTransport : public Transport {
public:
MOCK_METHOD(void,
SetResponseCallback,
(ResponseCallback callback),
(override));
MOCK_METHOD(void,
Send,
(const oak::session::v1::SessionRequest& request),
(override));
};
// Constants used to by FakeSecureSession.
constexpr char kEncryptedPrefix[] = "encrypted: ";
constexpr char kInvalidPublicKey[] = "invalid public key";
constexpr char kEncryptionMustFail[] = "encrypt: must fail";
constexpr char kDecryptionMustFail[] = "decrypt: must fail";
class FakeSecureSession : public SecureSession {
public:
FakeSecureSession() = default;
~FakeSecureSession() override = default;
FakeSecureSession(const FakeSecureSession&) = default;
FakeSecureSession& operator=(const FakeSecureSession&) = default;
FakeSecureSession(FakeSecureSession&&) = default;
FakeSecureSession& operator=(FakeSecureSession&&) = default;
void GetHandshakeMessage(
SecureSession::GetHandshakeMessageOnceCallback callback) override {
oak::session::v1::HandshakeRequest handshake_request;
auto task_runner = base::SequencedTaskRunner::GetCurrentDefault();
task_runner->PostTask(
FROM_HERE,
base::BindOnce(std::move(callback), std::move(handshake_request)));
}
// Runs callback with `true` if response's ephemeral public key is NOT equal
// to `kInvalidPublicKey`.
void ProcessHandshakeResponse(
const oak::session::v1::HandshakeResponse& response,
SecureSession::ProcessHandshakeResponseOnceCallback callback) override {
bool result = ProcessHandshakeResponseSync(response);
auto task_runner = base::SequencedTaskRunner::GetCurrentDefault();
task_runner->PostTask(FROM_HERE,
base::BindOnce(std::move(callback), result));
}
bool ProcessHandshakeResponseSync(
const oak::session::v1::HandshakeResponse& response) {
if (response.has_noise_handshake_message()) {
const auto& message = response.noise_handshake_message();
if (message.ephemeral_public_key() == kInvalidPublicKey) {
return false;
}
}
return true;
}
// Runs callback with `std::nullopt` if message is equal to
// `kEncryptionMustFail`.
//
// Otherwise adds "encrypted: " prefix to the message.
void Encrypt(const Request& message,
SecureSession::EncryptOnceCallback callback) override {
auto result = EncryptSync(message);
auto task_runner = base::SequencedTaskRunner::GetCurrentDefault();
task_runner->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), std::move(result)));
}
std::optional<oak::session::v1::EncryptedMessage> EncryptSync(
const Request& message) {
std::string message_str = BytesToString(message);
if (message_str == kEncryptionMustFail) {
return std::nullopt;
}
CHECK(!message_str.starts_with(kEncryptedPrefix));
message_str = kEncryptedPrefix + message_str;
oak::session::v1::EncryptedMessage encrypted_message;
encrypted_message.set_ciphertext(message_str.data(), message_str.size());
return encrypted_message;
}
// Runs callback with `std::nullopt` if message's ciphertext is equal to
// `kDecryptionMustFail`.
//
// Expects that message has "encrypted: " prefix and removes that prefix.
void Decrypt(const oak::session::v1::EncryptedMessage& message,
SecureSession::DecryptOnceCallback callback) override {
auto result = DecryptSync(message);
auto task_runner = base::SequencedTaskRunner::GetCurrentDefault();
task_runner->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), std::move(result)));
}
std::optional<Request> DecryptSync(
const oak::session::v1::EncryptedMessage& message) {
Request message_bytes(message.ciphertext().begin(),
message.ciphertext().end());
std::string message_str = BytesToString(message_bytes);
if (message_str == kDecryptionMustFail) {
return std::nullopt;
}
CHECK(message_str.starts_with(kEncryptedPrefix));
message_str = message_str.substr(strlen(kEncryptedPrefix));
return Request(message_str.begin(), message_str.end());
}
};
class MockAttestationHandler : public AttestationHandler {
public:
MOCK_METHOD(std::optional<oak::session::v1::AttestRequest>,
GetAttestationRequest,
(),
(override));
MOCK_METHOD(bool,
VerifyAttestationResponse,
(const oak::session::v1::AttestResponse& evidence),
(override));
};
class SecureChannelImplTest : public ::testing::Test {
protected:
SecureChannelImplTest() {
auto transport = std::make_unique<StrictMock<MockTransport>>();
transport_ = transport.get();
EXPECT_CALL(*transport_, SetResponseCallback(_))
.WillOnce(testing::SaveArg<0>(&response_callback_));
auto secure_session = std::make_unique<FakeSecureSession>();
secure_session_ = secure_session.get();
auto attestation_handler =
std::make_unique<StrictMock<MockAttestationHandler>>();
attestation_handler_ = attestation_handler.get();
secure_channel_ = std::make_unique<SecureChannelImpl>(
std::move(transport), std::move(secure_session),
std::move(attestation_handler));
}
void TearDown() override {
testing::Mock::VerifyAndClearExpectations(transport_);
testing::Mock::VerifyAndClearExpectations(attestation_handler_);
}
void SetUpHandshakeAndAttestation();
base::test::TaskEnvironment task_environment_;
std::unique_ptr<SecureChannelImpl> secure_channel_;
raw_ptr<MockTransport> transport_;
raw_ptr<FakeSecureSession> secure_session_;
raw_ptr<MockAttestationHandler> attestation_handler_;
Transport::ResponseCallback response_callback_;
};
void SecureChannelImplTest::SetUpHandshakeAndAttestation() {
oak::session::v1::SessionRequest expected_attestation_request;
expected_attestation_request.mutable_attest_request();
oak::session::v1::SessionRequest expected_handshake_request;
expected_handshake_request.mutable_handshake_request();
EXPECT_CALL(*attestation_handler_, GetAttestationRequest())
.WillOnce(Return(expected_attestation_request.attest_request()));
EXPECT_CALL(*transport_,
Send(EqualsSessionRequest(expected_attestation_request)))
.WillOnce([&]() {
oak::session::v1::SessionResponse response;
response.mutable_attest_response();
response_callback_.Run(response);
});
EXPECT_CALL(*attestation_handler_, VerifyAttestationResponse(_))
.WillOnce(Return(true));
EXPECT_CALL(*transport_,
Send(EqualsSessionRequest(expected_handshake_request)))
.WillOnce([&]() {
oak::session::v1::SessionResponse response;
response.mutable_handshake_response();
response_callback_.Run(response);
});
}
// Tests the successful establishment of a secure session and sending a single
// request.
TEST_F(SecureChannelImplTest, WriteAndEstablishSessionSucceeds) {
SetUpHandshakeAndAttestation();
oak::session::v1::SessionRequest expected_session_request;
{
oak::session::v1::EncryptedMessage encrypted_request;
encrypted_request.set_ciphertext("encrypted: secret request");
*expected_session_request.mutable_encrypted_message() = encrypted_request;
}
EXPECT_CALL(*transport_, Send(EqualsSessionRequest(expected_session_request)))
.WillOnce([&]() {
oak::session::v1::SessionResponse response;
{
oak::session::v1::EncryptedMessage encrypted_response;
encrypted_response.set_ciphertext("encrypted: secret response");
*response.mutable_encrypted_message() = encrypted_response;
}
response_callback_.Run(response);
});
base::test::TestFuture<base::expected<Response, ErrorCode>> future;
secure_channel_->SetResponseCallback(future.GetRepeatingCallback());
EXPECT_TRUE(secure_channel_->Write(StringToBytes("secret request")));
const auto& result = future.Get();
ASSERT_TRUE(result.has_value());
EXPECT_EQ(BytesToString(result.value()), "secret response");
}
// Tests that a closed channel is reported through the response callback.
TEST_F(SecureChannelImplTest, ChannelClosedIsReported) {
EXPECT_CALL(*attestation_handler_, GetAttestationRequest())
.WillOnce(Return(std::nullopt));
base::test::TestFuture<base::expected<Response, ErrorCode>> future;
secure_channel_->SetResponseCallback(future.GetRepeatingCallback());
EXPECT_TRUE(secure_channel_->Write(StringToBytes("secret request")));
const auto& result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kAttestationFailed);
}
// Tests the case where attestation verification fails, leading to a session
// failure.
TEST_F(SecureChannelImplTest, AttestationErrorFailsWrite) {
oak::session::v1::SessionRequest expected_attestation_request;
expected_attestation_request.mutable_attest_request();
oak::session::v1::SessionResponse attestation_session_response;
attestation_session_response.mutable_attest_response();
EXPECT_CALL(*attestation_handler_, GetAttestationRequest())
.WillOnce(Return(expected_attestation_request.attest_request()));
EXPECT_CALL(*transport_,
Send(EqualsSessionRequest(expected_attestation_request)))
.WillOnce([&]() {
response_callback_.Run(attestation_session_response);
});
EXPECT_CALL(*attestation_handler_, VerifyAttestationResponse(_))
.WillOnce(Return(false));
base::test::TestFuture<base::expected<Response, ErrorCode>> future;
secure_channel_->SetResponseCallback(future.GetRepeatingCallback());
EXPECT_TRUE(secure_channel_->Write(StringToBytes("secret request")));
const auto& result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kAttestationFailed);
}
// Tests a transport-level error during the attestation phase of session
// establishment.
TEST_F(SecureChannelImplTest, TransportErrorDuringAttestationFailsRequest) {
oak::session::v1::SessionRequest expected_attestation_request;
expected_attestation_request.mutable_attest_request();
EXPECT_CALL(*attestation_handler_, GetAttestationRequest())
.WillOnce(Return(expected_attestation_request.attest_request()));
EXPECT_CALL(*transport_,
Send(EqualsSessionRequest(expected_attestation_request)))
.WillOnce([&]() {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(
[](Transport::ResponseCallback cb) {
cb.Run(base::unexpected(Transport::TransportError::kError));
},
response_callback_));
});
base::test::TestFuture<base::expected<Response, ErrorCode>> future;
secure_channel_->SetResponseCallback(future.GetRepeatingCallback());
EXPECT_TRUE(secure_channel_->Write(StringToBytes("secret request")));
const auto& result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kAttestationFailed);
}
// Tests a transport-level error during the handshake phase of session
// establishment.
TEST_F(SecureChannelImplTest, TransportErrorDuringHandshakeFailsRequest) {
// Setup successful attestation.
{
oak::session::v1::SessionRequest expected_attestation_request;
expected_attestation_request.mutable_attest_request();
EXPECT_CALL(*attestation_handler_, GetAttestationRequest())
.WillOnce(Return(expected_attestation_request.attest_request()));
EXPECT_CALL(*transport_,
Send(EqualsSessionRequest(expected_attestation_request)))
.WillOnce([&]() {
oak::session::v1::SessionResponse response;
response.mutable_attest_response();
response_callback_.Run(response);
});
EXPECT_CALL(*attestation_handler_, VerifyAttestationResponse(_))
.WillOnce(Return(true));
}
// Setup transport failure during handshake.
{
oak::session::v1::SessionRequest expected_handshake_request;
expected_handshake_request.mutable_handshake_request();
EXPECT_CALL(*transport_,
Send(EqualsSessionRequest(expected_handshake_request)))
.WillOnce([&]() {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(
[](Transport::ResponseCallback cb) {
cb.Run(base::unexpected(Transport::TransportError::kError));
},
response_callback_));
});
}
base::test::TestFuture<base::expected<Response, ErrorCode>> future;
secure_channel_->SetResponseCallback(future.GetRepeatingCallback());
EXPECT_TRUE(secure_channel_->Write(StringToBytes("secret request")));
const auto& result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kHandshakeFailed);
}
// Tests a transport-level error after the session is established.
TEST_F(SecureChannelImplTest, TransportErrorAfterSessionEstablished) {
SetUpHandshakeAndAttestation();
oak::session::v1::SessionRequest expected_session_request;
{
oak::session::v1::EncryptedMessage encrypted_request;
encrypted_request.set_ciphertext("encrypted: secret request");
*expected_session_request.mutable_encrypted_message() = encrypted_request;
}
EXPECT_CALL(*transport_, Send(EqualsSessionRequest(expected_session_request)))
.WillOnce([&]() {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(
[](Transport::ResponseCallback cb) {
cb.Run(base::unexpected(Transport::TransportError::kError));
},
response_callback_));
});
base::test::TestFuture<base::expected<Response, ErrorCode>> future;
secure_channel_->SetResponseCallback(future.GetRepeatingCallback());
EXPECT_TRUE(secure_channel_->Write(StringToBytes("secret request")));
const auto& result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kNetworkError);
}
// Tests a failure in generating the initial attestation request.
TEST_F(SecureChannelImplTest, GetAttestationRequestFails) {
EXPECT_CALL(*attestation_handler_, GetAttestationRequest())
.WillOnce(Return(std::nullopt));
base::test::TestFuture<base::expected<Response, ErrorCode>> future;
secure_channel_->SetResponseCallback(future.GetRepeatingCallback());
EXPECT_TRUE(secure_channel_->Write(StringToBytes("secret request")));
const auto& result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kAttestationFailed);
}
// Tests a failure in processing the handshake response.
TEST_F(SecureChannelImplTest, ProcessHandshakeResponseFails) {
oak::session::v1::SessionRequest expected_attestation_request;
expected_attestation_request.mutable_attest_request();
oak::session::v1::SessionResponse attestation_session_response;
attestation_session_response.mutable_attest_response();
oak::session::v1::SessionRequest expected_handshake_request;
expected_handshake_request.mutable_handshake_request();
// Setting `kInvalidPublicKey` as an ephemeral public key will fail handshake
// on a client side.
oak::session::v1::SessionResponse handshake_session_response;
{
oak::session::v1::HandshakeResponse handshake_response;
auto* noise_handshake_message =
handshake_response.mutable_noise_handshake_message();
noise_handshake_message->set_ephemeral_public_key(kInvalidPublicKey);
*handshake_session_response.mutable_handshake_response() =
handshake_response;
}
EXPECT_CALL(*attestation_handler_, GetAttestationRequest())
.WillOnce(Return(expected_attestation_request.attest_request()));
EXPECT_CALL(*transport_,
Send(EqualsSessionRequest(expected_attestation_request)))
.WillOnce([&]() {
response_callback_.Run(attestation_session_response);
});
EXPECT_CALL(*attestation_handler_, VerifyAttestationResponse(_))
.WillOnce(Return(true));
EXPECT_CALL(*transport_,
Send(EqualsSessionRequest(expected_handshake_request)))
.WillOnce([&]() {
response_callback_.Run(handshake_session_response);
});
base::test::TestFuture<base::expected<Response, ErrorCode>> future;
secure_channel_->SetResponseCallback(future.GetRepeatingCallback());
EXPECT_TRUE(secure_channel_->Write(StringToBytes("secret request")));
const auto& result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kHandshakeFailed);
}
// Tests a failure to encrypt a request after the session is established.
TEST_F(SecureChannelImplTest, EncryptRequestFails) {
SetUpHandshakeAndAttestation();
base::test::TestFuture<base::expected<Response, ErrorCode>> future;
secure_channel_->SetResponseCallback(future.GetRepeatingCallback());
EXPECT_TRUE(secure_channel_->Write(StringToBytes(kEncryptionMustFail)));
const auto& result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kEncryptionFailed);
}
// Tests a failure to decrypt a response from the server.
TEST_F(SecureChannelImplTest, DecryptResponseFails) {
SetUpHandshakeAndAttestation();
oak::session::v1::SessionRequest expected_session_request;
{
oak::session::v1::EncryptedMessage encrypted_request;
encrypted_request.set_ciphertext("encrypted: secret request");
*expected_session_request.mutable_encrypted_message() = encrypted_request;
}
EXPECT_CALL(*transport_, Send(EqualsSessionRequest(expected_session_request)))
.WillOnce([&]() {
oak::session::v1::SessionResponse response;
{
oak::session::v1::EncryptedMessage encrypted_response;
encrypted_response.set_ciphertext(kDecryptionMustFail);
*response.mutable_encrypted_message() = encrypted_response;
}
response_callback_.Run(response);
});
base::test::TestFuture<base::expected<Response, ErrorCode>> future;
secure_channel_->SetResponseCallback(future.GetRepeatingCallback());
EXPECT_TRUE(secure_channel_->Write(StringToBytes("secret request")));
const auto& result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kDecryptionFailed);
}
// Tests receiving an empty response from the server after session
// establishment.
TEST_F(SecureChannelImplTest, EmptyResponseFailsRequest) {
SetUpHandshakeAndAttestation();
oak::session::v1::SessionRequest expected_session_request;
{
oak::session::v1::EncryptedMessage encrypted_request;
encrypted_request.set_ciphertext("encrypted: secret request");
*expected_session_request.mutable_encrypted_message() = encrypted_request;
}
EXPECT_CALL(*transport_, Send(EqualsSessionRequest(expected_session_request)))
.WillOnce([&]() {
// Return an empty response.
response_callback_.Run(oak::session::v1::SessionResponse());
});
base::test::TestFuture<base::expected<Response, ErrorCode>> future;
secure_channel_->SetResponseCallback(future.GetRepeatingCallback());
EXPECT_TRUE(secure_channel_->Write(StringToBytes("secret request")));
const auto& result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kNetworkError);
}
// Tests that `Write` returns false if the channel is closed.
TEST_F(SecureChannelImplTest, WriteInClosedState) {
EXPECT_CALL(*attestation_handler_, GetAttestationRequest())
.WillOnce(Return(std::nullopt));
base::test::TestFuture<base::expected<Response, ErrorCode>> future;
secure_channel_->SetResponseCallback(future.GetRepeatingCallback());
// First write triggers the failure.
EXPECT_TRUE(secure_channel_->Write(StringToBytes("secret request")));
ASSERT_FALSE(future.Get().has_value());
// Second write should fail immediately.
EXPECT_FALSE(secure_channel_->Write(StringToBytes("secret request")));
}
} // namespace
} // namespace legion