blob: 9dc544aa934606d93f67900e6a83489caa1c37c5 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/devtools/aida_client.h"
#include <string>
#include <variant>
#include "base/check_is_test.h"
#include "base/containers/fixed_flat_set.h"
#include "base/json/json_string_value_serializer.h"
#include "base/json/string_escape.h"
#include "base/metrics/histogram_functions.h"
#include "base/metrics/histogram_macros.h"
#include "base/metrics/user_metrics.h"
#include "base/no_destructor.h"
#include "base/strings/string_util.h"
#include "chrome/browser/browser_features.h"
#include "chrome/browser/browser_process.h"
#include "chrome/browser/signin/identity_manager_factory.h"
#include "chrome/common/pref_names.h"
#include "components/prefs/scoped_user_pref_update.h"
#include "components/signin/public/identity_manager/identity_manager.h"
#include "components/variations/service/variations_service.h"
#include "components/variations/service/variations_service_utils.h"
#include "google_apis/gaia/gaia_constants.h"
#include "net/base/load_flags.h"
constexpr auto kLoggingDisallowedCountries =
base::MakeFixedFlatSet<std::string_view>(
{"at", "be", "bg", "cy", "cz", "de", "dk", "ee", "es", "fi", "fr",
"gb", "gr", "hr", "hu", "ie", "is", "it", "li", "lt", "lu", "lv",
"mt", "nl", "no", "pl", "pt", "ro", "se", "si", "sk"});
constexpr auto kAidaSupportedCountries =
base::MakeFixedFlatSet<std::string_view>(
{"ae", "ag", "ai", "am", "ao", "ar", "as", "at", "au", "aw", "az", "bb",
"bd", "be", "bf", "bg", "bh", "bi", "bj", "bl", "bm", "bn", "bo", "bq",
"br", "bs", "bt", "bw", "bz", "ca", "cc", "cd", "cf", "cg", "ch", "ci",
"ck", "cl", "cm", "co", "cr", "cv", "cw", "cx", "cy", "cz", "de", "dj",
"dk", "dm", "do", "dz", "ec", "ee", "eg", "eh", "er", "es", "et", "fi",
"fj", "fk", "fm", "fr", "ga", "gb", "gd", "ge", "gg", "gh", "gi", "gm",
"gn", "gq", "gr", "gs", "gt", "gu", "gw", "gy", "hm", "hn", "hr", "ht",
"hu", "id", "ie", "il", "im", "in", "io", "iq", "is", "it", "je", "jm",
"jo", "jp", "ke", "kg", "kh", "ki", "km", "kn", "kr", "kw", "ky", "kz",
"la", "lb", "lc", "li", "lk", "lr", "ls", "lt", "lu", "lv", "ly", "ma",
"mg", "mh", "ml", "mn", "mp", "mr", "ms", "mt", "mu", "mv", "mw", "mx",
"my", "mz", "na", "nc", "ne", "nf", "ng", "ni", "nl", "no", "np", "nr",
"nu", "nz", "om", "pa", "pe", "pg", "ph", "pk", "pl", "pm", "pn", "pr",
"ps", "pt", "pw", "py", "qa", "ro", "rw", "sa", "sb", "sc", "sd", "se",
"sg", "sh", "si", "sk", "sl", "sn", "so", "sr", "ss", "st", "sv", "sz",
"tc", "td", "tg", "th", "tj", "tk", "tl", "tm", "tn", "to", "tr", "tt",
"tv", "tw", "tz", "ug", "um", "us", "uy", "uz", "vc", "ve", "vg", "vi",
"vn", "vu", "wf", "ws", "ye", "za", "zm", "zw"});
AidaClient::AidaClient(Profile* profile)
: profile_(*profile),
aida_scope_(GaiaConstants::kAidaOAuth2Scope) {}
AidaClient::~AidaClient() = default;
std::optional<AccountInfo> AccountInfoForProfile(Profile* profile) {
auto* identity_manager = IdentityManagerFactory::GetForProfile(profile);
if (!identity_manager) {
return std::nullopt;
}
const auto account_id =
identity_manager->GetPrimaryAccountId(signin::ConsentLevel::kSignin);
if (account_id.empty()) {
return std::nullopt;
}
return identity_manager->FindExtendedAccountInfoByAccountId(account_id);
}
bool IsAidaBlockedByAge(std::optional<AccountInfo> account_info) {
if (!account_info.has_value()) {
return true;
}
return account_info.value()
.capabilities.can_use_devtools_generative_ai_features() !=
signin::Tribool::kTrue;
}
std::unique_ptr<std::string>& GetCountryCodeOverride() {
static base::NoDestructor<std::unique_ptr<std::string>> country_code_override(
nullptr);
return *country_code_override;
}
std::string GetCountryCode() {
if (GetCountryCodeOverride()) {
return *GetCountryCodeOverride();
}
std::string country_code =
base::ToLowerASCII(variations::GetCurrentCountryCode(
g_browser_process->variations_service()));
DLOG_IF(WARNING, country_code.empty()) << "Couldn't get country info.";
return country_code;
}
bool IsLoggingDisabledByGeo(std::string country_code) {
return kLoggingDisallowedCountries.contains(country_code);
}
bool IsAidaBlockedByGeo(std::string country_code) {
return !kAidaSupportedCountries.contains(country_code);
}
AidaClient::Availability AidaClient::CanUseAida(Profile* profile) {
struct Availability result;
// AidaClient is only available on branded builds
#if BUILDFLAG(GOOGLE_CHROME_BRANDING)
result.available = true;
auto account_info = AccountInfoForProfile(profile);
result.blocked_by_age = IsAidaBlockedByAge(account_info);
result.blocked_by_enterprise_policy =
profile->GetPrefs()->GetInteger(prefs::kDevToolsGenAiSettings) ==
static_cast<int>(DevToolsGenAiEnterprisePolicyValue::kDisable);
std::string country_code = GetCountryCode();
result.blocked_by_geo = IsAidaBlockedByGeo(country_code);
result.disallow_logging =
profile->GetPrefs()->GetInteger(prefs::kDevToolsGenAiSettings) ==
static_cast<int>(
DevToolsGenAiEnterprisePolicyValue::kAllowWithoutLogging) ||
IsLoggingDisabledByGeo(country_code);
result.blocked = result.blocked_by_age ||
result.blocked_by_enterprise_policy || result.blocked_by_geo;
result.enterprise_policy_value =
static_cast<DevToolsGenAiEnterprisePolicyValue>(
profile->GetPrefs()->GetInteger(prefs::kDevToolsGenAiSettings));
return result;
#else
// AidaClient is only available on branded builds
result.available = false;
result.blocked = true;
return result;
#endif
}
AidaClient::ScopedOverride AidaClient::OverrideCountryForTesting(
std::string country_code) {
CHECK(!GetCountryCodeOverride());
GetCountryCodeOverride() = std::make_unique<std::string>(country_code);
return std::make_unique<base::ScopedClosureRunner>(
base::BindOnce([]() { GetCountryCodeOverride().reset(); }));
}
void AidaClient::OverrideAidaScopeForTesting(const std::string& aida_scope) {
aida_scope_ = aida_scope;
}
void AidaClient::RemoveAccessToken() {
access_token_.clear();
}
void AidaClient::PrepareRequestOrFail(
base::OnceCallback<
void(std::variant<network::ResourceRequest, std::string>)> callback) {
if (!access_token_.empty() && base::Time::Now() < access_token_expiration_) {
PrepareAidaRequest(std::move(callback));
return;
}
auto* identity_manager = IdentityManagerFactory::GetForProfile(&*profile_);
if (!identity_manager) {
std::move(callback).Run(R"({"error": "IdentityManager is not available"})");
return;
}
CoreAccountId account_id =
identity_manager->GetPrimaryAccountId(signin::ConsentLevel::kSignin);
access_token_fetcher_ = identity_manager->CreateAccessTokenFetcherForAccount(
account_id, "AIDA client", signin::ScopeSet{aida_scope_},
base::BindOnce(&AidaClient::AccessTokenFetchFinished,
base::Unretained(this), std::move(callback)),
signin::AccessTokenFetcher::Mode::kImmediate);
}
void AidaClient::AccessTokenFetchFinished(
base::OnceCallback<
void(std::variant<network::ResourceRequest, std::string>)> callback,
GoogleServiceAuthError error,
signin::AccessTokenInfo access_token_info) {
if (error.state() != GoogleServiceAuthError::NONE) {
std::move(callback).Run(base::ReplaceStringPlaceholders(
R"({"error": "Cannot get OAuth credentials", "detail": $1})",
{base::GetQuotedJSONString(error.ToString())}, nullptr));
return;
}
access_token_ = access_token_info.token;
access_token_expiration_ = access_token_info.expiration_time;
PrepareAidaRequest(std::move(callback));
}
void AidaClient::PrepareAidaRequest(
base::OnceCallback<
void(std::variant<network::ResourceRequest, std::string>)> callback) {
CHECK(!access_token_.empty());
network::ResourceRequest aida_request;
aida_request.load_flags = net::LOAD_DISABLE_CACHE;
aida_request.credentials_mode = network::mojom::CredentialsMode::kOmit;
aida_request.method = "POST";
aida_request.headers.SetHeader(net::HttpRequestHeaders::kAuthorization,
std::string("Bearer ") + access_token_);
std::move(callback).Run(std::move(aida_request));
}