blob: 5c592a9f26e6de4642e717e930fc8d509a3bf227 [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/base/instance_identity_token_getter_impl.h"
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include "base/base64.h"
#include "base/environment.h"
#include "base/functional/bind.h"
#include "base/json/json_writer.h"
#include "base/memory/ref_counted.h"
#include "base/run_loop.h"
#include "base/strings/stringprintf.h"
#include "base/test/task_environment.h"
#include "net/http/http_status_code.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
#include "services/network/test/test_shared_url_loader_factory.h"
#include "services/network/test/test_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace remoting {
namespace {
constexpr char kTestAudience[] = "audience_for_testing";
// Matches the URL generated for requests from ComputeEngineServiceClient.
constexpr char kHttpMetadataRequestUrl[] =
"http://metadata.google.internal/computeMetadata/v1/instance/"
"service-accounts/default/identity?audience=audience_for_testing&"
"format=full";
// The environment variable used to override the default metadata server host
// when code is run on a Compute Engine instance.
constexpr char kGceMetadataHostVarName[] = "GCE_METADATA_HOST";
// Constants used for testing Compute Engine Metadata server overrides.
constexpr char kTestMetadataServerHost[] = "override.google.internal";
constexpr char kOverriddenMetadataRequestUrl[] =
"http://override.google.internal/computeMetadata/v1/instance/"
"service-accounts/default/identity?audience=audience_for_testing&"
"format=full";
// This JWT is valid but decoding will fail if kStrict is used because the
// length of the header and payload are not divisible by 4 as the padding has
// been stripped which is how tokens received from the metadata server are
// formatted.
constexpr std::string_view kJwtWithoutPadding =
// Header
"eyJhbGciOiJSUzI1NiIsImtpZCI6ImtpZC1zaWduYXR1cmUiLCJ0eXAiOiJKV1QifQ."
// Payload
"eyJhdWQiOiJhdWRpZW5jZV9mb3JfdGVzdGluZyIsImF6cCI6IjEyMzQ1IiwiZXhwIjo1NzQw"
"MTcsImdvb2dsZSI6eyJjb21wdXRlX2VuZ2luZSI6eyJpbnN0YW5jZV9pZCI6IjEyMzQ1Njc4"
"OSIsImluc3RhbmNlX25hbWUiOiJteS1pbnN0YW5jZSIsInByb2plY3RfaWQiOiJteS1wcm9q"
"ZWN0IiwicHJvamVjdF9udW1iZXIiOjU0MzIxLCJ6b25lIjoidXMtd2lsZC13ZXN0MS16In19"
"LCJpYXQiOjU3MDQxNywiaXNzIjoiRzAwZ2xlIiwic3ViIjoiMTIzNDUifQ."
// Signature
"SIGNATURE!!";
base::Value::Dict CreateJwtHeaderDict() {
return base::Value::Dict()
.Set("typ", "JWT")
.Set("kid", "kid-signature")
.Set("alg", "RS256");
}
base::Value::Dict CreateJwtPayloadDict(base::Time now = base::Time::Now()) {
auto compute_engine_dict = base::Value::Dict()
.Set("instance_id", "123456789")
.Set("instance_name", "my-instance")
.Set("project_id", "my-project")
.Set("project_number", 54321)
.Set("zone", "us-wild-west1-z");
return base::Value::Dict()
.Set("iss", "G00gle")
.Set("iat", static_cast<int>(now.InSecondsFSinceUnixEpoch()))
.Set("exp", static_cast<int>(
(now + base::Minutes(60)).InSecondsFSinceUnixEpoch()))
.Set("aud", kTestAudience)
.Set("sub", "12345")
.Set("azp", "12345")
.Set("google", base::Value::Dict().Set("compute_engine",
std::move(compute_engine_dict)));
}
std::string Base64EncodeDict(base::Value::Dict dict) {
return base::Base64Encode(*base::WriteJson(std::move(dict)));
}
std::string GetBase64EncodedHeader() {
return Base64EncodeDict(CreateJwtHeaderDict());
}
std::string GetBase64EncodedPayload(base::Time now = base::Time::Now()) {
return Base64EncodeDict(CreateJwtPayloadDict(now));
}
std::string GenerateValidJwt(std::string header, std::string payload) {
return header + "." + payload + "." + "signature";
}
std::string GenerateValidJwt(base::Time now = base::Time::Now()) {
return GenerateValidJwt(GetBase64EncodedHeader(), GetBase64EncodedPayload());
}
struct TokenParams {
TokenParams(std::string header, std::string payload, std::string test_name)
: header(std::move(header)),
payload(std::move(payload)),
test_name(std::move(test_name)) {}
std::string header;
std::string payload;
std::string test_name;
};
} // namespace
class InstanceIdentityTokenGetterImplTest
: public testing::Test,
public testing::WithParamInterface<TokenParams> {
public:
InstanceIdentityTokenGetterImplTest();
~InstanceIdentityTokenGetterImplTest() override;
void SetUp() override;
void OnTokenRetrieved(std::string_view token);
protected:
void RunUntilQuit();
void SetTokenResponse(
std::string_view response_body,
std::string_view metadata_server_url = kHttpMetadataRequestUrl);
void SetErrorResponse(net::HttpStatusCode status);
void ResetQuitClosure();
void ClearTokenResponse();
void FastForwardBy(base::TimeDelta duration);
void SetMetadataServerEnvVar(std::string_view metadata_server_host);
InstanceIdentityTokenGetterImpl& instance_identity_token_getter() {
return *instance_identity_token_getter_;
}
void set_pending_callback_count(int count) {
pending_callback_count_ = count;
}
const std::optional<std::string>& token() { return token_; }
void clear_token() { token_.reset(); }
size_t url_loader_request_count() {
return test_url_loader_factory_.total_requests();
}
std::string_view valid_jwt() { return valid_jwt_; }
private:
int pending_callback_count_ = 1;
std::optional<std::string> token_;
// Generate once and store so this token can be used for comparisons.
std::string valid_jwt_;
base::RepeatingClosure quit_closure_;
base::test::TaskEnvironment task_environment_{
base::test::TaskEnvironment::MainThreadType::IO,
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
std::unique_ptr<base::Environment> environment_ = base::Environment::Create();
network::TestURLLoaderFactory test_url_loader_factory_;
scoped_refptr<network::SharedURLLoaderFactory> shared_url_loader_factory_;
std::unique_ptr<InstanceIdentityTokenGetterImpl>
instance_identity_token_getter_;
};
InstanceIdentityTokenGetterImplTest::InstanceIdentityTokenGetterImplTest() =
default;
InstanceIdentityTokenGetterImplTest::~InstanceIdentityTokenGetterImplTest() =
default;
void InstanceIdentityTokenGetterImplTest::SetUp() {
valid_jwt_ = GenerateValidJwt();
shared_url_loader_factory_ = test_url_loader_factory_.GetSafeWeakWrapper();
// Unset any pre-existing environment vars before constructing the
// InstanceIdentityTokenGetterImpl so the externally set values are not used.
environment_->UnSetVar(kGceMetadataHostVarName);
instance_identity_token_getter_ =
std::make_unique<InstanceIdentityTokenGetterImpl>(
kTestAudience, shared_url_loader_factory_);
quit_closure_ = task_environment_.QuitClosure();
}
void InstanceIdentityTokenGetterImplTest::RunUntilQuit() {
task_environment_.RunUntilQuit();
}
void InstanceIdentityTokenGetterImplTest::SetTokenResponse(
std::string_view response_body,
std::string_view metadata_server_url) {
ClearTokenResponse();
test_url_loader_factory_.AddResponse(metadata_server_url, response_body);
}
void InstanceIdentityTokenGetterImplTest::ClearTokenResponse() {
test_url_loader_factory_.ClearResponses();
}
void InstanceIdentityTokenGetterImplTest::SetErrorResponse(
net::HttpStatusCode status) {
ClearTokenResponse();
test_url_loader_factory_.AddResponse(kHttpMetadataRequestUrl,
/*content=*/std::string(), status);
}
void InstanceIdentityTokenGetterImplTest::ResetQuitClosure() {
ASSERT_TRUE(quit_closure_.is_null());
quit_closure_ = task_environment_.QuitClosure();
}
void InstanceIdentityTokenGetterImplTest::OnTokenRetrieved(
std::string_view token) {
// If the callback has been run previously, make sure each callback receives
// the same value.
if (token_.has_value()) {
EXPECT_EQ(*token_, token);
} else {
token_ = token;
}
pending_callback_count_--;
if (pending_callback_count_ == 0) {
std::move(quit_closure_).Run();
}
}
void InstanceIdentityTokenGetterImplTest::FastForwardBy(
base::TimeDelta duration) {
task_environment_.FastForwardBy(duration);
}
void InstanceIdentityTokenGetterImplTest::SetMetadataServerEnvVar(
std::string_view metadata_server_host) {
environment_->SetVar(kGceMetadataHostVarName,
std::string(metadata_server_host));
// Recreate the InstanceIdentityTokenGetterImpl so the new value is used.
instance_identity_token_getter_ =
std::make_unique<InstanceIdentityTokenGetterImpl>(
kTestAudience, shared_url_loader_factory_);
}
TEST_F(InstanceIdentityTokenGetterImplTest, SingleRequest) {
SetTokenResponse(valid_jwt());
instance_identity_token_getter().RetrieveToken(
base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
base::Unretained(this)));
RunUntilQuit();
ASSERT_TRUE(token().has_value());
ASSERT_EQ(*token(), valid_jwt());
ASSERT_EQ(url_loader_request_count(), 1U);
}
TEST_F(InstanceIdentityTokenGetterImplTest,
SingleRequestWithCustomMetadataServer) {
SetMetadataServerEnvVar(kTestMetadataServerHost);
SetTokenResponse(valid_jwt(), kOverriddenMetadataRequestUrl);
instance_identity_token_getter().RetrieveToken(
base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
base::Unretained(this)));
RunUntilQuit();
ASSERT_TRUE(token().has_value());
ASSERT_EQ(*token(), valid_jwt());
ASSERT_EQ(url_loader_request_count(), 1U);
}
TEST_F(InstanceIdentityTokenGetterImplTest, JwtWithoutPadding) {
// Base64 decode will fail if kStrict is used.
SetTokenResponse(kJwtWithoutPadding);
instance_identity_token_getter().RetrieveToken(
base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
base::Unretained(this)));
RunUntilQuit();
ASSERT_TRUE(token().has_value());
ASSERT_EQ(*token(), kJwtWithoutPadding);
ASSERT_EQ(url_loader_request_count(), 1U);
}
TEST_F(InstanceIdentityTokenGetterImplTest, MultipleRequests) {
const int kQueuedCallbackCount = 10;
set_pending_callback_count(kQueuedCallbackCount);
for (int i = 0; i < kQueuedCallbackCount; i++) {
instance_identity_token_getter().RetrieveToken(
base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
base::Unretained(this)));
}
SetTokenResponse(valid_jwt());
RunUntilQuit();
ASSERT_TRUE(token().has_value());
ASSERT_EQ(*token(), valid_jwt());
ASSERT_EQ(url_loader_request_count(), 1U);
}
TEST_F(InstanceIdentityTokenGetterImplTest, CachedTokenReturned) {
instance_identity_token_getter().RetrieveToken(
base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
base::Unretained(this)));
SetTokenResponse(valid_jwt());
RunUntilQuit();
ASSERT_TRUE(token().has_value());
ASSERT_EQ(*token(), valid_jwt());
// Call a second time and verify a token is provided w/o calling the service.
ClearTokenResponse();
ResetQuitClosure();
set_pending_callback_count(1);
instance_identity_token_getter().RetrieveToken(
base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
base::Unretained(this)));
RunUntilQuit();
ASSERT_TRUE(token().has_value());
ASSERT_EQ(*token(), valid_jwt());
ASSERT_EQ(url_loader_request_count(), 1U);
}
TEST_F(InstanceIdentityTokenGetterImplTest, CachedTokenIgnored) {
auto first_jwt_response = valid_jwt();
instance_identity_token_getter().RetrieveToken(
base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
base::Unretained(this)));
SetTokenResponse(first_jwt_response);
RunUntilQuit();
ASSERT_TRUE(token().has_value());
ASSERT_EQ(*token(), first_jwt_response);
// Call again and verify a token is provided after calling the service.
FastForwardBy(base::Hours(1));
ResetQuitClosure();
set_pending_callback_count(1);
clear_token();
// Generate a new JWT with updated timestamp.
auto second_jwt_response = GenerateValidJwt();
SetTokenResponse(second_jwt_response);
instance_identity_token_getter().RetrieveToken(
base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
base::Unretained(this)));
RunUntilQuit();
ASSERT_TRUE(token().has_value());
ASSERT_EQ(*token(), second_jwt_response);
ASSERT_EQ(url_loader_request_count(), 2U);
}
TEST_F(InstanceIdentityTokenGetterImplTest, ServiceFailureReturnsEmptyString) {
instance_identity_token_getter().RetrieveToken(
base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
base::Unretained(this)));
SetErrorResponse(net::HTTP_BAD_REQUEST);
RunUntilQuit();
ASSERT_TRUE(token().has_value());
ASSERT_TRUE(token()->empty());
}
TEST_F(InstanceIdentityTokenGetterImplTest, UnsignedJwtReturnsEmptyToken) {
// Set response to a token which is missing the signature.
SetTokenResponse(GetBase64EncodedHeader() + "." + GetBase64EncodedPayload());
instance_identity_token_getter().RetrieveToken(
base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
base::Unretained(this)));
RunUntilQuit();
ASSERT_TRUE(token().has_value());
ASSERT_TRUE(token()->empty());
}
TEST_P(InstanceIdentityTokenGetterImplTest, InvalidJwtReturnsEmptyToken) {
SetTokenResponse(GetParam().header + "." + GetParam().payload + ".signature");
instance_identity_token_getter().RetrieveToken(
base::BindOnce(&InstanceIdentityTokenGetterImplTest::OnTokenRetrieved,
base::Unretained(this)));
RunUntilQuit();
ASSERT_TRUE(token().has_value());
ASSERT_TRUE(token()->empty());
}
INSTANTIATE_TEST_SUITE_P(
InstanceIdentityTokenGetterImplTest,
InstanceIdentityTokenGetterImplTest,
::testing::Values(
TokenParams("", "", "EmptyHeaderAndPayload"),
TokenParams("a." + GetBase64EncodedHeader(),
GetBase64EncodedPayload(),
"ExtraTokenSegment"),
TokenParams(GetBase64EncodedHeader() + "====",
GetBase64EncodedPayload() + "====",
"ExtraPadding"),
TokenParams("", GetBase64EncodedPayload(), "EmptyHeader"),
TokenParams(GetBase64EncodedHeader(), "", "EmptyPayload"),
TokenParams(Base64EncodeDict(base::Value::Dict()),
GetBase64EncodedPayload(),
"HeaderIsAnEmptyDict"),
TokenParams(GetBase64EncodedHeader(),
Base64EncodeDict(base::Value::Dict()),
"PayloadIsAnEmptyDict"),
TokenParams(*base::WriteJson(CreateJwtHeaderDict()),
GetBase64EncodedPayload(),
"HeaderNotBase64Encoded"),
TokenParams(GetBase64EncodedHeader(),
*base::WriteJson(CreateJwtPayloadDict()),
"PayloadNotBase64Encoded"),
TokenParams(base::Base64Encode("I'm JSON!"),
GetBase64EncodedPayload(),
"HeaderIsNotValidJson"),
TokenParams(GetBase64EncodedHeader(),
base::Base64Encode("I'm JSON!"),
"PayloadIsNotValidJson"),
TokenParams(Base64EncodeDict(base::Value::Dict()
.Set("alg", "RS256")
.Set("typ", "JWT")),
GetBase64EncodedPayload(),
"HeaderIsMissingMembers"),
TokenParams(GetBase64EncodedHeader(),
Base64EncodeDict(base::Value::Dict().Set("iss", "blergh")),
"PayloadIsMissingMembers"),
TokenParams(GetBase64EncodedHeader(),
Base64EncodeDict(CreateJwtPayloadDict().SetByDottedPath(
"google",
base::Value::Dict())),
"NoComputeEngineDict"),
TokenParams(GetBase64EncodedHeader(),
Base64EncodeDict(CreateJwtPayloadDict().SetByDottedPath(
"google.compute_engine",
base::Value::Dict())),
"EmptyComputeEngineDict"),
TokenParams(GetBase64EncodedHeader(),
Base64EncodeDict(CreateJwtPayloadDict().SetByDottedPath(
"google.compute_engine",
base::Value::Dict().Set("instance_id",
"test-instance-id"))),
"ComputeEngineDictMissingValues")),
[](const testing::TestParamInfo<
InstanceIdentityTokenGetterImplTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace remoting