blob: 13596ed9e34996fde3084f855e9e98aa6f153d0b [file] [log] [blame]
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/execution/mock_model_provider.h"
#include <utility>
#include "base/containers/contains.h"
#include "base/functional/callback.h"
namespace segmentation_platform {
namespace {
using ::testing::_;
using ::testing::Invoke;
// Stores the client callbacks to |data|.
void StoreClientCallback(
proto::SegmentId segment_id,
TestModelProviderFactory::Data* data,
const ModelProvider::ModelUpdatedCallback& model_updated_callback) {
data->model_providers_callbacks.emplace(
std::make_pair(segment_id, model_updated_callback));
}
} // namespace
MockModelProvider::MockModelProvider(
proto::SegmentId segment_id,
base::RepeatingCallback<void(const ModelProvider::ModelUpdatedCallback&)>
get_client_callback)
: ModelProvider(segment_id), get_client_callback_(get_client_callback) {
ON_CALL(*this, InitAndFetchModel(_))
.WillByDefault(
Invoke([&](const ModelUpdatedCallback& model_updated_callback) {
get_client_callback_.Run(model_updated_callback);
}));
}
MockModelProvider::~MockModelProvider() = default;
TestModelProviderFactory::Data::Data() = default;
TestModelProviderFactory::Data::~Data() = default;
std::unique_ptr<ModelProvider> TestModelProviderFactory::CreateProvider(
proto::SegmentId segment_id) {
auto provider = std::make_unique<MockModelProvider>(
segment_id, base::BindRepeating(&StoreClientCallback, segment_id, data_));
data_->model_providers.emplace(std::make_pair(segment_id, provider.get()));
return provider;
}
std::unique_ptr<ModelProvider> TestModelProviderFactory::CreateDefaultProvider(
proto::SegmentId segment_id) {
if (!base::Contains(data_->segments_supporting_default_model, segment_id))
return nullptr;
auto provider = std::make_unique<MockModelProvider>(
segment_id, base::BindRepeating(&StoreClientCallback, segment_id, data_));
data_->default_model_providers.emplace(
std::make_pair(segment_id, provider.get()));
return provider;
}
} // namespace segmentation_platform