blob: fba5e5eb6898a035b2830a74522de9de288f7755 [file] [log] [blame]
// Copyright 2021 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "federated/federated_client.h"
#include <string>
#include <utility>
#include <base/logging.h>
#include <base/notreached.h>
#include <base/strings/stringprintf.h>
#include <base/time/time.h>
#include "federated/device_status/device_status_monitor.h"
#include "federated/example_database.h"
#include "federated/federated_metadata.h"
#include "federated/metrics.h"
#include "federated/protos/cros_events.pb.h"
#include "federated/protos/cros_example_selector_criteria.pb.h"
#include "federated/utils.h"
namespace federated {
namespace {
// The first checkin happens kInitialWaitingWindow after the device startup to
// avoid resource competition.
constexpr base::TimeDelta kInitialWaitingWindow = base::Minutes(5);
// The default value is used when the server doesn't respond with a valid retry
// window.
constexpr base::TimeDelta kDefaultRetryWindow = base::Minutes(30);
// The retry window should not be shorter than kMinimalRetryWindow to avoid
// spam.
constexpr base::TimeDelta kMinimalRetryWindow = base::Minutes(1);
// Limits each round to 10 minutes.
constexpr base::TimeDelta kMaximalExecutionTime = base::Minutes(10);
// TODO(b/251378482): Just dummpy impl for now, might need to log to UMA.
void LogCrosEvent(const fcp::client::CrosEvent& cros_event) {
LOG(INFO) << "In LogCrosEvent, model_id is " << cros_event.model_id()
<< ", event_type is "
<< fcp::client::CrosEvent::EventType_Name(cros_event.event_type());
DVLOG(1) << "cros_event is " << cros_event.DebugString();
}
// TODO(b/251378482): Just dummpy impl for now, might need to log to UMA.
void LogCrosSecAggEvent(const fcp::client::CrosSecAggEvent& cros_secagg_event) {
LOG(INFO) << "In LogCrosSecAggEvent, session_id is "
<< cros_secagg_event.execution_session_id();
if (cros_secagg_event.has_state_transition())
LOG(INFO) << "cros_secagg_event.has_state_transition";
else if (cros_secagg_event.has_error())
LOG(ERROR) << "cros_secagg_event.has_error";
else if (cros_secagg_event.has_abort())
LOG(INFO) << "cros_secagg_event.has_abort";
else
LOG(INFO) << "cros_secagg_event doesn't have any event log";
}
} // namespace
FederatedClient::Context::Context(
const std::string& client_name,
const std::string& table_name,
const std::string& population_name,
const DeviceStatusMonitor* const device_status_monitor,
const StorageManager* const storage_manager)
: client_name_(client_name),
table_name_(table_name),
population_name_(population_name),
start_time_(base::Time::Now()),
device_status_monitor_(device_status_monitor),
storage_manager_(storage_manager) {}
FederatedClient::Context::~Context() = default;
bool FederatedClient::Context::PrepareExamples(const char* const criteria_data,
const int criteria_data_size,
void* const context) {
auto* typed_context = static_cast<FederatedClient::Context*>(context);
const std::string& client_name = typed_context->client_name_;
const std::string& table_name = typed_context->table_name_;
fcp::client::CrosExampleSelectorCriteria criteria;
if (!criteria.ParseFromArray(criteria_data, criteria_data_size)) {
LOG(ERROR) << "Failed to parse criteria.";
Metrics::GetInstance()->LogClientEvent(
client_name, ClientEvent::kExampleSelectorCriteriaParsingError);
return false;
}
if (criteria.task_name().empty()) {
LOG(ERROR) << "No valid task_name";
Metrics::GetInstance()->LogClientEvent(client_name,
ClientEvent::kTaskNameEmptyError);
return false;
}
// Initializes `new_meta_record_`. If iterator is created successfully and the
// task starts, it keeps the largest seen example id and the associated
// example timestamp . If the task succeeds, metatable will be updated with
// `new_meta_record_`.
// Next time when running this task and if it prevents used examples, example
// selection will start from last_used_example_timestamp not inclusive.
// This is a precise breakpoint compared to last_contribution_time, because
// there may be new examples received during a computation and with timestamp
// in between last_used_example_timestamp and last_contribution_time. Such
// examples will never be used.
// Note: The identifier contains population_name instead of client name, so
// that when a client's launch stage changes, e.g. from "dev" to "dogfood",
// the examples in this client's table used in dev stage can be used in
// dogfood again.
typed_context->new_meta_record_.identifier =
base::StringPrintf("%s#%s", typed_context->population_name_.c_str(),
criteria.task_name().c_str());
typed_context->new_meta_record_.last_used_example_id = -1;
typed_context->new_meta_record_.last_used_example_timestamp =
base::Time::UnixEpoch();
std::optional<ExampleDatabase::Iterator> example_iterator =
typed_context->storage_manager_->GetExampleIterator(
table_name, typed_context->new_meta_record_.identifier, criteria);
if (!example_iterator.has_value()) {
DVLOG(1) << "Client " << client_name
<< " failed to prepare examples with table_name " << table_name;
Metrics::GetInstance()->LogClientEvent(
client_name, ClientEvent::kGetExampleIteratorError);
return false;
}
typed_context->example_iterator_ = std::move(example_iterator.value());
return true;
}
bool FederatedClient::Context::GetNextExample(const char** const data,
int* const size,
bool* const end,
void* const context) {
if (context == nullptr)
return false;
auto* typed_context = static_cast<FederatedClient::Context*>(context);
const absl::StatusOr<ExampleRecord> record =
typed_context->example_iterator_.Next();
if (absl::IsInvalidArgument(record.status())) {
Metrics::GetInstance()->LogClientEvent(
typed_context->client_name_,
ClientEvent::kGetNextExampleInvalidArgumentError);
return false;
}
if (record.ok()) {
*end = false;
*size = record->serialized_example.size();
char* const str_data = new char[*size];
record->serialized_example.copy(str_data, *size);
*data = str_data;
if (record->id > typed_context->new_meta_record_.last_used_example_id) {
typed_context->new_meta_record_.last_used_example_id = record->id;
typed_context->new_meta_record_.last_used_example_timestamp =
record->timestamp;
}
} else {
DCHECK(absl::IsOutOfRange(record.status()));
*end = true;
}
return true;
}
void FederatedClient::Context::FreeExample(const char* const data,
void* const context) {
delete[] data;
}
bool FederatedClient::Context::TrainingConditionsSatisfied(
void* const context) {
if (context == nullptr)
return false;
auto* typed_context = static_cast<FederatedClient::Context*>(context);
// If time cost exceeds the limit, return false to quit this round.
base::TimeDelta time_cost = base::Time::Now() - typed_context->start_time_;
if (time_cost > kMaximalExecutionTime) {
Metrics::GetInstance()->LogClientEvent(typed_context->client_name_,
ClientEvent::kTaskTimeoutAbort);
return false;
}
const bool condition_satisfied =
typed_context->device_status_monitor_
->TrainingConditionsSatisfiedToContinue();
if (!condition_satisfied) {
Metrics::GetInstance()->LogClientEvent(
typed_context->client_name_, ClientEvent::kUnsatisfiedConditionAbort);
}
return condition_satisfied;
}
void FederatedClient::Context::PublishEvent(const char* const event,
const int size,
void* const context) {
if (context == nullptr) {
LOG(ERROR) << "PublishEvent gets nullptr context.";
return;
}
fcp::client::CrosEventLog event_log;
if (!event_log.ParseFromArray(event, size)) {
LOG(ERROR) << "Failed to parse event_log.";
return;
}
if (event_log.has_event()) {
LogCrosEvent(event_log.event());
} else if (event_log.has_secagg_event()) {
LogCrosSecAggEvent(event_log.secagg_event());
} else {
LOG(ERROR) << "event_log has no content";
}
}
FederatedClient::FederatedClient(
const FlRunPlanFn run_plan,
const FlFreeRunPlanResultFn free_run_plan_result,
const std::string& service_uri,
const std::string& api_key,
const std::string& brella_lib_version,
const ClientConfigMetadata client_config,
const DeviceStatusMonitor* const device_status_monitor)
: run_plan_(run_plan),
free_run_plan_result_(free_run_plan_result),
service_uri_(service_uri),
api_key_(api_key),
brella_lib_version_(brella_lib_version),
client_config_(client_config),
next_retry_delay_(kInitialWaitingWindow),
device_status_monitor_(device_status_monitor) {}
FederatedClient::~FederatedClient() = default;
void FederatedClient::RunPlan(const StorageManager* const storage_manager) {
if (!storage_manager->IsDatabaseConnected()) {
next_retry_delay_ = kDefaultRetryWindow;
DVLOG(1) << "StorageManager doesn't have a database connection, retry in "
<< next_retry_delay_;
return;
}
DCHECK(!storage_manager->sanitized_username().empty())
<< "storage_manager->sanitized_username() is unexpectedly empty!";
// Compose a unique population name from the client name and the launch stage.
const std::string population_name =
base::StringPrintf("chromeos/%s/%s", client_config_.name.c_str(),
client_config_.launch_stage.c_str());
FederatedClient::Context context(client_config_.name,
client_config_.table_name, population_name,
device_status_monitor_, storage_manager);
const std::string base_dir_in_cryptohome =
GetBaseDir(storage_manager->sanitized_username(), client_config_.name)
.value();
const FlTaskEnvironment env = {
&FederatedClient::Context::PrepareExamples,
&FederatedClient::Context::GetNextExample,
&FederatedClient::Context::FreeExample,
&FederatedClient::Context::TrainingConditionsSatisfied,
&FederatedClient::Context::PublishEvent,
base_dir_in_cryptohome.c_str(),
&context};
auto scoped_metrics_recorder =
Metrics::GetInstance()->CreateScopedMetricsRecorder(client_config_.name);
FlRunPlanResult result = (*run_plan_)(
env, service_uri_.c_str(), api_key_.c_str(), brella_lib_version_.c_str(),
population_name.c_str(), client_config_.retry_token.c_str());
// TODO(b/251378482): maybe log the event to UMA
if (result.status == CONTRIBUTED || result.status == REJECTED_BY_SERVER) {
DVLOG(1) << "result.status = " << result.status;
DVLOG(1) << "result.retry_token = " << result.retry_token;
DVLOG(1) << "result.delay_usecs = " << result.delay_usecs;
client_config_.retry_token = std::string(result.retry_token);
next_retry_delay_ = base::Microseconds(result.delay_usecs);
// TODO(b/239623649): result.delay_usecs may be 0 when setup is wrong, now I
// set next_retry_delay_ to kMinimalRetryWindow to avoid spam, consider
// stopping retry in this case because it's very likely to fail again.
if (next_retry_delay_ < kMinimalRetryWindow)
next_retry_delay_ = kMinimalRetryWindow;
if (result.status == CONTRIBUTED) {
scoped_metrics_recorder.MarkSuccess();
Metrics::GetInstance()->LogClientEvent(client_config_.name,
ClientEvent::kContributed);
context.new_meta_record().timestamp = base::Time::Now();
storage_manager->UpdateMetaRecord(context.new_meta_record());
} else {
Metrics::GetInstance()->LogClientEvent(client_config_.name,
ClientEvent::kRejected);
}
} else {
DVLOG(1) << "Failed to checkin with the servce, result.status = "
<< result.status;
Metrics::GetInstance()->LogClientEvent(
client_config_.name, ClientEvent::kTaskFailedUnknownError);
next_retry_delay_ = kDefaultRetryWindow;
}
(*free_run_plan_result_)(result);
}
void FederatedClient::ResetRetryDelay() {
next_retry_delay_ = kDefaultRetryWindow;
}
std::string FederatedClient::GetClientName() const {
return client_config_.name;
}
} // namespace federated