next() and iterator
diff --git a/test/gtest/coro.cpp b/test/gtest/coro.cpp
index 3e991a0..7e67fb8 100644
--- a/test/gtest/coro.cpp
+++ b/test/gtest/coro.cpp
@@ -8,18 +8,18 @@
using namespace wasm;
-template<typename T> struct generator {
+template<typename T> struct Generator {
struct promise_type;
std::coroutine_handle<promise_type> handle;
- ~generator() { handle.destroy(); }
+ ~Generator() { handle.destroy(); }
struct promise_type {
std::optional<T> value;
- generator<T> get_return_object() {
- return generator<T>{
+ Generator<T> get_return_object() {
+ return Generator<T>{
std::coroutine_handle<promise_type>::from_promise(*this)};
}
@@ -37,7 +37,7 @@
void unhandled_exception() { WASM_UNREACHABLE("unhandled exception"); }
};
- std::optional<T> operator()() {
+ std::optional<T> next() {
if (!handle.done()) {
handle.resume();
}
@@ -45,9 +45,42 @@
handle.promise().value.reset();
return ret;
}
+
+ struct Iter {
+ using value_type = T;
+ using iterator_category = std::input_iterator_tag;
+ using difference_type = std::ptrdiff_t;
+ using reference = const T&;
+
+ Generator<T>* generator = nullptr;
+ std::optional<T> curr = std::nullopt;
+
+ const T& operator*() const { return *curr; }
+
+ const T* operator->() const { return &*curr; }
+
+ Iter& operator++() {
+ curr = generator->next();
+ return *this;
+ }
+
+ Iter operator++(int) {
+ auto it = *this;
+ ++(*this);
+ return it;
+ }
+
+ bool operator==(const Iter& other) const { return !curr && !other.curr; }
+ };
+
+ static_assert(std::input_iterator<Iter>);
+
+ Iter begin() { return ++Iter{this}; }
+
+ Iter end() { return Iter{this}; }
};
-generator<Expression**> walkExpressionPtrs(Expression*& curr) {
+Generator<Expression**> walkExpressionPtrs(Expression*& curr) {
struct Task {
Expression** currp;
bool done;
@@ -94,16 +127,12 @@
#include "wasm-delegations-fields.def"
}
-
- co_return;
}
-generator<Expression*> walkExpression(Expression* curr) {
- auto walker = walkExpressionPtrs(curr);
- while (auto exprp = walker()) {
- co_yield** exprp;
+Generator<Expression*> walkExpression(Expression* curr) {
+ for (auto& expr : walkExpressionPtrs(curr)) {
+ co_yield *expr;
}
- co_return;
}
TEST(Coro, Traversal) {
@@ -123,24 +152,24 @@
{
auto walker = walkExpressionPtrs(expr);
- ASSERT_EQ(walker(), std::optional{&add->left});
- ASSERT_EQ(walker(), std::optional{&add->right});
- ASSERT_EQ(walker(), std::optional{&mul->left});
- ASSERT_EQ(walker(), std::optional{&sub->left});
- ASSERT_EQ(walker(), std::optional{&sub->right});
- ASSERT_EQ(walker(), std::optional{&mul->right});
- ASSERT_EQ(walker(), std::optional{&expr});
- ASSERT_EQ(walker(), std::nullopt);
+ ASSERT_EQ(walker.next(), std::optional{&add->left});
+ ASSERT_EQ(walker.next(), std::optional{&add->right});
+ ASSERT_EQ(walker.next(), std::optional{&mul->left});
+ ASSERT_EQ(walker.next(), std::optional{&sub->left});
+ ASSERT_EQ(walker.next(), std::optional{&sub->right});
+ ASSERT_EQ(walker.next(), std::optional{&mul->right});
+ ASSERT_EQ(walker.next(), std::optional{&expr});
+ ASSERT_EQ(walker.next(), std::nullopt);
}
{
auto walker = walkExpression(expr);
- ASSERT_EQ(walker(), std::optional{add->left});
- ASSERT_EQ(walker(), std::optional{add->right});
- ASSERT_EQ(walker(), std::optional{add});
- ASSERT_EQ(walker(), std::optional{sub->left});
- ASSERT_EQ(walker(), std::optional{sub->right});
- ASSERT_EQ(walker(), std::optional{sub});
- ASSERT_EQ(walker(), std::optional{mul});
+ ASSERT_EQ(walker.next(), std::optional{add->left});
+ ASSERT_EQ(walker.next(), std::optional{add->right});
+ ASSERT_EQ(walker.next(), std::optional{add});
+ ASSERT_EQ(walker.next(), std::optional{sub->left});
+ ASSERT_EQ(walker.next(), std::optional{sub->right});
+ ASSERT_EQ(walker.next(), std::optional{sub});
+ ASSERT_EQ(walker.next(), std::optional{mul});
}
}