next() and iterator
diff --git a/test/gtest/coro.cpp b/test/gtest/coro.cpp index 3e991a0..7e67fb8 100644 --- a/test/gtest/coro.cpp +++ b/test/gtest/coro.cpp
@@ -8,18 +8,18 @@ using namespace wasm; -template<typename T> struct generator { +template<typename T> struct Generator { struct promise_type; std::coroutine_handle<promise_type> handle; - ~generator() { handle.destroy(); } + ~Generator() { handle.destroy(); } struct promise_type { std::optional<T> value; - generator<T> get_return_object() { - return generator<T>{ + Generator<T> get_return_object() { + return Generator<T>{ std::coroutine_handle<promise_type>::from_promise(*this)}; } @@ -37,7 +37,7 @@ void unhandled_exception() { WASM_UNREACHABLE("unhandled exception"); } }; - std::optional<T> operator()() { + std::optional<T> next() { if (!handle.done()) { handle.resume(); } @@ -45,9 +45,42 @@ handle.promise().value.reset(); return ret; } + + struct Iter { + using value_type = T; + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using reference = const T&; + + Generator<T>* generator = nullptr; + std::optional<T> curr = std::nullopt; + + const T& operator*() const { return *curr; } + + const T* operator->() const { return &*curr; } + + Iter& operator++() { + curr = generator->next(); + return *this; + } + + Iter operator++(int) { + auto it = *this; + ++(*this); + return it; + } + + bool operator==(const Iter& other) const { return !curr && !other.curr; } + }; + + static_assert(std::input_iterator<Iter>); + + Iter begin() { return ++Iter{this}; } + + Iter end() { return Iter{this}; } }; -generator<Expression**> walkExpressionPtrs(Expression*& curr) { +Generator<Expression**> walkExpressionPtrs(Expression*& curr) { struct Task { Expression** currp; bool done; @@ -94,16 +127,12 @@ #include "wasm-delegations-fields.def" } - - co_return; } -generator<Expression*> walkExpression(Expression* curr) { - auto walker = walkExpressionPtrs(curr); - while (auto exprp = walker()) { - co_yield** exprp; +Generator<Expression*> walkExpression(Expression* curr) { + for (auto& expr : walkExpressionPtrs(curr)) { + co_yield *expr; } - co_return; } TEST(Coro, Traversal) { @@ -123,24 +152,24 @@ { auto walker = walkExpressionPtrs(expr); - ASSERT_EQ(walker(), std::optional{&add->left}); - ASSERT_EQ(walker(), std::optional{&add->right}); - ASSERT_EQ(walker(), std::optional{&mul->left}); - ASSERT_EQ(walker(), std::optional{&sub->left}); - ASSERT_EQ(walker(), std::optional{&sub->right}); - ASSERT_EQ(walker(), std::optional{&mul->right}); - ASSERT_EQ(walker(), std::optional{&expr}); - ASSERT_EQ(walker(), std::nullopt); + ASSERT_EQ(walker.next(), std::optional{&add->left}); + ASSERT_EQ(walker.next(), std::optional{&add->right}); + ASSERT_EQ(walker.next(), std::optional{&mul->left}); + ASSERT_EQ(walker.next(), std::optional{&sub->left}); + ASSERT_EQ(walker.next(), std::optional{&sub->right}); + ASSERT_EQ(walker.next(), std::optional{&mul->right}); + ASSERT_EQ(walker.next(), std::optional{&expr}); + ASSERT_EQ(walker.next(), std::nullopt); } { auto walker = walkExpression(expr); - ASSERT_EQ(walker(), std::optional{add->left}); - ASSERT_EQ(walker(), std::optional{add->right}); - ASSERT_EQ(walker(), std::optional{add}); - ASSERT_EQ(walker(), std::optional{sub->left}); - ASSERT_EQ(walker(), std::optional{sub->right}); - ASSERT_EQ(walker(), std::optional{sub}); - ASSERT_EQ(walker(), std::optional{mul}); + ASSERT_EQ(walker.next(), std::optional{add->left}); + ASSERT_EQ(walker.next(), std::optional{add->right}); + ASSERT_EQ(walker.next(), std::optional{add}); + ASSERT_EQ(walker.next(), std::optional{sub->left}); + ASSERT_EQ(walker.next(), std::optional{sub->right}); + ASSERT_EQ(walker.next(), std::optional{sub}); + ASSERT_EQ(walker.next(), std::optional{mul}); } }