blob: c82538f9e7f004ed5b41654bfac7adae352a3ddd [file] [log] [blame]
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "base/task/thread_pool/test_utils.h"
#include <utility>
#include "base/bind.h"
#include "base/synchronization/condition_variable.h"
#include "base/task/thread_pool/pooled_parallel_task_runner.h"
#include "base/task/thread_pool/pooled_sequenced_task_runner.h"
#include "base/threading/scoped_blocking_call.h"
#include "base/threading/thread_restrictions.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace base {
namespace internal {
namespace test {
MockWorkerThreadObserver::MockWorkerThreadObserver()
: on_main_exit_cv_(lock_.CreateConditionVariable()) {}
MockWorkerThreadObserver::~MockWorkerThreadObserver() {
WaitCallsOnMainExit();
}
void MockWorkerThreadObserver::AllowCallsOnMainExit(int num_calls) {
CheckedAutoLock auto_lock(lock_);
EXPECT_EQ(0, allowed_calls_on_main_exit_);
allowed_calls_on_main_exit_ = num_calls;
}
void MockWorkerThreadObserver::WaitCallsOnMainExit() {
CheckedAutoLock auto_lock(lock_);
while (allowed_calls_on_main_exit_ != 0)
on_main_exit_cv_->Wait();
}
void MockWorkerThreadObserver::OnWorkerThreadMainExit() {
CheckedAutoLock auto_lock(lock_);
EXPECT_GE(allowed_calls_on_main_exit_, 0);
--allowed_calls_on_main_exit_;
if (allowed_calls_on_main_exit_ == 0)
on_main_exit_cv_->Signal();
}
scoped_refptr<Sequence> CreateSequenceWithTask(
Task task,
const TaskTraits& traits,
scoped_refptr<TaskRunner> task_runner,
TaskSourceExecutionMode execution_mode) {
scoped_refptr<Sequence> sequence =
MakeRefCounted<Sequence>(traits, task_runner.get(), execution_mode);
sequence->BeginTransaction().PushTask(std::move(task));
return sequence;
}
scoped_refptr<TaskRunner> CreateTaskRunnerWithExecutionMode(
TaskSourceExecutionMode execution_mode,
MockPooledTaskRunnerDelegate* mock_pooled_task_runner_delegate,
const TaskTraits& traits) {
switch (execution_mode) {
case TaskSourceExecutionMode::kParallel:
return CreateTaskRunner(traits, mock_pooled_task_runner_delegate);
case TaskSourceExecutionMode::kSequenced:
return CreateSequencedTaskRunner(traits,
mock_pooled_task_runner_delegate);
default:
// Fall through.
break;
}
ADD_FAILURE() << "Unexpected ExecutionMode";
return nullptr;
}
scoped_refptr<TaskRunner> CreateTaskRunner(
const TaskTraits& traits,
MockPooledTaskRunnerDelegate* mock_pooled_task_runner_delegate) {
return MakeRefCounted<PooledParallelTaskRunner>(
traits, mock_pooled_task_runner_delegate);
}
scoped_refptr<SequencedTaskRunner> CreateSequencedTaskRunner(
const TaskTraits& traits,
MockPooledTaskRunnerDelegate* mock_pooled_task_runner_delegate) {
return MakeRefCounted<PooledSequencedTaskRunner>(
traits, mock_pooled_task_runner_delegate);
}
// Waits on |event| in a scope where the blocking observer is null, to avoid
// affecting the max tasks in a thread group.
void WaitWithoutBlockingObserver(WaitableEvent* event) {
internal::ScopedClearBlockingObserverForTesting clear_blocking_observer;
ScopedAllowBaseSyncPrimitivesForTesting allow_base_sync_primitives;
event->Wait();
}
MockPooledTaskRunnerDelegate::MockPooledTaskRunnerDelegate(
TrackedRef<TaskTracker> task_tracker,
DelayedTaskManager* delayed_task_manager)
: task_tracker_(task_tracker),
delayed_task_manager_(delayed_task_manager) {}
MockPooledTaskRunnerDelegate::~MockPooledTaskRunnerDelegate() = default;
bool MockPooledTaskRunnerDelegate::PostTaskWithSequence(
Task task,
scoped_refptr<Sequence> sequence) {
// |thread_group_| must be initialized with SetThreadGroup() before
// proceeding.
DCHECK(thread_group_);
DCHECK(task.task);
DCHECK(sequence);
if (!task_tracker_->WillPostTask(&task, sequence->shutdown_behavior()))
return false;
if (task.delayed_run_time.is_null()) {
PostTaskWithSequenceNow(std::move(task), std::move(sequence));
} else {
// It's safe to take a ref on this pointer since the caller must have a ref
// to the TaskRunner in order to post.
scoped_refptr<TaskRunner> task_runner = sequence->task_runner();
delayed_task_manager_->AddDelayedTask(
std::move(task),
BindOnce(
[](scoped_refptr<Sequence> sequence,
MockPooledTaskRunnerDelegate* self, Task task) {
self->PostTaskWithSequenceNow(std::move(task),
std::move(sequence));
},
std::move(sequence), Unretained(this)),
std::move(task_runner));
}
return true;
}
void MockPooledTaskRunnerDelegate::PostTaskWithSequenceNow(
Task task,
scoped_refptr<Sequence> sequence) {
auto transaction = sequence->BeginTransaction();
const bool sequence_should_be_queued = transaction.WillPushTask();
RegisteredTaskSource task_source;
if (sequence_should_be_queued) {
task_source = task_tracker_->WillQueueTaskSource(sequence);
// We shouldn't push |task| if we're not allowed to queue |task_source|.
if (!task_source)
return;
}
transaction.PushTask(std::move(task));
if (task_source) {
thread_group_->PushTaskSourceAndWakeUpWorkers(
{std::move(task_source), std::move(transaction)});
}
}
bool MockPooledTaskRunnerDelegate::IsRunningPoolWithTraits(
const TaskTraits& traits) const {
// |thread_group_| must be initialized with SetThreadGroup() before
// proceeding.
DCHECK(thread_group_);
return thread_group_->IsBoundToCurrentThread();
}
void MockPooledTaskRunnerDelegate::UpdatePriority(
scoped_refptr<TaskSource> task_source,
TaskPriority priority) {
auto transaction = task_source->BeginTransaction();
transaction.UpdatePriority(priority);
thread_group_->UpdateSortKey(
{std::move(task_source), std::move(transaction)});
}
void MockPooledTaskRunnerDelegate::SetThreadGroup(ThreadGroup* thread_group) {
thread_group_ = thread_group;
}
RegisteredTaskSource QueueAndRunTaskSource(
TaskTracker* task_tracker,
scoped_refptr<TaskSource> task_source) {
return task_tracker->RunAndPopNextTask(
task_tracker->WillQueueTaskSource(std::move(task_source)));
}
void ShutdownTaskTracker(TaskTracker* task_tracker) {
task_tracker->StartShutdown();
task_tracker->CompleteShutdown();
}
} // namespace test
} // namespace internal
} // namespace base