waitqueue notify
diff --git a/scripts/gen-s-parser.py b/scripts/gen-s-parser.py index 4160ad9..0af02f9 100755 --- a/scripts/gen-s-parser.py +++ b/scripts/gen-s-parser.py
@@ -207,6 +207,7 @@ # atomic instructions ("memory.atomic.notify", "makeAtomicNotify()"), ("waitqueue.wait", "makeWaitQueueWait()"), + ("waitqueue.notify", "makeWaitQueueNotify()"), ("memory.atomic.wait32", "makeAtomicWait(Type::i32)"), ("memory.atomic.wait64", "makeAtomicWait(Type::i64)"), ("atomic.fence", "makeAtomicFence()"),
diff --git a/src/gen-s-parser.inc b/src/gen-s-parser.inc index f0616d7..fde7fe6 100644 --- a/src/gen-s-parser.inc +++ b/src/gen-s-parser.inc
@@ -5784,12 +5784,23 @@ default: goto parse_error; } } - case 'w': - if (op == "waitqueue.wait"sv) { - CHECK_ERR(makeWaitQueueWait(ctx, pos, annotations)); - return Ok{}; + case 'w': { + switch (buf[10]) { + case 'n': + if (op == "waitqueue.notify"sv) { + CHECK_ERR(makeWaitQueueNotify(ctx, pos, annotations)); + return Ok{}; + } + goto parse_error; + case 'w': + if (op == "waitqueue.wait"sv) { + CHECK_ERR(makeWaitQueueWait(ctx, pos, annotations)); + return Ok{}; + } + goto parse_error; + default: goto parse_error; } - goto parse_error; + } default: goto parse_error; } parse_error:
diff --git a/src/interpreter/interpreter.cpp b/src/interpreter/interpreter.cpp index d1f198c..d8fb51b 100644 --- a/src/interpreter/interpreter.cpp +++ b/src/interpreter/interpreter.cpp
@@ -284,6 +284,7 @@ Flow visitResumeThrow(ResumeThrow* curr) { WASM_UNREACHABLE("TODO"); } Flow visitStackSwitch(StackSwitch* curr) { WASM_UNREACHABLE("TODO"); } Flow visitWaitQueueWait(WaitQueueWait* curr) { WASM_UNREACHABLE("TODO"); } + Flow visitWaitQueueNotify(WaitQueueNotify* curr) { WASM_UNREACHABLE("TODO"); } }; } // anonymous namespace
diff --git a/src/ir/ReFinalize.cpp b/src/ir/ReFinalize.cpp index e41f427..93fc904 100644 --- a/src/ir/ReFinalize.cpp +++ b/src/ir/ReFinalize.cpp
@@ -203,6 +203,9 @@ } void ReFinalize::visitStackSwitch(StackSwitch* curr) { curr->finalize(); } void ReFinalize::visitWaitQueueWait(WaitQueueWait* curr) { curr->finalize(); } +void ReFinalize::visitWaitQueueNotify(WaitQueueNotify* curr) { + curr->finalize(); +} void ReFinalize::visitExport(Export* curr) { WASM_UNREACHABLE("unimp"); } void ReFinalize::visitGlobal(Global* curr) { WASM_UNREACHABLE("unimp"); }
diff --git a/src/ir/child-typer.h b/src/ir/child-typer.h index ef77cbc..9f445fe 100644 --- a/src/ir/child-typer.h +++ b/src/ir/child-typer.h
@@ -1366,6 +1366,14 @@ note(&curr->value, Type(Type::BasicType::i32)); note(&curr->timeout, Type(Type::BasicType::i64)); } + + void visitWaitQueueNotify(WaitQueueNotify* curr) { + note(&curr->waitqueue, + Type(HeapType(Struct(std::vector{ + Field(Field::PackedType::WaitQueue, Mutability::Immutable)})), + NonNullable)); + note(&curr->count, Type(Type::BasicType::i32)); + } }; } // namespace wasm
diff --git a/src/ir/cost.h b/src/ir/cost.h index e292a4f..58de2e1 100644 --- a/src/ir/cost.h +++ b/src/ir/cost.h
@@ -122,6 +122,9 @@ return AtomicCost + visit(curr->waitqueue) + visit(curr->value) + visit(curr->timeout); } + CostType visitWaitQueueNotify(WaitQueueNotify* curr) { + return AtomicCost + visit(curr->waitqueue) + visit(curr->count); + } CostType visitAtomicNotify(AtomicNotify* curr) { return AtomicCost + visit(curr->ptr) + visit(curr->notifyCount); }
diff --git a/src/ir/effects.h b/src/ir/effects.h index b12ffb1..b4f960d 100644 --- a/src/ir/effects.h +++ b/src/ir/effects.h
@@ -1172,6 +1172,18 @@ parent.readsMutableStruct = true; } } + + void visitWaitQueueNotify(WaitQueueNotify* curr) { + parent.isAtomic = true; + + // field 0 must exist and be a packed waitqueue if this is valid Wasm. + if (curr->waitqueue->type.getHeapType() + .getStruct() + .fields.at(0) + .mutable_ == Mutable) { + parent.readsMutableStruct = true; + } + } }; public:
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp index 5d838aa..9bebac5 100644 --- a/src/ir/module-utils.cpp +++ b/src/ir/module-utils.cpp
@@ -449,6 +449,9 @@ void visitWaitQueueWait(WaitQueueWait* curr) { info.note(curr->waitqueue->type); } + void visitWaitQueueNotify(WaitQueueNotify* curr) { + info.note(curr->waitqueue->type); + } void visitBlock(Block* curr) { info.noteControlFlow(Signature(Type::none, curr->type)); }
diff --git a/src/ir/possible-contents.cpp b/src/ir/possible-contents.cpp index 3ba8aae..fccef6e 100644 --- a/src/ir/possible-contents.cpp +++ b/src/ir/possible-contents.cpp
@@ -1402,6 +1402,7 @@ addRoot(curr); } void visitWaitQueueWait(WaitQueueWait* curr) { addRoot(curr); } + void visitWaitQueueNotify(WaitQueueNotify* curr) { addRoot(curr); } void visitFunction(Function* func) { // Functions with a result can flow a value out from their body.
diff --git a/src/ir/subtype-exprs.h b/src/ir/subtype-exprs.h index 16d66b3..99c6325 100644 --- a/src/ir/subtype-exprs.h +++ b/src/ir/subtype-exprs.h
@@ -604,6 +604,12 @@ Field::PackedType::WaitQueue, Immutable)})), NonNullable)); } + void visitWaitQueueNotify(WaitQueueNotify* curr) { + self()->noteSubtype(curr->waitqueue, + Type(HeapType(Struct(std::vector{Field( + Field::PackedType::WaitQueue, Immutable)})), + NonNullable)); + } }; } // namespace wasm
diff --git a/src/parser/contexts.h b/src/parser/contexts.h index 3deb304..0257154 100644 --- a/src/parser/contexts.h +++ b/src/parser/contexts.h
@@ -960,6 +960,9 @@ Result<> makeWaitQueueWait(Index, const std::vector<Annotation>&) { return Ok{}; } + Result<> makeWaitQueueNotify(Index, const std::vector<Annotation>&) { + return Ok{}; + } }; struct NullCtx : NullTypeParserCtx, NullInstrParserCtx { @@ -2956,6 +2959,11 @@ const std::vector<Annotation>& annotations) { return withLoc(pos, irBuilder.makeWaitQueueWait()); } + + Result<> makeWaitQueueNotify(Index pos, + const std::vector<Annotation>& annotations) { + return withLoc(pos, irBuilder.makeWaitQueueNotify()); + } }; } // namespace wasm::WATParser
diff --git a/src/parser/parsers.h b/src/parser/parsers.h index 08ffaa5..4b01900 100644 --- a/src/parser/parsers.h +++ b/src/parser/parsers.h
@@ -2847,6 +2847,13 @@ return ctx.makeWaitQueueWait(pos, annotations); } +template<typename Ctx> +Result<> makeWaitQueueNotify(Ctx& ctx, + Index pos, + const std::vector<Annotation>& annotations) { + return ctx.makeWaitQueueNotify(pos, annotations); +} + // ======= // Modules // =======
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index a13a6fe..609cbbc 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp
@@ -381,6 +381,11 @@ visitExpression(curr); } } + void visitWaitQueueNotify(WaitQueueNotify* curr) { + if (!maybePrintUnreachableReplacement(curr, curr->type)) { + visitExpression(curr); + } + } // Module-level visitors void handleSignature(Function* curr, bool printImplicitNames = false); @@ -2667,6 +2672,10 @@ void visitWaitQueueWait(WaitQueueWait* curr) { printMedium(o, "waitqueue.wait"); } + + void visitWaitQueueNotify(WaitQueueNotify* curr) { + printMedium(o, "waitqueue.notify"); + } }; void PrintSExpression::setModule(Module* module) {
diff --git a/src/passes/TypeGeneralizing.cpp b/src/passes/TypeGeneralizing.cpp index 2e6610a..ddabb44 100644 --- a/src/passes/TypeGeneralizing.cpp +++ b/src/passes/TypeGeneralizing.cpp
@@ -900,6 +900,7 @@ void visitResumeThrow(ResumeThrow* curr) { WASM_UNREACHABLE("TODO"); } void visitStackSwitch(StackSwitch* curr) { WASM_UNREACHABLE("TODO"); } void visitWaitQueueWait(WaitQueueWait* curr) { WASM_UNREACHABLE("TODO"); } + void visitWaitQueueNotify(WaitQueueNotify* curr) { WASM_UNREACHABLE("TODO"); } }; struct TypeGeneralizing : WalkerPass<PostWalker<TypeGeneralizing>> {
diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 332e725..f1b6606 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h
@@ -703,6 +703,7 @@ AtomicFence = 0x03, Pause = 0x04, WaitQueueWait = 0x05, + WaitQueueNotify = 0x06, I32AtomicLoad = 0x10, I64AtomicLoad = 0x11,
diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 74fd2ba..0b70bfe 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h
@@ -1374,6 +1374,15 @@ return ret; } + WaitQueueNotify* makeWaitQueueNotify(Expression* waitqueue, + Expression* count) { + auto* ret = wasm.allocator.alloc<WaitQueueNotify>(); + ret->waitqueue = waitqueue; + ret->count = count; + ret->finalize(); + return ret; + } + // Additional helpers Drop* makeDrop(Expression* value) {
diff --git a/src/wasm-delegations-fields.def b/src/wasm-delegations-fields.def index 84746c3..f194283 100644 --- a/src/wasm-delegations-fields.def +++ b/src/wasm-delegations-fields.def
@@ -893,6 +893,11 @@ DELEGATE_FIELD_CHILD(WaitQueueWait, waitqueue) DELEGATE_FIELD_CASE_END(WaitQueueWait) +DELEGATE_FIELD_CASE_START(WaitQueueNotify) +DELEGATE_FIELD_CHILD(WaitQueueNotify, count) +DELEGATE_FIELD_CHILD(WaitQueueNotify, waitqueue) +DELEGATE_FIELD_CASE_END(WaitQueueNotify) + DELEGATE_FIELD_MAIN_END
diff --git a/src/wasm-delegations.def b/src/wasm-delegations.def index 25144d6..d5a23b2 100644 --- a/src/wasm-delegations.def +++ b/src/wasm-delegations.def
@@ -116,5 +116,6 @@ DELEGATE(ResumeThrow); DELEGATE(StackSwitch); DELEGATE(WaitQueueWait); +DELEGATE(WaitQueueNotify); #undef DELEGATE
diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 710975f..cdafcc8 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h
@@ -2935,6 +2935,9 @@ Flow visitWaitQueueWait(WaitQueueWait* curr) { return Flow(NONCONSTANT_FLOW); } + Flow visitWaitQueueNotify(WaitQueueNotify* curr) { + return Flow(NONCONSTANT_FLOW); + } void trap(std::string_view why) override { throw NonconstantException(); } @@ -4924,6 +4927,9 @@ Flow visitWaitQueueWait(WaitQueueWait* curr) { WASM_UNREACHABLE("waitqueue not implemented"); } + Flow visitWaitQueueNotify(WaitQueueNotify* curr) { + WASM_UNREACHABLE("waitqueue not implemented"); + } void trap(std::string_view why) override { // Traps break all current continuations - they will never be resumable.
diff --git a/src/wasm-ir-builder.h b/src/wasm-ir-builder.h index d011353..9cd3a59 100644 --- a/src/wasm-ir-builder.h +++ b/src/wasm-ir-builder.h
@@ -286,6 +286,7 @@ } Result<> makeStackSwitch(HeapType ct, Name tag); Result<> makeWaitQueueWait(); + Result<> makeWaitQueueNotify(); // Private functions that must be public for technical reasons. Result<> visitExpression(Expression*);
diff --git a/src/wasm.h b/src/wasm.h index 46af167..59cbf0d 100644 --- a/src/wasm.h +++ b/src/wasm.h
@@ -761,6 +761,7 @@ // Id for the stack switching `switch` StackSwitchId, WaitQueueWaitId, + WaitQueueNotifyId, NumExpressionIds }; Id _id; @@ -2168,6 +2169,18 @@ void finalize(); }; +class WaitQueueNotify + : public SpecificExpression<Expression::WaitQueueNotifyId> { +public: + WaitQueueNotify() = default; + WaitQueueNotify(MixedArena& allocator) : WaitQueueNotify() {} + + Expression* waitqueue = nullptr; + Expression* count = nullptr; + + void finalize(); +}; + // Globals struct Named {
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 2659db8..9d3427a 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp
@@ -3916,6 +3916,9 @@ case BinaryConsts::WaitQueueWait: { return builder.makeWaitQueueWait(); } + case BinaryConsts::WaitQueueNotify: { + return builder.makeWaitQueueNotify(); + } } return Err{"unknown atomic operation " + std::to_string(op)}; }
diff --git a/src/wasm/wasm-ir-builder.cpp b/src/wasm/wasm-ir-builder.cpp index 8c10ee2..97e1717 100644 --- a/src/wasm/wasm-ir-builder.cpp +++ b/src/wasm/wasm-ir-builder.cpp
@@ -668,6 +668,12 @@ ConstraintCollector{builder, children}.visitWaitQueueWait(curr); return popConstrainedChildren(children); } + + Result<> visitWaitQueueNotify(WaitQueueNotify* curr) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitWaitQueueNotify(curr); + return popConstrainedChildren(children); + } }; Result<> IRBuilder::visit(Expression* curr) { @@ -2668,6 +2674,13 @@ return Ok{}; } +Result<> IRBuilder::makeWaitQueueNotify() { + WaitQueueNotify curr(wasm.allocator); + CHECK_ERR(ChildPopper{*this}.visitWaitQueueNotify(&curr)); + push(builder.makeWaitQueueNotify(curr.waitqueue, curr.count)); + return Ok{}; +} + void IRBuilder::applyAnnotations(Expression* expr, const CodeAnnotation& annotation) { if (annotation.branchLikely) {
diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index 75c99ad..7ca2e87 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp
@@ -2916,6 +2916,11 @@ << static_cast<int8_t>(BinaryConsts::WaitQueueWait); } +void BinaryInstWriter::visitWaitQueueNotify(WaitQueueNotify* curr) { + o << static_cast<int8_t>(BinaryConsts::AtomicPrefix) + << static_cast<int8_t>(BinaryConsts::WaitQueueNotify); +} + void BinaryInstWriter::emitScopeEnd(Expression* curr) { assert(!breakStack.empty()); breakStack.pop_back();
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index e2bc36b..8f017b8 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp
@@ -570,6 +570,7 @@ void visitResumeThrow(ResumeThrow* curr); void visitStackSwitch(StackSwitch* curr); void visitWaitQueueWait(WaitQueueWait* curr); + void visitWaitQueueNotify(WaitQueueNotify* curr); void visitFunction(Function* curr); @@ -4283,6 +4284,24 @@ "waitqueue.wait timeout must be an i64"); } +void FunctionValidator::visitWaitQueueNotify(WaitQueueNotify* curr) { + shouldBeTrue( + !getModule() || getModule()->features.hasSharedEverything(), + curr, + "waitqueue.notify requires shared-everything [--enable-shared-everything]"); + shouldBeSubType(curr->waitqueue->type, + Type(HeapType(Struct(std::vector{ + Field(Field::PackedType::WaitQueue, Immutable)})), + NonNullable), + curr, + "waitqueue.notify waitqueue must be a subtype of (struct " + "(field waitqueue))"); + shouldBeEqual(curr->count->type, + Type(Type::BasicType::i32), + curr, + "waitqueue.notify count must be an i32"); +} + void FunctionValidator::visitFunction(Function* curr) { FeatureSet features; // Check for things like having a rec group with GC enabled. The type we're
diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 4133a0b..4408a87 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp
@@ -2019,4 +2019,6 @@ void WaitQueueWait::finalize() { type = Type::i32; } +void WaitQueueNotify::finalize() { type = Type::i32; } + } // namespace wasm
diff --git a/src/wasm2js.h b/src/wasm2js.h index fb176ad..a6573e9 100644 --- a/src/wasm2js.h +++ b/src/wasm2js.h
@@ -2468,6 +2468,10 @@ unimplemented(curr); WASM_UNREACHABLE("unimp"); } + Ref visitWaitQueueNotify(WaitQueueNotify* curr) { + unimplemented(curr); + WASM_UNREACHABLE("unimp"); + } private: Ref makePointer(Expression* ptr, Address offset) {
diff --git a/test/lit/waitqueue.wast b/test/lit/waitqueue.wast index e778b9d..d7ddbad 100644 --- a/test/lit/waitqueue.wast +++ b/test/lit/waitqueue.wast
@@ -9,14 +9,23 @@ ;; RTRIP-NEXT: )) (global $g (ref $t) (struct.new $t (i32.const 0))) - ;; RTRIP: (func $f (type $1) (result i32) + ;; RTRIP: (func $waitqueue.wait (type $1) (result i32) ;; RTRIP-NEXT: (waitqueue.wait ;; RTRIP-NEXT: (global.get $g) ;; RTRIP-NEXT: (i32.const 0) ;; RTRIP-NEXT: (i64.const 0) ;; RTRIP-NEXT: ) ;; RTRIP-NEXT: ) - (func $f (result i32) + (func $waitqueue.wait (result i32) (waitqueue.wait (global.get $g) (i32.const 0) (i64.const 0)) ) + ;; RTRIP: (func $waitqueue.notify (type $1) (result i32) + ;; RTRIP-NEXT: (waitqueue.notify + ;; RTRIP-NEXT: (global.get $g) + ;; RTRIP-NEXT: (i32.const 0) + ;; RTRIP-NEXT: ) + ;; RTRIP-NEXT: ) + (func $waitqueue.notify (result i32) + (waitqueue.notify (global.get $g) (i32.const 0)) + ) )
diff --git a/test/spec/waitqueue.wast b/test/spec/waitqueue.wast index a0331d8..da871dd 100644 --- a/test/spec/waitqueue.wast +++ b/test/spec/waitqueue.wast
@@ -26,6 +26,24 @@ ) "waitqueue.wait timeout must be an i64" ) +(assert_invalid + (module + (func (param $count i32) (result i32) + (waitqueue.notify (i32.const 0) (local.get $count)) + ) + ) "waitqueue.notify waitqueue must be a subtype of (struct (field waitqueue))" +) + +(assert_invalid + (module + (type $t (struct (field waitqueue))) + (global $g (ref $t) (struct.new $t (i32.const 0))) + (func (param $count i32) (result i32) + (waitqueue.notify (global.get $g) (i64.const 0)) + ) + ) "waitqueue.notify count must be an i32" +) + (module (type $t (struct (field waitqueue))) @@ -34,4 +52,8 @@ (func (export "waitqueue.wait") (param $expected i32) (param $timeout i64) (result i32) (waitqueue.wait (global.get $g) (local.get $expected) (local.get $timeout)) ) + + (func (export "waitqueue.notify") (param $count i32) (result i32) + (waitqueue.notify (global.get $g) (local.get $count)) + ) )