blob: 210f9bf76deb8e00676c8e7637a35d453a6c5d51 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/test.pb.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/signin/identity_test_environment_profile_adaptor.h"
#include "chrome/browser/ui/browser.h"
#include "chrome/test/base/in_process_browser_test.h"
#include "chrome/test/base/ui_test_utils.h"
#include "components/optimization_guide/core/model_execution/model_execution_manager.h"
#include "components/optimization_guide/core/model_execution/optimization_guide_model_execution_error.h"
#include "components/optimization_guide/core/optimization_guide_constants.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_logger.h"
#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/model_quality_service.pb.h"
#include "content/public/test/browser_test.h"
#include "content/public/test/browser_test_utils.h"
#include "net/dns/mock_host_resolver.h"
#include "testing/gmock/include/gmock/gmock.h"
namespace optimization_guide {
using base::test::TestMessage;
namespace {
enum class ModelExecutionRemoteResponseType {
kSuccessful = 0,
kUnsuccessful = 1,
kMalformed = 2,
kErrorFiltered = 3,
kUnsupportedLanguage = 4,
};
proto::ExecuteResponse BuildComposeResponse(const std::string& output) {
proto::ComposeResponse compose_response;
compose_response.set_output(output);
proto::ExecuteResponse execute_response;
proto::Any* any_metadata = execute_response.mutable_response_metadata();
any_metadata->set_type_url("type.googleapis.com/" +
compose_response.GetTypeName());
compose_response.SerializeToString(any_metadata->mutable_value());
auto response_data =
optimization_guide::ParsedAnyMetadata<proto::ComposeResponse>(
*any_metadata);
EXPECT_TRUE(response_data);
return execute_response;
}
proto::ExecuteResponse BuildTestErrorExecuteResponse(
const proto::ErrorState& state) {
proto::ExecuteResponse execute_response;
execute_response.mutable_error_response()->set_error_state(state);
return execute_response;
}
} // namespace
class ModelExecutionBrowserTestBase : public InProcessBrowserTest {
public:
ModelExecutionBrowserTestBase() = default;
~ModelExecutionBrowserTestBase() override = default;
ModelExecutionBrowserTestBase(const ModelExecutionBrowserTestBase&) = delete;
ModelExecutionBrowserTestBase& operator=(
const ModelExecutionBrowserTestBase&) = delete;
void SetUp() override {
InitializeFeatureList();
model_execution_server_ = std::make_unique<net::EmbeddedTestServer>(
net::EmbeddedTestServer::TYPE_HTTPS);
net::EmbeddedTestServer::ServerCertificateConfig cert_config;
cert_config.dns_names = {
GURL(kOptimizationGuideServiceModelExecutionDefaultURL).host()};
model_execution_server_->SetSSLConfig(cert_config);
model_execution_server_->RegisterRequestHandler(base::BindRepeating(
&ModelExecutionBrowserTestBase::HandleGetModelExecutionRequest,
base::Unretained(this)));
ASSERT_TRUE(model_execution_server_->Start());
InProcessBrowserTest::SetUp();
}
void SetUpOnMainThread() override {
InProcessBrowserTest::SetUpOnMainThread();
identity_test_env_adaptor_ =
std::make_unique<IdentityTestEnvironmentProfileAdaptor>(
browser()->profile());
host_resolver()->AddRule("*", "127.0.0.1");
}
void SetUpInProcessBrowserTestFixture() override {
create_services_subscription_ =
BrowserContextDependencyManager::GetInstance()
->RegisterCreateServicesCallbackForTesting(
base::BindRepeating(&ModelExecutionBrowserTestBase::
OnWillCreateBrowserContextServices,
base::Unretained(this)));
}
void SetUpCommandLine(base::CommandLine* cmd) override {
cmd->AppendSwitchASCII(
switches::kOptimizationGuideServiceModelExecutionURL,
model_execution_server_
->GetURL(
GURL(kOptimizationGuideServiceModelExecutionDefaultURL).host(),
"/")
.spec());
}
void TearDownOnMainThread() override {
EXPECT_TRUE(model_execution_server_->ShutdownAndWaitUntilComplete());
InProcessBrowserTest::TearDownOnMainThread();
}
void EnableSignin() {
identity_test_env_adaptor_->identity_test_env()
->MakePrimaryAccountAvailable("user@gmail.com",
signin::ConsentLevel::kSignin);
identity_test_env_adaptor_->identity_test_env()
->SetAutomaticIssueOfAccessTokens(true);
}
OptimizationGuideKeyedService* GetOptimizationGuideKeyedService(
Profile* profile = nullptr) {
if (!profile) {
profile = browser()->profile();
}
return OptimizationGuideKeyedServiceFactory::GetForProfile(profile);
}
// Executes the model for the feature, waits until the response is received,
// and returns the response.
void ExecuteModel(proto::ModelExecutionFeature feature,
const google::protobuf::MessageLite& request_metadata,
Profile* profile = nullptr) {
if (!profile) {
profile = browser()->profile();
}
base::RunLoop run_loop;
GetOptimizationGuideKeyedService(profile)->ExecuteModel(
feature, request_metadata,
base::BindOnce(&ModelExecutionBrowserTestBase::OnModelExecutionResponse,
base::Unretained(this), run_loop.QuitClosure()));
run_loop.Run();
}
void SetExpectedBearerAccessToken(
const std::string& expected_bearer_access_token) {
expected_bearer_access_token_ = expected_bearer_access_token;
}
void SetResponseType(ModelExecutionRemoteResponseType response_type) {
response_type_ = response_type;
}
protected:
void OnModelExecutionResponse(
base::OnceClosure on_model_execution_closure,
OptimizationGuideModelExecutionResult result,
std::unique_ptr<ModelQualityLogEntry> log_entry) {
if (result.has_value() ||
result.error().error() == OptimizationGuideModelExecutionError::
ModelExecutionError::kFiltered ||
result.error().error() ==
OptimizationGuideModelExecutionError::ModelExecutionError::
kUnsupportedLanguage) {
EXPECT_TRUE(log_entry);
proto::LogAiDataRequest* log_ai_data_request =
log_entry.get()->log_ai_data_request();
EXPECT_NE(log_ai_data_request, nullptr);
EXPECT_EQ(log_ai_data_request->feature_case(),
proto::LogAiDataRequest::FeatureCase::kCompose);
EXPECT_TRUE(log_ai_data_request->has_compose());
EXPECT_TRUE(log_ai_data_request->mutable_compose()->has_request_data());
}
if (result.has_value()) {
EXPECT_TRUE(log_entry.get()
->log_ai_data_request()
->mutable_compose()
->has_response_data());
model_execution_result_ = base::ok(result.value());
} else {
model_execution_result_ = base::unexpected(result.error());
}
std::move(on_model_execution_closure).Run();
}
std::unique_ptr<net::test_server::HttpResponse>
HandleGetModelExecutionRequest(const net::test_server::HttpRequest& request) {
auto response = std::make_unique<net::test_server::BasicHttpResponse>();
// If the request is a GET, it corresponds to a navigation so return a
// normal response.
EXPECT_EQ(request.method, net::test_server::METHOD_POST);
EXPECT_NE(request.headers.end(), request.headers.find("X-Client-Data"));
// Access token should be set.
EXPECT_TRUE(base::Contains(request.headers,
net::HttpRequestHeaders::kAuthorization));
EXPECT_EQ(expected_bearer_access_token_,
request.headers.at(net::HttpRequestHeaders::kAuthorization));
if (response_type_ == ModelExecutionRemoteResponseType::kSuccessful) {
std::string serialized_response;
proto::ExecuteResponse execute_response =
BuildComposeResponse("foo response");
execute_response.SerializeToString(&serialized_response);
response->set_code(net::HTTP_OK);
response->set_content(serialized_response);
} else if (response_type_ ==
ModelExecutionRemoteResponseType::kUnsuccessful) {
response->set_code(net::HTTP_NOT_FOUND);
} else if (response_type_ == ModelExecutionRemoteResponseType::kMalformed) {
response->set_code(net::HTTP_OK);
response->set_content("Not a proto");
} else if (response_type_ ==
ModelExecutionRemoteResponseType::kErrorFiltered) {
std::string serialized_response;
proto::ExecuteResponse execute_response = BuildTestErrorExecuteResponse(
proto::ErrorState::ERROR_STATE_FILTERED);
execute_response.SerializeToString(&serialized_response);
response->set_code(net::HTTP_OK);
response->set_content(serialized_response);
} else if (response_type_ ==
ModelExecutionRemoteResponseType::kUnsupportedLanguage) {
std::string serialized_response;
proto::ExecuteResponse execute_response = BuildTestErrorExecuteResponse(
proto::ErrorState::ERROR_STATE_UNSUPPORTED_LANGUAGE);
execute_response.SerializeToString(&serialized_response);
response->set_code(net::HTTP_OK);
response->set_content(serialized_response);
} else {
NOTREACHED();
}
return std::move(response);
}
void OnWillCreateBrowserContextServices(content::BrowserContext* context) {
IdentityTestEnvironmentProfileAdaptor::
SetIdentityTestEnvironmentFactoriesOnBrowserContext(context);
}
// Virtualize for testing different feature configurations.
virtual void InitializeFeatureList() {}
base::test::ScopedFeatureList scoped_feature_list_;
std::unique_ptr<net::EmbeddedTestServer> model_execution_server_;
base::HistogramTester histogram_tester_;
ModelExecutionRemoteResponseType response_type_ =
ModelExecutionRemoteResponseType::kSuccessful;
// The last model execution response received.
absl::optional<OptimizationGuideModelExecutionResult> model_execution_result_;
// Identity test support.
std::unique_ptr<IdentityTestEnvironmentProfileAdaptor>
identity_test_env_adaptor_;
base::CallbackListSubscription create_services_subscription_;
// The expected authorization header holding the bearer access token.
std::string expected_bearer_access_token_;
};
class ModelExecutionDisabledBrowserTest : public ModelExecutionBrowserTestBase {
void InitializeFeatureList() override {
scoped_feature_list_.InitAndDisableFeature(
features::kOptimizationGuideModelExecution);
}
};
IN_PROC_BROWSER_TEST_F(ModelExecutionDisabledBrowserTest,
ModelExecutionDisabled) {
proto::ComposeRequest request;
request.set_user_input("a user typed this");
ExecuteModel(proto::MODEL_EXECUTION_FEATURE_COMPOSE, request);
EXPECT_TRUE(model_execution_result_.has_value());
EXPECT_FALSE(model_execution_result_->has_value());
EXPECT_EQ(OptimizationGuideModelExecutionError::ModelExecutionError::
kGenericFailure,
model_execution_result_->error().error());
EXPECT_TRUE(model_execution_result_->error().transient());
}
class ModelExecutionEnabledBrowserTest : public ModelExecutionBrowserTestBase {
void InitializeFeatureList() override {
scoped_feature_list_.InitAndEnableFeature(
features::kOptimizationGuideModelExecution);
}
};
IN_PROC_BROWSER_TEST_F(ModelExecutionEnabledBrowserTest,
ModelExecutionDisabledInIncognito) {
Browser* otr_browser = CreateIncognitoBrowser(browser()->profile());
proto::ComposeRequest request;
request.set_user_input("a user typed this");
ExecuteModel(proto::MODEL_EXECUTION_FEATURE_COMPOSE, request,
otr_browser->profile());
EXPECT_TRUE(model_execution_result_.has_value());
EXPECT_FALSE(model_execution_result_->has_value());
EXPECT_EQ(OptimizationGuideModelExecutionError::ModelExecutionError::
kGenericFailure,
model_execution_result_->error().error());
EXPECT_TRUE(model_execution_result_->error().transient());
}
IN_PROC_BROWSER_TEST_F(ModelExecutionEnabledBrowserTest,
ModelExecutionFailsNoUserSignIn) {
proto::ComposeRequest request;
request.set_user_input("a user typed this");
ExecuteModel(proto::MODEL_EXECUTION_FEATURE_COMPOSE, request);
EXPECT_TRUE(model_execution_result_.has_value());
EXPECT_FALSE(model_execution_result_->has_value());
EXPECT_EQ(OptimizationGuideModelExecutionError::ModelExecutionError::
kPermissionDenied,
model_execution_result_->error().error());
EXPECT_FALSE(model_execution_result_->error().transient());
}
IN_PROC_BROWSER_TEST_F(ModelExecutionEnabledBrowserTest,
ModelExecutionSuccess) {
EnableSignin();
SetExpectedBearerAccessToken("Bearer access_token");
proto::ComposeRequest request;
request.set_user_input("a user typed this");
ExecuteModel(proto::MODEL_EXECUTION_FEATURE_COMPOSE, request);
EXPECT_TRUE(model_execution_result_.has_value());
EXPECT_TRUE(model_execution_result_->has_value());
auto response = ParsedAnyMetadata<proto::ComposeResponse>(
model_execution_result_->value());
EXPECT_EQ("foo response", response->output());
}
IN_PROC_BROWSER_TEST_F(ModelExecutionEnabledBrowserTest,
ModelExecutionFailsForUnsuccessfulResponse) {
EnableSignin();
SetExpectedBearerAccessToken("Bearer access_token");
SetResponseType(ModelExecutionRemoteResponseType::kUnsuccessful);
proto::ComposeRequest request;
request.set_user_input("a user typed this");
ExecuteModel(proto::MODEL_EXECUTION_FEATURE_COMPOSE, request);
EXPECT_TRUE(model_execution_result_.has_value());
EXPECT_FALSE(model_execution_result_->has_value());
EXPECT_EQ(OptimizationGuideModelExecutionError::ModelExecutionError::
kGenericFailure,
model_execution_result_->error().error());
EXPECT_TRUE(model_execution_result_->error().transient());
}
IN_PROC_BROWSER_TEST_F(ModelExecutionEnabledBrowserTest,
ModelExecutionFailsForMalformedResponse) {
EnableSignin();
SetExpectedBearerAccessToken("Bearer access_token");
SetResponseType(ModelExecutionRemoteResponseType::kMalformed);
proto::ComposeRequest request;
request.set_user_input("a user typed this");
ExecuteModel(proto::MODEL_EXECUTION_FEATURE_COMPOSE, request);
EXPECT_TRUE(model_execution_result_.has_value());
EXPECT_FALSE(model_execution_result_->has_value());
EXPECT_EQ(OptimizationGuideModelExecutionError::ModelExecutionError::
kGenericFailure,
model_execution_result_->error().error());
EXPECT_TRUE(model_execution_result_->error().transient());
}
IN_PROC_BROWSER_TEST_F(ModelExecutionEnabledBrowserTest,
ModelExecutionFailsForErrorFilteredResponse) {
EnableSignin();
SetExpectedBearerAccessToken("Bearer access_token");
SetResponseType(ModelExecutionRemoteResponseType::kErrorFiltered);
proto::ComposeRequest request;
request.set_user_input("a user typed this");
ExecuteModel(proto::MODEL_EXECUTION_FEATURE_COMPOSE, request);
EXPECT_TRUE(model_execution_result_.has_value());
EXPECT_FALSE(model_execution_result_->has_value());
EXPECT_EQ(
OptimizationGuideModelExecutionError::ModelExecutionError::kFiltered,
model_execution_result_->error().error());
}
IN_PROC_BROWSER_TEST_F(ModelExecutionEnabledBrowserTest,
ModelExecutionFailsForUnsupportedLanguageResponse) {
EnableSignin();
SetExpectedBearerAccessToken("Bearer access_token");
SetResponseType(ModelExecutionRemoteResponseType::kUnsupportedLanguage);
proto::ComposeRequest request;
request.set_user_input("a user typed this");
ExecuteModel(proto::MODEL_EXECUTION_FEATURE_COMPOSE, request);
EXPECT_TRUE(model_execution_result_.has_value());
EXPECT_FALSE(model_execution_result_->has_value());
EXPECT_EQ(OptimizationGuideModelExecutionError::ModelExecutionError::
kUnsupportedLanguage,
model_execution_result_->error().error());
}
class ModelExecutionInternalsPageBrowserTest
: public ModelExecutionEnabledBrowserTest {
public:
void SetUpCommandLine(base::CommandLine* cmd) override {
ModelExecutionEnabledBrowserTest::SetUpCommandLine(cmd);
cmd->AppendSwitch(switches::kDebugLoggingEnabled);
}
void CheckInternalsLog(std::string_view message) {
auto* logger =
GetOptimizationGuideKeyedService()->GetOptimizationGuideLogger();
EXPECT_THAT(logger->recent_log_messages_,
testing::Contains(testing::Field(
&OptimizationGuideLogger::LogMessage::message,
testing::HasSubstr(message))));
}
};
IN_PROC_BROWSER_TEST_F(ModelExecutionInternalsPageBrowserTest,
LoggedInInternalsPage) {
EnableSignin();
SetExpectedBearerAccessToken("Bearer access_token");
proto::ComposeRequest request;
request.set_user_input("foo");
ExecuteModel(proto::MODEL_EXECUTION_FEATURE_COMPOSE, request);
EXPECT_TRUE(model_execution_result_.has_value());
EXPECT_TRUE(model_execution_result_->has_value());
CheckInternalsLog("ExecuteModel");
// CheckInternalsLog("TabOrganization Request");
CheckInternalsLog("OnModelExecutionResponse");
}
} // namespace optimization_guide