// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "base/sequence_checker.h"

#include <stddef.h>

#include <memory>
#include <string>

#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/callback_forward.h"
#include "base/macros.h"
#include "base/message_loop/message_loop.h"
#include "base/sequence_token.h"
#include "base/single_thread_task_runner.h"
#include "base/test/gtest_util.h"
#include "base/test/sequenced_worker_pool_owner.h"
#include "base/threading/simple_thread.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace base {

namespace {

constexpr size_t kNumWorkerThreads = 3;

// Runs a callback on another thread.
class RunCallbackThread : public SimpleThread {
 public:
  explicit RunCallbackThread(const Closure& callback)
      : SimpleThread("RunCallbackThread"), callback_(callback) {
    Start();
    Join();
  }

 private:
  // SimpleThread:
  void Run() override { callback_.Run(); }

  const Closure callback_;

  DISALLOW_COPY_AND_ASSIGN(RunCallbackThread);
};

class SequenceCheckerTest : public testing::Test {
 protected:
  SequenceCheckerTest() : pool_owner_(kNumWorkerThreads, "test") {}

  void PostToSequencedWorkerPool(const Closure& callback,
                                 const std::string& token_name) {
    pool_owner_.pool()->PostNamedSequencedWorkerTask(token_name, FROM_HERE,
                                                     callback);
  }

  void FlushSequencedWorkerPoolForTesting() {
    pool_owner_.pool()->FlushForTesting();
  }

 private:
  MessageLoop message_loop_;  // Needed by SequencedWorkerPool to function.
  SequencedWorkerPoolOwner pool_owner_;

  DISALLOW_COPY_AND_ASSIGN(SequenceCheckerTest);
};

void ExpectCalledOnValidSequence(SequenceCheckerImpl* sequence_checker) {
  ASSERT_TRUE(sequence_checker);

  // This should bind |sequence_checker| to the current sequence if it wasn't
  // already bound to a sequence.
  EXPECT_TRUE(sequence_checker->CalledOnValidSequence());

  // Since |sequence_checker| is now bound to the current sequence, another call
  // to CalledOnValidSequence() should return true.
  EXPECT_TRUE(sequence_checker->CalledOnValidSequence());
}

void ExpectCalledOnValidSequenceWithSequenceToken(
    SequenceCheckerImpl* sequence_checker,
    SequenceToken sequence_token) {
  ScopedSetSequenceTokenForCurrentThread
      scoped_set_sequence_token_for_current_thread(sequence_token);
  ExpectCalledOnValidSequence(sequence_checker);
}

void ExpectNotCalledOnValidSequence(SequenceCheckerImpl* sequence_checker) {
  ASSERT_TRUE(sequence_checker);
  EXPECT_FALSE(sequence_checker->CalledOnValidSequence());
}

}  // namespace

TEST_F(SequenceCheckerTest, CallsAllowedOnSameThreadNoSequenceToken) {
  SequenceCheckerImpl sequence_checker;
  EXPECT_TRUE(sequence_checker.CalledOnValidSequence());
}

TEST_F(SequenceCheckerTest, CallsAllowedOnSameThreadSameSequenceToken) {
  ScopedSetSequenceTokenForCurrentThread
      scoped_set_sequence_token_for_current_thread(SequenceToken::Create());
  SequenceCheckerImpl sequence_checker;
  EXPECT_TRUE(sequence_checker.CalledOnValidSequence());
}

TEST_F(SequenceCheckerTest, CallsDisallowedOnDifferentThreadsNoSequenceToken) {
  SequenceCheckerImpl sequence_checker;
  RunCallbackThread thread(
      Bind(&ExpectNotCalledOnValidSequence, Unretained(&sequence_checker)));
}

TEST_F(SequenceCheckerTest, CallsAllowedOnDifferentThreadsSameSequenceToken) {
  const SequenceToken sequence_token(SequenceToken::Create());

  ScopedSetSequenceTokenForCurrentThread
      scoped_set_sequence_token_for_current_thread(sequence_token);
  SequenceCheckerImpl sequence_checker;
  EXPECT_TRUE(sequence_checker.CalledOnValidSequence());

  RunCallbackThread thread(Bind(&ExpectCalledOnValidSequenceWithSequenceToken,
                                Unretained(&sequence_checker), sequence_token));
}

TEST_F(SequenceCheckerTest, CallsDisallowedOnSameThreadDifferentSequenceToken) {
  std::unique_ptr<SequenceCheckerImpl> sequence_checker;

  {
    ScopedSetSequenceTokenForCurrentThread
        scoped_set_sequence_token_for_current_thread(SequenceToken::Create());
    sequence_checker.reset(new SequenceCheckerImpl);
  }

  {
    // Different SequenceToken.
    ScopedSetSequenceTokenForCurrentThread
        scoped_set_sequence_token_for_current_thread(SequenceToken::Create());
    EXPECT_FALSE(sequence_checker->CalledOnValidSequence());
  }

  // No SequenceToken.
  EXPECT_FALSE(sequence_checker->CalledOnValidSequence());
}

TEST_F(SequenceCheckerTest, DetachFromSequence) {
  std::unique_ptr<SequenceCheckerImpl> sequence_checker;

  {
    ScopedSetSequenceTokenForCurrentThread
        scoped_set_sequence_token_for_current_thread(SequenceToken::Create());
    sequence_checker.reset(new SequenceCheckerImpl);
  }

  sequence_checker->DetachFromSequence();

  {
    // Verify that CalledOnValidSequence() returns true when called with
    // a different sequence token after a call to DetachFromSequence().
    ScopedSetSequenceTokenForCurrentThread
        scoped_set_sequence_token_for_current_thread(SequenceToken::Create());
    EXPECT_TRUE(sequence_checker->CalledOnValidSequence());
  }
}

TEST_F(SequenceCheckerTest, DetachFromSequenceNoSequenceToken) {
  SequenceCheckerImpl sequence_checker;
  sequence_checker.DetachFromSequence();

  // Verify that CalledOnValidSequence() returns true when called on a
  // different thread after a call to DetachFromSequence().
  RunCallbackThread thread(
      Bind(&ExpectCalledOnValidSequence, Unretained(&sequence_checker)));

  EXPECT_FALSE(sequence_checker.CalledOnValidSequence());
}

TEST_F(SequenceCheckerTest, SequencedWorkerPool_SameSequenceTokenValid) {
  SequenceCheckerImpl sequence_checker;
  sequence_checker.DetachFromSequence();

  PostToSequencedWorkerPool(
      Bind(&ExpectCalledOnValidSequence, Unretained(&sequence_checker)), "A");
  PostToSequencedWorkerPool(
      Bind(&ExpectCalledOnValidSequence, Unretained(&sequence_checker)), "A");
  FlushSequencedWorkerPoolForTesting();
}

TEST_F(SequenceCheckerTest, SequencedWorkerPool_DetachSequenceTokenValid) {
  SequenceCheckerImpl sequence_checker;
  sequence_checker.DetachFromSequence();

  PostToSequencedWorkerPool(
      Bind(&ExpectCalledOnValidSequence, Unretained(&sequence_checker)), "A");
  PostToSequencedWorkerPool(
      Bind(&ExpectCalledOnValidSequence, Unretained(&sequence_checker)), "A");
  FlushSequencedWorkerPoolForTesting();

  sequence_checker.DetachFromSequence();

  PostToSequencedWorkerPool(
      Bind(&ExpectCalledOnValidSequence, Unretained(&sequence_checker)), "B");
  PostToSequencedWorkerPool(
      Bind(&ExpectCalledOnValidSequence, Unretained(&sequence_checker)), "B");
  FlushSequencedWorkerPoolForTesting();
}

TEST_F(SequenceCheckerTest,
       SequencedWorkerPool_DifferentSequenceTokensInvalid) {
  SequenceCheckerImpl sequence_checker;
  sequence_checker.DetachFromSequence();

  PostToSequencedWorkerPool(
      Bind(&ExpectCalledOnValidSequence, Unretained(&sequence_checker)), "A");
  PostToSequencedWorkerPool(
      Bind(&ExpectCalledOnValidSequence, Unretained(&sequence_checker)), "A");
  FlushSequencedWorkerPoolForTesting();

  PostToSequencedWorkerPool(
      Bind(&ExpectNotCalledOnValidSequence, Unretained(&sequence_checker)),
      "B");
  PostToSequencedWorkerPool(
      Bind(&ExpectNotCalledOnValidSequence, Unretained(&sequence_checker)),
      "B");
  FlushSequencedWorkerPoolForTesting();
}

TEST_F(SequenceCheckerTest,
       SequencedWorkerPool_WorkerPoolAndSimpleThreadInvalid) {
  SequenceCheckerImpl sequence_checker;
  sequence_checker.DetachFromSequence();

  PostToSequencedWorkerPool(
      Bind(&ExpectCalledOnValidSequence, Unretained(&sequence_checker)), "A");
  PostToSequencedWorkerPool(
      Bind(&ExpectCalledOnValidSequence, Unretained(&sequence_checker)), "A");
  FlushSequencedWorkerPoolForTesting();

  EXPECT_FALSE(sequence_checker.CalledOnValidSequence());
}

TEST_F(SequenceCheckerTest,
       SequencedWorkerPool_TwoDifferentWorkerPoolsInvalid) {
  SequenceCheckerImpl sequence_checker;
  sequence_checker.DetachFromSequence();

  PostToSequencedWorkerPool(
      Bind(&ExpectCalledOnValidSequence, Unretained(&sequence_checker)), "A");
  PostToSequencedWorkerPool(
      Bind(&ExpectCalledOnValidSequence, Unretained(&sequence_checker)), "A");
  FlushSequencedWorkerPoolForTesting();

  SequencedWorkerPoolOwner second_pool_owner(kNumWorkerThreads, "test2");
  second_pool_owner.pool()->PostNamedSequencedWorkerTask(
      "A", FROM_HERE,
      base::BindOnce(&ExpectNotCalledOnValidSequence,
                     base::Unretained(&sequence_checker)));
  second_pool_owner.pool()->FlushForTesting();
}

namespace {

// This fixture is a helper for unit testing the sequence checker macros as it
// is not possible to inline ExpectDeathOnOtherSequence() and
// ExpectNoDeathOnOtherSequenceAfterDetach() as lambdas since binding
// |Unretained(&my_sequence_checker)| wouldn't compile on non-dcheck builds
// where it won't be defined.
class SequenceCheckerMacroTest : public SequenceCheckerTest {
 public:
  SequenceCheckerMacroTest() = default;

  void ExpectDeathOnOtherSequence() {
#if DCHECK_IS_ON()
    EXPECT_DCHECK_DEATH({
      DCHECK_CALLED_ON_VALID_SEQUENCE(my_sequence_checker_) << "Error message.";
    });
#else
    // Happily no-ops on non-dcheck builds.
    DCHECK_CALLED_ON_VALID_SEQUENCE(my_sequence_checker_) << "Error message.";
#endif
  }

  void ExpectNoDeathOnOtherSequenceAfterDetach() {
    DCHECK_CALLED_ON_VALID_SEQUENCE(my_sequence_checker_) << "Error message.";
  }

 protected:
  SEQUENCE_CHECKER(my_sequence_checker_);

 private:
  DISALLOW_COPY_AND_ASSIGN(SequenceCheckerMacroTest);
};

}  // namespace

TEST_F(SequenceCheckerMacroTest, Macros) {
  PostToSequencedWorkerPool(
      Bind(&SequenceCheckerMacroTest::ExpectDeathOnOtherSequence,
           Unretained(this)),
      "A");
  FlushSequencedWorkerPoolForTesting();

  DETACH_FROM_SEQUENCE(my_sequence_checker_);

  PostToSequencedWorkerPool(
      Bind(&SequenceCheckerMacroTest::ExpectNoDeathOnOtherSequenceAfterDetach,
           Unretained(this)),
      "A");
  FlushSequencedWorkerPoolForTesting();
}

}  // namespace base
