Add support for Promise<std::tuple<...>> which can be applied

This is useful for Promises::All and for Mojo if we wanted Mojo IPCs to
be promise based.

Design: https://docs.google.com/document/d/1l12PAJgEtlrqTXKiw6mk2cR2jP7FAfCCDr-DGIdiC9w/edit

Bug: 906125
Change-Id: Ibee1acc622e35a80acf2f264f9d1cb9308de3e02
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1617447
Commit-Queue: Alex Clarke <alexclarke@chromium.org>
Reviewed-by: Fran├žois Doray <fdoray@chromium.org>
Reviewed-by: Etienne Pierre-Doray <etiennep@chromium.org>
Cr-Commit-Position: refs/heads/master@{#663407}
diff --git a/base/task/promise/abstract_promise.cc b/base/task/promise/abstract_promise.cc
index 412db6e..3568745 100644
--- a/base/task/promise/abstract_promise.cc
+++ b/base/task/promise/abstract_promise.cc
@@ -261,11 +261,6 @@
 
   // This is likely to delete the executor.
   GetExecutor()->Execute(this);
-
-  // We need to release any AdjacencyListNodes we own to prevent memory leaks
-  // due to refcount cycles.
-  if (prerequisites_ && !IsResolvedWithPromise())
-    prerequisites_->prerequisite_list.clear();
 }
 
 bool AbstractPromise::DispatchIfNonCurriedRootSettled() {
@@ -392,19 +387,25 @@
 
   DependentList::Node* dependent_list = dependents_.ConsumeOnceForCancel();
 
-  // Release all pre-requisites to prevent memory leaks
-  if (prerequisites_) {
-    for (AdjacencyListNode& node : prerequisites_->prerequisite_list) {
-      node.prerequisite = nullptr;
-    }
-  }
-
   // Propagate cancellation to dependents.
   while (dependent_list) {
     scoped_refptr<AbstractPromise> dependent =
         std::move(dependent_list->dependent);
-    dependent->OnPrerequisiteCancelled();
     dependent_list = dependent_list->next.load(std::memory_order_relaxed);
+    dependent->OnPrerequisiteCancelled();
+  }
+
+  // We need to release any AdjacencyListNodes we own to prevent memory leaks
+  // due to refcount cycles. We can't just clear |prerequisite_list| (which
+  // contains DependentList::Node) because in the case of multiple prerequisites
+  // they may not have all be settled, which means some will want to traverse
+  // their |dependent_list| which includes this promise. This is a problem
+  // because there isn't a conveniant way of removing ourself from their
+  // |dependent_list|. It's sufficent however to simply null our references.
+  if (prerequisites_) {
+    for (AdjacencyListNode& node : prerequisites_->prerequisite_list) {
+      node.prerequisite = nullptr;
+    }
   }
 }
 
@@ -437,6 +438,11 @@
     AddAsDependentForAllPrerequisites();
   } else {
     OnResolvePostReadyDependents();
+
+    // We need to release any AdjacencyListNodes we own to prevent memory leaks
+    // due to refcount cycles.
+    if (prerequisites_)
+      prerequisites_->prerequisite_list.clear();
   }
 }
 
@@ -445,6 +451,19 @@
   DCHECK(executor_can_reject_) << from_here_.ToString();
 #endif
   OnRejectPostReadyDependents();
+
+  // We need to release any AdjacencyListNodes we own to prevent memory leaks
+  // due to refcount cycles. We can't just clear |prerequisite_list| (which
+  // contains DependentList::Node) because in the case of multiple prerequisites
+  // they may not have all be settled, which means some will want to traverse
+  // their |dependent_list| which includes this promise. This is a problem
+  // because there isn't a conveniant way of removing ourself from their
+  // |dependent_list|. It's sufficent however to simply null our references.
+  if (prerequisites_) {
+    for (AdjacencyListNode& node : prerequisites_->prerequisite_list) {
+      node.prerequisite = nullptr;
+    }
+  }
 }
 
 // static
diff --git a/base/task/promise/abstract_promise.h b/base/task/promise/abstract_promise.h
index 1ae8231..a967a02 100644
--- a/base/task/promise/abstract_promise.h
+++ b/base/task/promise/abstract_promise.h
@@ -176,12 +176,24 @@
 
   const unique_any& value() const { return FindNonCurriedAncestor()->value_; }
 
-  // Moves the value from within T::value.
-  template <typename T>
-  auto TakeInnerValue() {
-    T* ptr = unique_any_cast<T>(&FindNonCurriedAncestor()->value_);
-    DCHECK(ptr);
-    return std::move(ptr->value);
+  class ValueHandle {
+   public:
+    unique_any& value() { return value_; }
+
+    ~ValueHandle() { value_.reset(); }
+
+   private:
+    friend class AbstractPromise;
+
+    explicit ValueHandle(unique_any& value) : value_(value) {}
+
+    unique_any& value_;
+  };
+
+  ValueHandle TakeValue() {
+    AbstractPromise* non_curried_ancestor = FindNonCurriedAncestor();
+    DCHECK(non_curried_ancestor->value_.has_value());
+    return ValueHandle(non_curried_ancestor->value_);
   }
 
   // If this promise isn't curried, returns this. Otherwise follows the chain of
diff --git a/base/task/promise/abstract_promise_unittest.cc b/base/task/promise/abstract_promise_unittest.cc
index 964be98..22bdcdb 100644
--- a/base/task/promise/abstract_promise_unittest.cc
+++ b/base/task/promise/abstract_promise_unittest.cc
@@ -2115,7 +2115,6 @@
         FROM_HERE, BindLambdaForTesting([&]() {
           scoped_refptr<AbstractPromise> p =
               ThenPromise(FROM_HERE, root).With(decrement_cb);
-          p->OnResolved();
         }));
 
     // Mid way through post a task to resolve |root|.
diff --git a/base/task/promise/helpers.h b/base/task/promise/helpers.h
index 62a351b..609d510 100644
--- a/base/task/promise/helpers.h
+++ b/base/task/promise/helpers.h
@@ -5,10 +5,12 @@
 #ifndef BASE_TASK_PROMISE_HELPERS_H_
 #define BASE_TASK_PROMISE_HELPERS_H_
 
+#include <tuple>
 #include <type_traits>
 
 #include "base/bind.h"
 #include "base/callback.h"
+#include "base/parameter_pack.h"
 #include "base/task/promise/abstract_promise.h"
 #include "base/task/promise/promise_result.h"
 
@@ -152,19 +154,7 @@
 struct CallbackTraits;
 
 template <typename T>
-struct CallbackTraits<base::OnceCallback<T()>> {
-  using ResolveType = typename internal::PromiseCallbackTraits<T>::ResolveType;
-  using RejectType = typename internal::PromiseCallbackTraits<T>::RejectType;
-  using ArgType = void;
-  using ReturnType = T;
-  using SignatureType = T();
-  static constexpr AbstractPromise::Executor::ArgumentPassingType
-      argument_passing_type =
-          AbstractPromise::Executor::ArgumentPassingType::kNormal;
-};
-
-template <typename T>
-struct CallbackTraits<base::RepeatingCallback<T()>> {
+struct CallbackTraits<T()> {
   using ResolveType = typename internal::PromiseCallbackTraits<T>::ResolveType;
   using RejectType = typename internal::PromiseCallbackTraits<T>::RejectType;
   using ArgType = void;
@@ -176,7 +166,7 @@
 };
 
 template <typename T, typename Arg>
-struct CallbackTraits<base::OnceCallback<T(Arg)>> {
+struct CallbackTraits<T(Arg)> {
   using ResolveType = typename internal::PromiseCallbackTraits<T>::ResolveType;
   using RejectType = typename internal::PromiseCallbackTraits<T>::RejectType;
   using ArgType = Arg;
@@ -186,17 +176,32 @@
       argument_passing_type = UseMoveSemantics<Arg>::argument_passing_type;
 };
 
-template <typename T, typename Arg>
-struct CallbackTraits<base::RepeatingCallback<T(Arg)>> {
+template <typename T, typename... Args>
+struct CallbackTraits<T(Args...)> {
   using ResolveType = typename internal::PromiseCallbackTraits<T>::ResolveType;
   using RejectType = typename internal::PromiseCallbackTraits<T>::RejectType;
-  using ArgType = Arg;
+  using ArgType =
+      std::conditional_t<(sizeof...(Args) > 0), std::tuple<Args...>, void>;
   using ReturnType = T;
-  using SignatureType = T(Arg);
+  using SignatureType = T(Args...);
+
+  // If any arguments need move semantics, treat as if they all do.
   static constexpr AbstractPromise::Executor::ArgumentPassingType
-      argument_passing_type = UseMoveSemantics<Arg>::argument_passing_type;
+      argument_passing_type =
+          any_of({UseMoveSemantics<Args>::value...})
+              ? AbstractPromise::Executor::ArgumentPassingType::kMove
+              : AbstractPromise::Executor::ArgumentPassingType::kNormal;
 };
 
+// Adaptors for OnceCallback and RepeatingCallback
+template <typename T, typename... Args>
+struct CallbackTraits<OnceCallback<T(Args...)>>
+    : public CallbackTraits<T(Args...)> {};
+
+template <typename T, typename... Args>
+struct CallbackTraits<RepeatingCallback<T(Args...)>>
+    : public CallbackTraits<T(Args...)> {};
+
 // Helper for combining the resolve types of two promises.
 template <typename A, typename B>
 struct ResolveCombinerHelper {
@@ -276,9 +281,6 @@
   static constexpr bool valid = ResolveHelper::valid && RejectHelper::valid;
 };
 
-// TODO(alexclarke): Specialize |CallbackTraits| for callbacks with more than
-// one argument to support Promises::All.
-
 template <typename RejectStorage>
 struct EmplaceInnerHelper {
   template <typename Resolve, typename Reject>
@@ -351,7 +353,8 @@
 
  private:
   static CbArg GetImpl(AbstractPromise* arg, std::true_type should_move) {
-    return arg->TakeInnerValue<ArgStorageType>();
+    return std::move(
+        unique_any_cast<ArgStorageType>(&arg->TakeValue().value())->value);
   }
 
   static CbArg GetImpl(AbstractPromise* arg, std::false_type should_move) {
@@ -450,7 +453,85 @@
   }
 };
 
-// TODO(alexclarke): Specialize RunHelper for callbacks unpacked from a tuple.
+template <typename T>
+struct UnwrapCallback;
+
+template <typename R, typename... Args>
+struct UnwrapCallback<R(Args...)> {
+  using ArgsTuple = std::tuple<Args...>;
+};
+
+// Helper for getting callback arguments from a tuple, which works out if move
+// semantics are needed.
+template <typename Callback, typename Tuple, size_t Index>
+struct TupleArgMoveSemanticsHelper {
+  using CallbackArgsTuple =
+      typename UnwrapCallback<typename Callback::RunType>::ArgsTuple;
+  using CbArg = std::tuple_element_t<Index, CallbackArgsTuple>;
+
+  static CbArg Get(Tuple& tuple) {
+    return GetImpl(tuple, UseMoveSemantics<CbArg>());
+  }
+
+ private:
+  static CbArg GetImpl(Tuple& tuple, std::true_type should_move) {
+    return std::move(std::get<Index>(tuple));
+  }
+
+  static CbArg GetImpl(Tuple& tuple, std::false_type should_move) {
+    return std::get<Index>(tuple);
+  }
+};
+
+// Run helper for running a callbacks with the arguments unpacked from a tuple.
+template <typename CbResult,
+          typename... CbArgs,
+          typename ResolveStorage,
+          typename RejectStorage>
+struct RunHelper<OnceCallback<CbResult(CbArgs...)>,
+                 Resolved<std::tuple<CbArgs...>>,
+                 ResolveStorage,
+                 RejectStorage> {
+  using Callback = OnceCallback<CbResult(CbArgs...)>;
+  using StorageType = Resolved<std::tuple<CbArgs...>>;
+  using IndexSequence = std::index_sequence_for<CbArgs...>;
+
+  static void Run(Callback executor,
+                  AbstractPromise* arg,
+                  AbstractPromise* result) {
+    AbstractPromise::ValueHandle value = arg->TakeValue();
+    std::tuple<CbArgs...>& tuple =
+        unique_any_cast<StorageType>(&value.value())->value;
+    RunInternal(std::move(executor), tuple, result,
+                std::integral_constant<bool, std::is_void<CbResult>::value>(),
+                IndexSequence{});
+  }
+
+ private:
+  template <typename Callback, size_t... Indices>
+  static void RunInternal(Callback executor,
+                          std::tuple<CbArgs...>& tuple,
+                          AbstractPromise* result,
+                          std::false_type void_result,
+                          std::index_sequence<Indices...>) {
+    EmplaceHelper<ResolveStorage, RejectStorage>::Emplace(
+        std::move(executor).Run(
+            TupleArgMoveSemanticsHelper<Callback, std::tuple<CbArgs...>,
+                                        Indices>::Get(tuple)...));
+  }
+
+  template <typename Callback, size_t... Indices>
+  static void RunInternal(Callback executor,
+                          std::tuple<CbArgs...>& tuple,
+                          AbstractPromise* result,
+                          std::true_type void_result,
+                          std::index_sequence<Indices...>) {
+    std::move(executor).Run(
+        TupleArgMoveSemanticsHelper<Callback, std::tuple<CbArgs...>,
+                                    Indices>::Get(tuple)...);
+    result->emplace(Resolved<void>());
+  }
+};
 
 // Used by ManualPromiseResolver<> to generate callbacks.
 template <typename T, typename... Args>
diff --git a/base/task/promise/promise.h b/base/task/promise/promise.h
index d8fcffe..f2ebc83 100644
--- a/base/task/promise/promise.h
+++ b/base/task/promise/promise.h
@@ -483,26 +483,62 @@
   }
 
   typename ResolveHelper::Callback GetResolveCallback() {
+    static_assert(!std::is_same<ResolveType, NoResolve>::value,
+                  "Cant resolve a NoResolve promise");
     return ResolveHelper::GetResolveCallback(promise_.abstract_promise_);
   }
 
+  template <typename... Args>
+  auto GetResolveCallback() {
+    static_assert(!std::is_same<ResolveType, NoResolve>::value,
+                  "Cant resolve a NoResolve promise");
+    using Helper = internal::PromiseCallbackHelper<ResolveType, Args...>;
+    return Helper::GetResolveCallback(promise_.abstract_promise_);
+  }
+
   typename ResolveHelper::RepeatingCallback GetRepeatingResolveCallback() {
+    static_assert(!std::is_same<ResolveType, NoResolve>::value,
+                  "Cant resolve a NoResolve promise");
     return ResolveHelper::GetRepeatingResolveCallback(
         promise_.abstract_promise_);
   }
 
+  template <typename... Args>
+  auto GetRepeatingResolveCallback() {
+    static_assert(!std::is_same<ResolveType, NoResolve>::value,
+                  "Cant resolve a NoResolve promise");
+    using Helper = internal::PromiseCallbackHelper<ResolveType, Args...>;
+    return Helper::GetRepeatingResolveCallback(promise_.abstract_promise_);
+  }
+
   typename RejectHelper::Callback GetRejectCallback() {
     static_assert(!std::is_same<NoReject, RejectType>::value,
                   "Can't reject a NoReject promise.");
     return RejectHelper::GetRejectCallback(promise_.abstract_promise_);
   }
 
+  template <typename... Args>
+  auto GetRejectCallback() {
+    static_assert(!std::is_same<NoReject, RejectType>::value,
+                  "Can't reject a NoReject promise.");
+    using Helper = internal::PromiseCallbackHelper<RejectType, Args...>;
+    return Helper::GetRejectCallback(promise_.abstract_promise_);
+  }
+
   typename RejectHelper::RepeatingCallback GetRepeatingRejectCallback() {
     static_assert(!std::is_same<NoReject, RejectType>::value,
                   "Can't reject a NoReject promise.");
     return RejectHelper::GetRepeatingRejectCallback(promise_.abstract_promise_);
   }
 
+  template <typename... Args>
+  auto GetRepeatingRejectCallback() {
+    static_assert(!std::is_same<NoReject, RejectType>::value,
+                  "Can't reject a NoReject promise.");
+    using Helper = internal::PromiseCallbackHelper<RejectType, Args...>;
+    return Helper::GetRepeatingRejectCallback(promise_.abstract_promise_);
+  }
+
   Promise<ResolveType, RejectType>& promise() { return promise_; }
 
  private:
diff --git a/base/task/promise/promise_unittest.cc b/base/task/promise/promise_unittest.cc
index 50c0d461..f232575 100644
--- a/base/task/promise/promise_unittest.cc
+++ b/base/task/promise/promise_unittest.cc
@@ -131,6 +131,80 @@
   run_loop.Run();
 }
 
+TEST_F(PromiseTest, GetResolveCallbackMultipleArgs) {
+  ManualPromiseResolver<std::tuple<int, bool, float>> p(FROM_HERE);
+  p.GetResolveCallback<int, bool, float>().Run(123, true, 1.5f);
+
+  RunLoop run_loop;
+  p.promise().ThenOnCurrent(FROM_HERE,
+                            BindLambdaForTesting([&](int a, bool b, float c) {
+                              EXPECT_EQ(123, a);
+                              EXPECT_TRUE(b);
+                              EXPECT_EQ(1.5f, c);
+                              run_loop.Quit();
+                            }));
+
+  run_loop.Run();
+}
+
+TEST_F(PromiseTest, ResolveWithTuple) {
+  ManualPromiseResolver<void> p(FROM_HERE);
+  p.Resolve();
+
+  RunLoop run_loop;
+  p.promise()
+      .ThenOnCurrent(FROM_HERE, BindOnce([]() {
+                       return std::tuple<int, bool>(123, false);
+                     }))
+      .ThenOnCurrent(FROM_HERE, BindLambdaForTesting(
+                                    [&](const std::tuple<int, bool>& tuple) {
+                                      EXPECT_EQ(123, std::get<0>(tuple));
+                                      EXPECT_FALSE(std::get<1>(tuple));
+                                      run_loop.Quit();
+                                    }));
+
+  run_loop.Run();
+}
+
+TEST_F(PromiseTest, ResolveWithUnpackedTuple) {
+  ManualPromiseResolver<void> p(FROM_HERE);
+  p.Resolve();
+
+  RunLoop run_loop;
+  p.promise()
+      .ThenOnCurrent(FROM_HERE, BindOnce([]() {
+                       return std::tuple<int, bool>(123, false);
+                     }))
+      .ThenOnCurrent(FROM_HERE, BindLambdaForTesting([&](int a, bool b) {
+                       EXPECT_EQ(123, a);
+                       EXPECT_FALSE(b);
+                       run_loop.Quit();
+                     }));
+
+  run_loop.Run();
+}
+
+TEST_F(PromiseTest, ResolveWithUnpackedTupleMoveOnlyTypes) {
+  ManualPromiseResolver<void> p(FROM_HERE);
+  p.Resolve();
+
+  RunLoop run_loop;
+  p.promise()
+      .ThenOnCurrent(FROM_HERE, BindOnce([]() {
+                       return std::make_tuple(std::make_unique<int>(42),
+                                              std::make_unique<float>(4.2f));
+                     }))
+      .ThenOnCurrent(FROM_HERE,
+                     BindLambdaForTesting(
+                         [&](std::unique_ptr<int> a, std::unique_ptr<float> b) {
+                           EXPECT_EQ(42, *a);
+                           EXPECT_EQ(4.2f, *b);
+                           run_loop.Quit();
+                         }));
+
+  run_loop.Run();
+}
+
 TEST_F(PromiseTest, GetRejectCallbackCatch) {
   ManualPromiseResolver<int, std::string> p(FROM_HERE);
 
@@ -191,6 +265,46 @@
   run_loop.Run();
 }
 
+TEST_F(PromiseTest, ThenRejectWithTuple) {
+  ManualPromiseResolver<void> p(FROM_HERE);
+  p.Resolve();
+
+  RunLoop run_loop;
+  p.promise()
+      .ThenOnCurrent(FROM_HERE, BindOnce([]() {
+                       return Rejected<std::tuple<int, bool>>{123, false};
+                     }))
+      .CatchOnCurrent(FROM_HERE, BindLambdaForTesting(
+                                     [&](const std::tuple<int, bool>& tuple) {
+                                       EXPECT_EQ(123, std::get<0>(tuple));
+                                       EXPECT_FALSE(std::get<1>(tuple));
+                                       run_loop.Quit();
+                                     }));
+
+  run_loop.Run();
+}
+
+TEST_F(PromiseTest, GetRejectCallbackMultipleArgs) {
+  ManualPromiseResolver<int, std::tuple<bool, std::string>> p(FROM_HERE);
+
+  RunLoop run_loop;
+  p.promise().ThenOnCurrent(
+      FROM_HERE, BindLambdaForTesting([&](int result) {
+        run_loop.Quit();
+        FAIL() << "We shouldn't get here, the promise was rejected!";
+      }),
+      BindLambdaForTesting([&](const std::tuple<bool, std::string>& err) {
+        // NB we don't currently support tuple expansion for reject.
+        // Its not hard to add, but it's unclear if it will ever be used.
+        run_loop.Quit();
+        EXPECT_FALSE(std::get<0>(err));
+        EXPECT_EQ("Noes!", std::get<1>(err));
+      }));
+
+  p.GetRejectCallback<bool, std::string>().Run(false, "Noes!");
+  run_loop.Run();
+}
+
 TEST_F(PromiseTest, CatchOnCurrentReturnTypes) {
   ManualPromiseResolver<int, void> p1(FROM_HERE);
 
@@ -578,6 +692,18 @@
   run_loop.Run();
 }
 
+TEST_F(PromiseTest, ResolveThenStdTupleUnpack) {
+  RunLoop run_loop;
+  Promise<std::tuple<int, std::string>>::CreateResolved(FROM_HERE, 10,
+                                                        std::string("Hi"))
+      .ThenOnCurrent(FROM_HERE, BindLambdaForTesting([&](int a, std::string b) {
+                       EXPECT_EQ(10, a);
+                       EXPECT_EQ("Hi", b);
+                       run_loop.Quit();
+                     }));
+  run_loop.Run();
+}
+
 TEST_F(PromiseTest, ResolveAfterThen) {
   ManualPromiseResolver<int> p(FROM_HERE);