| /* |
| * Copyright (C) 2023 Apple Inc. All rights reserved. |
| * |
| * Redistribution and use in source and binary forms, with or without |
| * modification, are permitted provided that the following conditions |
| * are met: |
| * 1. Redistributions of source code must retain the above copyright |
| * notice, this list of conditions and the following disclaimer. |
| * 2. Redistributions in binary form must reproduce the above copyright |
| * notice, this list of conditions and the following disclaimer in the |
| * documentation and/or other materials provided with the distribution. |
| * |
| * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS'' |
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, |
| * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR |
| * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS |
| * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
| * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
| * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
| * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
| * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
| * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF |
| * THE POSSIBILITY OF SUCH DAMAGE. |
| */ |
| |
| #include "config.h" |
| #include "MetalFunctionWriter.h" |
| |
| #include "API.h" |
| #include "AST.h" |
| #include "ASTInterpolateAttribute.h" |
| #include "ASTStringDumper.h" |
| #include "ASTVisitor.h" |
| #include "CallGraph.h" |
| #include "Constraints.h" |
| #include "Types.h" |
| #include "WGSLShaderModule.h" |
| #include <numbers> |
| #include <wtf/HashSet.h> |
| #include <wtf/SetForScope.h> |
| #include <wtf/SortedArrayMap.h> |
| #include <wtf/text/MakeString.h> |
| #include <wtf/text/StringBuilder.h> |
| |
| namespace WGSL { |
| |
| namespace Metal { |
| |
| #define DECLARE_FORWARD_PROGRESS "volatile uint32_t __wgslEnsureForwardProgress = 0; if (!__wgslEnsureForwardProgress)" |
| #define CHECK_FORWARD_PROGRESS "if (++__wgslEnsureForwardProgress == 4294967295u) break;" |
| |
| #define STRINGIFY_(__x) #__x##_s |
| #define STRINGIFY(__x) STRINGIFY_(__x) |
| |
| #define DEFINE_HELPER(__name, ...) \ |
| void emit##__name() { m_output.append(STRINGIFY(__VA_ARGS__)); } \ |
| bool didEmit##__name { false }; |
| |
| #define DEFINE_BOUND_HELPER_RENAMED(__name, __capitalizedName, __metalFunction, __lowerBound, __upperBound, ...) \ |
| DEFINE_HELPER(__capitalizedName, \ |
| template <typename T> \ |
| T __wgsl##__capitalizedName(T value) \ |
| { \ |
| return __metalFunction(select(value, T(0), value < T(__lowerBound) || value > T(__upperBound))); \ |
| }) |
| |
| #define DEFINE_BOUND_HELPER(__name, __capitalizedName, __lowerBound, __upperBound, ...) \ |
| DEFINE_BOUND_HELPER_RENAMED(__name, __capitalizedName, __name, __lowerBound, __upperBound, __VA_ARGS__) |
| |
| #define DEFINE_VOLATILE_BOUND_HELPER_RENAMED(__name, __capitalizedName, __metalFunction, __lowerBound, __upperBound, ...) \ |
| DEFINE_HELPER(__capitalizedName, \ |
| template <typename T> \ |
| T __wgsl##__capitalizedName(T value) \ |
| { \ |
| if constexpr(__wgslMetalAppleGPUFamily < 9) { \n\ |
| volatile auto result = __metalFunction(select(value, T(0), value < T(__lowerBound) || value > T(__upperBound))); \ |
| return result; \ |
| } else { \n\ |
| return __metalFunction(select(value, T(0), value < T(__lowerBound) || value > T(__upperBound))); \ |
| }\n \ |
| }) |
| |
| #define DEFINE_VOLATILE_BOUND_HELPER(__name, __capitalizedName, __lowerBound, __upperBound, ...) \ |
| DEFINE_VOLATILE_BOUND_HELPER_RENAMED(__name, __capitalizedName, __name, __lowerBound, __upperBound, __VA_ARGS__) |
| |
| #define DEFINE_VOLATILE_HELPER_RENAMED(__name, __capitalizedName) \ |
| DEFINE_HELPER(__capitalizedName, \ |
| template <typename T>\n \ |
| auto __wgsl##__capitalizedName(T value)\n \ |
| {\n \ |
| if constexpr(__wgslMetalAppleGPUFamily < 9) { \n\ |
| volatile auto result = __name(value);\n \ |
| return result;\n \ |
| } else { \n\ |
| auto result = __name(value);\n \ |
| return result;\n \ |
| }\n \ |
| }\n) |
| |
| #define DEFINE_VOLATILE_HELPER(__name, __capitalizedName) \ |
| DEFINE_VOLATILE_HELPER_RENAMED(__name, __capitalizedName) |
| |
| struct HelperGenerator { |
| StringBuilder& m_output; |
| |
| HelperGenerator(StringBuilder& output) |
| : m_output(output) |
| { |
| } |
| |
| DEFINE_BOUND_HELPER(acos, Acos, -1, 1) |
| DEFINE_BOUND_HELPER(asin, Asin, -1, 1) |
| DEFINE_BOUND_HELPER(acosh, Acosh, 1, numeric_limits<T>::max()) |
| DEFINE_BOUND_HELPER(atanh, Atanh, -1, 1) |
| DEFINE_BOUND_HELPER_RENAMED(inverseSqrt, InverseSqrt, rsqrt, 0, numeric_limits<T>::infinity()) |
| DEFINE_VOLATILE_BOUND_HELPER(log, Log, 0, numeric_limits<T>::infinity()) |
| DEFINE_BOUND_HELPER(log2, Log2, 0, numeric_limits<T>::infinity()) |
| DEFINE_BOUND_HELPER(sqrt, Sqrt, 0, numeric_limits<T>::infinity()) |
| DEFINE_VOLATILE_HELPER(pack_float_to_snorm2x16, PackFloatToSnorm2x16) |
| DEFINE_VOLATILE_HELPER(pack_float_to_unorm2x16, PackFloatToUnorm2x16) |
| DEFINE_VOLATILE_HELPER(pack_float_to_snorm4x8, PackFloatToSnorm4x8) |
| DEFINE_VOLATILE_HELPER(pack_float_to_unorm4x8, PackFloatToUnorm4x8) |
| |
| }; |
| |
| #undef DEFINE_TRIG_HELPER |
| #undef DEFINE_HELPER |
| #undef STRINGIFY |
| #undef STRINGIFY_ |
| |
| |
| class FunctionDefinitionWriter : public AST::Visitor { |
| public: |
| FunctionDefinitionWriter(ShaderModule& shaderModule, StringBuilder& stringBuilder, PrepareResult& prepareResult, const HashMap<String, ConstantValue>& constantValues, DeviceState&& deviceState) |
| : m_helperGenerator(stringBuilder) |
| , m_output(stringBuilder) |
| , m_shaderModule(shaderModule) |
| , m_prepareResult(prepareResult) |
| , m_constantValues(constantValues) |
| , m_deviceState(WTF::move(deviceState)) |
| { |
| } |
| |
| virtual ~FunctionDefinitionWriter() = default; |
| |
| using AST::Visitor::visit; |
| |
| void write(); |
| |
| void visit(AST::Attribute&) override; |
| void visit(AST::BuiltinAttribute&) override; |
| void visit(AST::LocationAttribute&) override; |
| void visit(AST::StageAttribute&) override; |
| void visit(AST::GroupAttribute&) override; |
| void visit(AST::BindingAttribute&) override; |
| void visit(AST::WorkgroupSizeAttribute&) override; |
| void visit(AST::SizeAttribute&) override; |
| void visit(AST::AlignAttribute&) override; |
| void visit(AST::InterpolateAttribute&) override; |
| void visit(AST::InvariantAttribute&) override; |
| |
| void visit(AST::Function&) override; |
| void visit(AST::Structure&) override; |
| void visit(AST::Variable&) override; |
| void visit(AST::ConstAssert&) override; |
| |
| void visit(const Type*, AST::Expression&); |
| void visit(const Type*, AST::CallExpression&); |
| |
| void visit(AST::BoolLiteral&) override; |
| void visit(AST::AbstractFloatLiteral&) override; |
| void visit(AST::AbstractIntegerLiteral&) override; |
| void visit(AST::BinaryExpression&) override; |
| void visit(AST::Expression&) override; |
| void visit(AST::FieldAccessExpression&) override; |
| void visit(AST::Float32Literal&) override; |
| void visit(AST::Float16Literal&) override; |
| void visit(AST::IdentifierExpression&) override; |
| void visit(AST::IndexAccessExpression&) override; |
| void visit(AST::PointerDereferenceExpression&) override; |
| void visit(AST::UnaryExpression&) override; |
| void visit(AST::Signed32Literal&) override; |
| void visit(AST::Unsigned32Literal&) override; |
| |
| void visit(AST::Statement&) override; |
| void visit(AST::AssignmentStatement&) override; |
| void visit(AST::CallStatement&) override; |
| void visit(AST::CompoundAssignmentStatement&) override; |
| void visit(AST::CompoundStatement&) override; |
| void visit(AST::DecrementIncrementStatement&) override; |
| void visit(AST::DiscardStatement&) override; |
| void visit(AST::IfStatement&) override; |
| void visit(AST::PhonyAssignmentStatement&) override; |
| void visit(AST::ReturnStatement&) override; |
| void visit(AST::ForStatement&) override; |
| void visit(AST::LoopStatement&) override; |
| void visit(AST::Continuing&) override; |
| void visit(AST::WhileStatement&) override; |
| void visit(AST::SwitchStatement&) override; |
| void visit(AST::BreakStatement&) override; |
| void visit(AST::ContinueStatement&) override; |
| |
| void visit(AST::Parameter&) override; |
| void visitArgumentBufferParameter(AST::Parameter&); |
| |
| void visit(const Type*, bool shouldPack = false); |
| |
| StringBuilder& stringBuilder() { return m_body; } |
| Indentation<4>& indent() { return m_indent; } |
| unsigned metalAppleGPUFamily() const { return m_deviceState.appleGPUFamily; } |
| bool shaderValidationEnabled() const { return m_deviceState.shaderValidationEnabled; } |
| |
| private: |
| void emitNecessaryHelpers(); |
| void serializeVariable(AST::Variable&); |
| void generatePackingHelpers(AST::Structure&); |
| bool emitPackedVector(const Types::Vector&, bool shouldPack); |
| |
| bool outlineConstant(const Type*, AST::Expression&); |
| void serializeConstant(const Type*, ConstantValue); |
| void serializeBinaryExpression(AST::Expression&, AST::BinaryOperation, AST::Expression&); |
| void visitStatements(AST::Statement::List&); |
| bool shouldPackType() const; |
| |
| HelperGenerator m_helperGenerator; |
| StringBuilder m_body; |
| StringBuilder m_constants; |
| StringBuilder& m_output; |
| ShaderModule& m_shaderModule; |
| Indentation<4> m_indent { 0 }; |
| uint32_t m_constID { 0 }; |
| std::optional<AST::StructureRole> m_structRole; |
| std::optional<AST::VariableRole> m_variableRole; |
| std::optional<AST::ParameterRole> m_parameterRole; |
| std::optional<ShaderStage> m_entryPointStage; |
| AST::Function* m_currentFunction { nullptr }; |
| AST::Continuing*m_continuing { nullptr }; |
| HashSet<AST::Function*> m_visitedFunctions; |
| PrepareResult& m_prepareResult; |
| const HashMap<String, ConstantValue>& m_constantValues; |
| DeviceState m_deviceState; |
| }; |
| |
| static ASCIILiteral serializeAddressSpace(AddressSpace addressSpace) |
| { |
| switch (addressSpace) { |
| case AddressSpace::Function: |
| case AddressSpace::Private: |
| return "thread"_s; |
| case AddressSpace::Workgroup: |
| return "threadgroup"_s; |
| case AddressSpace::Uniform: |
| return "constant"_s; |
| case AddressSpace::Storage: |
| return "device"_s; |
| case AddressSpace::Handle: |
| return { }; |
| } |
| } |
| |
| void FunctionDefinitionWriter::write() |
| { |
| emitNecessaryHelpers(); |
| |
| for (auto& declaration : m_shaderModule.declarations()) { |
| if (auto* structure = dynamicDowncast<AST::Structure>(declaration)) |
| visit(*structure); |
| } |
| |
| for (auto& declaration : m_shaderModule.declarations()) { |
| if (auto* structure = dynamicDowncast<AST::Structure>(declaration)) |
| generatePackingHelpers(*structure); |
| } |
| |
| m_output.append(m_body); |
| m_body.clear(); |
| |
| for (auto& entryPoint : m_shaderModule.callGraph().entrypoints()) { |
| if (m_prepareResult.entryPoints.contains(entryPoint.originalName)) |
| visit(entryPoint.function); |
| } |
| |
| m_output.append(m_constants); |
| m_output.append(m_body); |
| } |
| |
| void FunctionDefinitionWriter::emitNecessaryHelpers() |
| { |
| m_output.append("template<typename T>\n"_s, |
| "struct __UnpackedTypeImpl;\n\n"_s, |
| "template<typename T>\n"_s, |
| "struct __PackedTypeImpl;\n\n"_s, |
| "template<typename T>\n"_s, |
| "using __UnpackedType = typename __UnpackedTypeImpl<T>::Type;\n\n"_s, |
| "template<typename T>\n"_s, |
| "using __PackedType = typename __PackedTypeImpl<T>::Type;\n\n"_s); |
| |
| if (m_shaderModule.usesPackedVec3()) { |
| m_output.append( |
| m_indent, "template<typename T>\n"_s, |
| m_indent, "struct PackedVec3 {\n"_s |
| ); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append( |
| m_indent, "union { T x; T r; };\n"_s, |
| m_indent, "union { T y; T g; };\n"_s, |
| m_indent, "union { T z; T b; };\n"_s, |
| m_indent, "uint8_t __padding[sizeof(T)];\n"_s, |
| m_indent, "\n"_s, |
| m_indent, "PackedVec3() { }\n"_s, |
| m_indent, "\n"_s, |
| m_indent, "PackedVec3(packed_vec<T, 3> v) : x(v.x), y(v.y), z(v.z) { }\n"_s, |
| m_indent, "\n"_s, |
| m_indent, "operator vec<T, 3>() { return vec<T, 3>(x, y, z); }\n"_s, |
| m_indent, "operator packed_vec<T, 3>() { return packed_vec<T, 3>(x, y, z); }\n"_s, |
| m_indent, "\n"_s, |
| m_indent, "T operator[](int i) const { return i ? i == 2 ? z : y : x; }\n"_s, |
| m_indent, "device T& operator[](int i) device { return i ? i == 2 ? z : y : x; }\n"_s, |
| m_indent, "constant T& operator[](int i) constant { return i ? i == 2 ? z : y : x; }\n"_s, |
| m_indent, "thread T& operator[](int i) thread { return i ? i == 2 ? z : y : x; }\n"_s, |
| m_indent, "threadgroup T& operator[](int i) threadgroup { return i ? i == 2 ? z : y : x; }\n"_s |
| ); |
| } |
| m_output.append(m_indent, "};\n\n"_s); |
| } |
| |
| if (m_shaderModule.usesExternalTextures()) { |
| m_shaderModule.clearUsesExternalTextures(); |
| m_output.append("struct texture_external {\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "texture2d<float> FirstPlane;\n"_s, |
| m_indent, "texture2d<float> SecondPlane;\n"_s, |
| m_indent, "float3x2 UVRemapMatrix;\n"_s, |
| m_indent, "float4x3 ColorSpaceConversionMatrix;\n"_s, |
| m_indent, "uint get_width(uint lod = 0) const { return FirstPlane.get_width(lod); }\n"_s, |
| m_indent, "uint get_height(uint lod = 0) const { return FirstPlane.get_height(lod); }\n"_s); |
| } |
| m_output.append("};\n\n"_s); |
| } |
| |
| if (m_shaderModule.usesPackArray()) { |
| m_shaderModule.clearUsesPackArray(); |
| |
| m_output.append(m_indent, "template<typename T, size_t N>\n"_s, |
| m_indent, "struct __PackedTypeImpl<array<T, N>> {\n"_s, |
| m_indent, "using Type = array<__PackedType<T>, N>;\n"_s, |
| m_indent, "};\n\n"_s); |
| |
| if (m_shaderModule.usesPackedVec3()) { |
| m_output.append(m_indent, "template<typename T, size_t N>\n"_s, |
| m_indent, "struct __PackedTypeImpl<array<vec<T, 3>, N>> {"_s, |
| m_indent, "using Type = array<PackedVec3<T>, N>;"_s, |
| m_indent, "};\n\n"_s); |
| |
| m_output.append(m_indent, "template<typename T, size_t N>\n"_s, |
| m_indent, "static __attribute__((always_inline)) array<PackedVec3<T>, N> __pack(array<vec<T, 3>, N> unpacked)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "array<PackedVec3<T>, N> packed;\n"_s, |
| m_indent, "for (size_t i = 0; i < N; ++i)\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "packed[i] = PackedVec3<T>(unpacked[i]);\n"_s); |
| } |
| m_output.append(m_indent, "return packed;\n"_s); |
| } |
| m_output.append(m_indent, "}\n\n"_s); |
| } |
| |
| m_output.append(m_indent, "template<typename T, size_t N>\n"_s, |
| m_indent, "static __attribute__((always_inline)) array<__PackedType<T>, N> __pack(array<T, N> unpacked)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "array<__PackedType<T>, N> packed;\n"_s, |
| m_indent, "for (size_t i = 0; i < N; ++i)\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "packed[i] = __pack(unpacked[i]);\n"_s); |
| } |
| m_output.append(m_indent, "return packed;\n"_s); |
| } |
| m_output.append(m_indent, "}\n\n"_s); |
| |
| } |
| |
| if (m_shaderModule.usesUnpackArray()) { |
| m_shaderModule.clearUsesUnpackArray(); |
| |
| m_output.append(m_indent, "template<typename T, size_t N>\n"_s, |
| m_indent, "struct __UnpackedTypeImpl<array<T, N>> {\n"_s, |
| m_indent, "using Type = array<__UnpackedType<T>, N>;\n"_s, |
| m_indent, "};\n\n"_s); |
| |
| if (m_shaderModule.usesPackedVec3()) { |
| m_output.append(m_indent, "template<typename T, size_t N>\n"_s, |
| m_indent, "struct __UnpackedTypeImpl<array<PackedVec3<T>, N>> {"_s, |
| m_indent, "using Type = array<vec<T, 3>, N>;"_s, |
| m_indent, "};\n\n"_s); |
| |
| m_output.append(m_indent, "template<typename T, size_t N>\n"_s, |
| m_indent, "static __attribute__((always_inline)) array<vec<T, 3>, N> __unpack(array<PackedVec3<T>, N> packed)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "array<vec<T, 3>, N> unpacked;\n"_s, |
| m_indent, "for (size_t i = 0; i < N; ++i)\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "unpacked[i] = vec<T, 3>(packed[i]);\n"_s); |
| } |
| m_output.append(m_indent, "return unpacked;\n"_s); |
| } |
| m_output.append(m_indent, "}\n\n"_s); |
| } |
| |
| m_output.append(m_indent, "template<typename T, size_t N>\n"_s, |
| m_indent, "static __attribute__((always_inline)) array<__UnpackedType<T>, N> __unpack(array<T, N> packed)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "array<__UnpackedType<T>, N> unpacked;\n"_s, |
| m_indent, "for (size_t i = 0; i < N; ++i)\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "unpacked[i] = __unpack(packed[i]);\n"_s); |
| } |
| m_output.append(m_indent, "return unpacked;\n"_s); |
| } |
| m_output.append(m_indent, "}\n\n"_s); |
| } |
| |
| if (m_shaderModule.usesPackVector()) { |
| m_shaderModule.clearUsesPackVector(); |
| m_output.append(m_indent, "template<typename T>\n"_s, |
| m_indent, "static __attribute__((always_inline)) packed_vec<T, 3> __pack(vec<T, 3> unpacked) { return unpacked; }\n\n"_s); |
| } |
| |
| if (m_shaderModule.usesUnpackVector()) { |
| m_shaderModule.clearUsesUnpackVector(); |
| m_output.append(m_indent, "template<typename T>\n"_s, |
| m_indent, "static __attribute__((always_inline)) vec<T, 3> __unpack(packed_vec<T, 3> packed) { return packed; }\n\n"_s); |
| |
| if (m_shaderModule.usesPackedVec3()) { |
| m_output.append(m_indent, "template<typename T>\n"_s, |
| m_indent, "static vec<T, 3> __unpack(PackedVec3<T> packed) { return packed; }\n\n"_s); |
| } |
| } |
| |
| if (m_shaderModule.usesWorkgroupUniformLoad()) { |
| m_output.append(m_indent, "template<typename T>\n"_s, |
| m_indent, (shaderValidationEnabled() ? "[[clang::optnone]] "_s : ""_s), "static T __workgroup_uniform_load(threadgroup T* const ptr)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "threadgroup_barrier(mem_flags::mem_threadgroup);\n"_s, |
| m_indent, "auto result = *ptr;\n"_s, |
| m_indent, "threadgroup_barrier(mem_flags::mem_threadgroup);\n"_s, |
| m_indent, "return result;\n"_s); |
| } |
| m_output.append(m_indent, "}\n\n"_s); |
| } |
| |
| if (m_shaderModule.usesDivision()) { |
| m_output.append(m_indent, "template<typename T, typename U, typename V = conditional_t<is_scalar_v<U>, T, U>>\n"_s, |
| m_indent, "static V __wgslDiv(T lhs, U rhs)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "auto predicate = V(rhs) == V(0);\n"_s, |
| m_indent, "if constexpr (is_signed_v<U>)\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "predicate = predicate || (V(lhs) == V(numeric_limits<T>::lowest()) && V(rhs) == V(-1));\n"_s); |
| } |
| m_output.append(m_indent, "return lhs / select(V(rhs), V(1), predicate);\n"_s); |
| } |
| m_output.append(m_indent, "}\n\n"_s); |
| } |
| |
| if (m_shaderModule.usesModulo()) { |
| m_output.append(m_indent, "template<typename T, typename U, typename V = conditional_t<is_scalar_v<U>, T, U>>\n"_s, |
| m_indent, "static V __wgslMod(T lhs, U rhs)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "auto predicate = V(rhs) == V(0);\n"_s, |
| m_indent, "if constexpr (is_signed_v<U>)\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "predicate = predicate || (V(lhs) == V(numeric_limits<T>::lowest()) && V(rhs) == V(-1));\n"_s); |
| } |
| m_output.append(m_indent, "return select(lhs % V(rhs), V(0), predicate);\n"_s); |
| } |
| m_output.append(m_indent, "}\n\n"_s); |
| } |
| |
| |
| if (m_shaderModule.usesFrexp()) { |
| m_output.append(m_indent, "template<typename T, typename U>\n"_s, |
| m_indent, "struct __frexp_result {\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "T fract;\n"_s, |
| m_indent, "U exp;\n"_s); |
| } |
| m_output.append(m_indent, "};\n\n"_s, |
| m_indent, "template<typename T, typename U = conditional_t<is_vector_v<T>, vec<int, vec_elements<T>::value ?: 2>, int>>\n"_s, |
| m_indent, "static __frexp_result<T, U> __wgslFrexp(T value)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "__frexp_result<T, U> result;\n"_s, |
| m_indent, "result.fract = frexp(value, result.exp);\n"_s, |
| m_indent, "return result;\n"_s); |
| } |
| m_output.append(m_indent, "}\n\n"_s); |
| } |
| |
| if (m_shaderModule.usesModf()) { |
| m_output.append(m_indent, "template<typename T, typename U>\n"_s, |
| m_indent, "struct __modf_result {\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "T fract;\n"_s, |
| m_indent, "U whole;\n"_s); |
| } |
| m_output.append(m_indent, "};\n\n"_s, |
| m_indent, "template<typename T>\n"_s, |
| m_indent, "static __modf_result<T, T> __wgslModf(T value)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "__modf_result<T, T> result;\n"_s, |
| m_indent, "result.fract = modf(value, result.whole);\n"_s, |
| m_indent, "return result;\n"_s); |
| } |
| m_output.append(m_indent, "}\n\n"_s); |
| } |
| |
| if (m_shaderModule.usesAtomicCompareExchange()) { |
| m_output.append(m_indent, "template<typename T, typename U = bool>\n"_s, |
| m_indent, "struct __atomic_compare_exchange_result {\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "T old_value;\n"_s, |
| m_indent, "U exchanged;\n"_s); |
| } |
| m_output.append(m_indent, "};\n\n"_s, |
| m_indent, "template<typename T, typename S, typename V> __atomic_compare_exchange_result<S> __wgslAtomicCompareExchangeWeak(T atomic1, S compare, V value) {\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "auto innerCompare = compare; \n"_s, |
| m_indent, "bool exchanged = atomic_compare_exchange_weak_explicit(atomic1, &innerCompare, value, memory_order_relaxed, memory_order_relaxed); \n"_s, |
| m_indent, "return __atomic_compare_exchange_result<decltype(compare)> { innerCompare, exchanged }; \\\n"_s, |
| m_indent, "}\n"_s); |
| } |
| } |
| |
| if (m_shaderModule.usesDot()) { |
| m_output.append(m_indent, "template<typename T, unsigned N>\n"_s, |
| m_indent, "static T __wgslDot(vec<T, N> lhs, vec<T, N> rhs)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "auto result = lhs[0] * rhs[0] + lhs[1] * rhs[1];\n"_s, |
| m_indent, "if constexpr (N > 2) result += lhs[2] * rhs[2];\n"_s, |
| m_indent, "if constexpr (N > 3) result += lhs[3] * rhs[3];\n"_s, |
| m_indent, "return result;\n"_s); |
| } |
| m_output.append(m_indent, "}\n"_s); |
| } |
| |
| if (m_shaderModule.usesDot4I8Packed()) { |
| m_output.append(m_indent, "static int __wgslDot4I8Packed(uint lhs, uint rhs)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "auto vec1 = as_type<packed_char4>(lhs);"_s, |
| m_indent, "auto vec2 = as_type<packed_char4>(rhs);"_s, |
| m_indent, "return vec1[0] * vec2[0] + vec1[1] * vec2[1] + vec1[2] * vec2[2] + vec1[3] * vec2[3];"_s); |
| } |
| m_output.append(m_indent, "}\n"_s); |
| } |
| |
| if (m_shaderModule.usesDot4U8Packed()) { |
| m_output.append(m_indent, "static uint __wgslDot4U8Packed(uint lhs, uint rhs)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "auto vec1 = as_type<packed_uchar4>(lhs);"_s, |
| m_indent, "auto vec2 = as_type<packed_uchar4>(rhs);"_s, |
| m_indent, "return vec1[0] * vec2[0] + vec1[1] * vec2[1] + vec1[2] * vec2[2] + vec1[3] * vec2[3];"_s); |
| } |
| m_output.append(m_indent, "}\n"_s); |
| } |
| |
| if (m_shaderModule.usesFirstLeadingBit()) { |
| m_output.append(m_indent, "template<typename T>\n"_s, |
| m_indent, "static T __wgslFirstLeadingBit(T e)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "if constexpr (is_signed_v<T>)\n"_s, |
| m_indent, " return select(T(31 - select(clz(e), clz(~e), e < T(0))), T(-1), e == T(0) || e == T(-1));\n"_s, |
| m_indent, "else\n"_s, |
| m_indent, " return select(T(31 - clz(e)), T(-1), e == T(0));\n"_s); |
| } |
| m_output.append(m_indent, "}\n"_s); |
| } |
| |
| if (m_shaderModule.usesFirstTrailingBit()) { |
| m_output.append(m_indent, "template<typename T>\n"_s, |
| m_indent, "static T __wgslFirstTrailingBit(T e)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "return select(ctz(e), T(-1), e == T(0));\n"_s); |
| } |
| m_output.append(m_indent, "}\n"_s); |
| } |
| |
| if (m_shaderModule.usesSign()) { |
| m_output.append(m_indent, "template<typename T>\n"_s, |
| m_indent, "static T __wgslSign(T e)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "return select(select(T(-1), T(1), e > 0), T(0), e == 0);\n"_s); |
| } |
| m_output.append(m_indent, "}\n"_s); |
| } |
| |
| if (m_shaderModule.usesExtractBits()) { |
| m_output.append(m_indent, "template<typename T>\n"_s, |
| m_indent, "static T __wgslExtractBits(T e, uint offset, uint count)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "auto o = min(offset, 32u);\n"_s, |
| m_indent, "auto c = min(count, 32u - o);\n"_s, |
| m_indent, "return select((T)0, extract_bits(e, min(o, 31u), c), c);\n"_s); |
| } |
| m_output.append(m_indent, "}\n"_s); |
| } |
| |
| if (m_shaderModule.usesMin()) { |
| m_output.append(m_indent, "static uint __attribute__((always_inline)) __wgslMin(uint a, uint b)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append("return min(a, b);\n"_s); |
| } |
| m_output.append(m_indent, "}\n\n"_s); |
| } |
| |
| if (m_shaderModule.usesFtoi()) { |
| m_output.append(m_indent, "template <typename T, typename S>\n"_s, |
| m_indent, "static T __wgslFtoi(S value)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "if constexpr (is_same_v<make_scalar_t<S>, half>)\n"_s); |
| m_output.append(m_indent, "return T(select(clamp(value, max(S(numeric_limits<T>::min()), numeric_limits<S>::lowest()), numeric_limits<S>::max()), S(0), isnan(value)));\n"_s); |
| m_output.append(m_indent, "else\n"_s); |
| m_output.append(m_indent, "return T(select(clamp(value, S(numeric_limits<T>::min()), S(numeric_limits<T>::max() - ((128 << (!is_signed_v<T>)) - 1))), S(0), isnan(value)));\n"_s); |
| } |
| m_output.append(m_indent, "}\n\n"_s); |
| } |
| |
| if (m_shaderModule.usesInsertBits()) { |
| m_output.append(m_indent, "template <typename T>\n"_s, |
| m_indent, "static T __wgslInsertBits(T e, T newBits, unsigned offset, unsigned count)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_output.append(m_indent, "constexpr unsigned w = 8 * static_cast<unsigned>(sizeof(make_scalar_t<T>));\n"_s); |
| m_output.append(m_indent, "const unsigned o = min(offset, w);\n"_s); |
| m_output.append(m_indent, "const unsigned c = min(count, w - o);\n"_s); |
| m_output.append(m_indent, "return insert_bits(e, newBits, min(o, w - 1), c);\n"_s); |
| } |
| m_output.append(m_indent, "}\n\n"_s); |
| } |
| |
| m_shaderModule.clearUsesPackedVec3(); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::Function& functionDefinition) |
| { |
| if (!m_visitedFunctions.add(&functionDefinition).isNewEntry) |
| return; |
| |
| for (auto& callee : m_shaderModule.callGraph().callees(functionDefinition)) |
| visit(*callee.target); |
| |
| for (auto& attribute : functionDefinition.attributes()) { |
| checkErrorAndVisit(attribute); |
| m_body.append(' '); |
| } |
| |
| if (functionDefinition.maybeReturnType()) |
| visit(functionDefinition.maybeReturnType()->inferredType()); |
| else |
| m_body.append("void"_s); |
| |
| m_body.append(' ', functionDefinition.name(), '('); |
| bool first = true; |
| for (auto& parameter : functionDefinition.parameters()) { |
| if (!first) |
| m_body.append(", "_s); |
| switch (parameter.role()) { |
| case AST::ParameterRole::UserDefined: |
| case AST::ParameterRole::PackedResource: |
| checkErrorAndVisit(parameter); |
| break; |
| case AST::ParameterRole::StageIn: |
| checkErrorAndVisit(parameter); |
| m_body.append(" [[stage_in]]"_s); |
| break; |
| case AST::ParameterRole::BindGroup: |
| visitArgumentBufferParameter(parameter); |
| break; |
| } |
| first = false; |
| } |
| |
| // Clear the flag set while serializing StageAttribute |
| m_entryPointStage = std::nullopt; |
| |
| m_currentFunction = &functionDefinition; |
| m_body.append(")\n"_s); |
| checkErrorAndVisit(functionDefinition.body()); |
| |
| m_body.append("\n\n"_s); |
| |
| m_currentFunction = nullptr; |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::Structure& structDecl) |
| { |
| m_structRole = { structDecl.role() }; |
| m_body.append(m_indent, "struct "_s, structDecl.name(), " {\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| unsigned paddingID = 0; |
| bool shouldPack = structDecl.role() == AST::StructureRole::PackedResource; |
| const auto& addPadding = [&](unsigned paddingSize) { |
| ASSERT(shouldPack); |
| m_body.append(m_indent, "uint8_t __padding"_s, ++paddingID, '[', String::number(paddingSize), "]; \n"_s); |
| }; |
| |
| for (auto& member : structDecl.members()) { |
| auto& name = member.name(); |
| auto* type = member.type().inferredType(); |
| if (isPrimitive(type, Types::Primitive::TextureExternal) || isPrimitiveReference(type, Types::Primitive::TextureExternal)) { |
| decltype(std::declval<ConstantValue>().integerValue()) bindingIndex = 0; |
| for (auto& attribute : member.attributes()) { |
| if (auto* bindingAttribute = dynamicDowncast<AST::BindingAttribute>(attribute)) { |
| if (auto bindingIndexValue = bindingAttribute->binding().constantValue()) { |
| bindingIndex = bindingIndexValue->integerValue(); |
| break; |
| } |
| } |
| } |
| m_body.append(m_indent, "texture2d<float> __"_s, name, "_FirstPlane [[id("_s, bindingIndex, ")]];\n"_s, |
| m_indent, "texture2d<float> __"_s, name, "_SecondPlane [[id("_s, (bindingIndex + 1), ")]];\n"_s, |
| m_indent, "float3x2 __"_s, name, "_UVRemapMatrix [[id("_s, (bindingIndex + 2), ")]];\n"_s, |
| m_indent, "float4x3 __"_s, name, "_ColorSpaceConversionMatrix [[id("_s, (bindingIndex + 3), ")]];\n"_s); |
| continue; |
| } |
| |
| m_body.append(m_indent); |
| visit(member.type().inferredType()); |
| m_body.append(' ', name); |
| for (auto &attribute : member.attributes()) { |
| m_body.append(' '); |
| visit(attribute); |
| } |
| m_body.append(";\n"_s); |
| |
| if (shouldPack && member.padding()) |
| addPadding(member.padding()); |
| } |
| |
| if (structDecl.role() == AST::StructureRole::VertexOutput || structDecl.role() == AST::StructureRole::FragmentOutput) { |
| m_body.append('\n', m_indent, "template<typename T>\n"_s, |
| m_indent, structDecl.name(), "(const thread T& other)\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| char prefix = ':'; |
| for (auto& member : structDecl.members()) { |
| auto& name = member.name(); |
| m_body.append(m_indent, prefix, ' ', name, "(other."_s, name, ")\n"_s); |
| prefix = ','; |
| } |
| } |
| m_body.append(m_indent, "{ }\n"_s); |
| } else if (structDecl.role() == AST::StructureRole::FragmentOutputWrapper || structDecl.role() == AST::StructureRole::VertexOutputWrapper) { |
| ASSERT(structDecl.members().size() == 1); |
| auto& member = structDecl.members()[0]; |
| |
| m_body.append('\n', m_indent, "template<typename T>\n"_s, |
| m_indent, structDecl.name(), "(T value)\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_body.append(m_indent, ": "_s, member.name(), "(value)\n"_s); |
| } |
| m_body.append(m_indent, "{ }\n"_s); |
| } |
| } |
| m_body.append(m_indent, "};\n\n"_s); |
| |
| if (structDecl.role() == AST::StructureRole::PackedResource) { |
| m_body.append(m_indent, "template<> struct __PackedTypeImpl<"_s, structDecl.original()->name(), "> { using Type = "_s, structDecl.name(), "; };\n"_s); |
| m_body.append(m_indent, "template<> struct __UnpackedTypeImpl<"_s, structDecl.name(), "> { using Type = "_s, structDecl.original()->name(), "; };\n\n"_s); |
| } |
| |
| m_structRole = std::nullopt; |
| |
| if (structDecl.role() == AST::StructureRole::BindGroup) { |
| for (auto& member : structDecl.members()) { |
| auto* type = member.type().inferredType(); |
| if (auto* reference = std::get_if<Types::Reference>(type)) |
| type = reference->element; |
| if (auto maybeSize = type->maybeSize(); maybeSize && *maybeSize < std::numeric_limits<unsigned>::max()) |
| m_body.append(m_indent, "static_assert(sizeof("_s, structDecl.name(), "::"_s, member.name(), ") == "_s, *maybeSize, ");\n\n"_s); |
| } |
| } |
| |
| } |
| |
| void FunctionDefinitionWriter::generatePackingHelpers(AST::Structure& structure) |
| { |
| if (structure.role() != AST::StructureRole::PackedResource || !structure.inferredType()->isConstructible()) |
| return; |
| |
| const String& packedName = structure.name(); |
| auto unpackedName = structure.original()->name(); |
| |
| m_body.append(m_indent, "static __attribute__((always_inline)) "_s, packedName, " __pack("_s, unpackedName, " unpacked)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_body.append(m_indent, packedName, " packed;\n"_s); |
| for (auto& member : structure.members()) { |
| auto& name = member.name(); |
| if (member.type().inferredType()->packing() & (Packing::PStruct | Packing::PArray)) |
| m_body.append(m_indent, "packed."_s, name, " = __pack(unpacked."_s, name, ");\n"_s); |
| else |
| m_body.append(m_indent, "packed."_s, name, " = unpacked."_s, name, ";\n"_s); |
| } |
| m_body.append(m_indent, "return packed;\n"_s); |
| } |
| m_body.append(m_indent, "}\n\n"_s, |
| m_indent, "static "_s, unpackedName, " __unpack("_s, packedName, " packed)\n"_s, |
| m_indent, "{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| m_body.append(m_indent, unpackedName, " unpacked;\n"_s); |
| for (auto& member : structure.members()) { |
| auto& name = member.name(); |
| if (member.type().inferredType()->packing() & (Packing::PStruct | Packing::PArray)) |
| m_body.append(m_indent, "unpacked."_s, name, " = __unpack(packed."_s, name, ");\n"_s); |
| else |
| m_body.append(m_indent, "unpacked."_s, name, " = packed."_s, name, ";\n"_s); |
| } |
| m_body.append(m_indent, "return unpacked;\n"_s); |
| } |
| m_body.append(m_indent, "}\n\n"_s); |
| } |
| |
| bool FunctionDefinitionWriter::shouldPackType() const |
| { |
| if (m_structRole.has_value() && (*m_structRole == AST::StructureRole::PackedResource || *m_structRole == AST::StructureRole::BindGroup)) |
| return true; |
| if (m_variableRole.has_value() && *m_variableRole == AST::VariableRole::PackedResource) |
| return true; |
| if (m_parameterRole.has_value() && (*m_parameterRole == AST::ParameterRole::PackedResource)) |
| return true; |
| return false; |
| } |
| |
| bool FunctionDefinitionWriter::emitPackedVector(const Types::Vector& vector, bool shouldPack) |
| { |
| if (!shouldPack) |
| return false; |
| |
| // The only vectors that need to be packed are the vectors with 3 elements, |
| // because their size differs between Metal and WGSL (4 * element size vs |
| // 3 * element size, respectively) |
| if (vector.size != 3) |
| return false; |
| |
| auto& primitive = std::get<Types::Primitive>(*vector.element); |
| switch (primitive.kind) { |
| case Types::Primitive::AbstractInt: |
| case Types::Primitive::I32: |
| m_body.append("packed_int"_s, vector.size); |
| break; |
| case Types::Primitive::U32: |
| m_body.append("packed_uint"_s, vector.size); |
| break; |
| case Types::Primitive::AbstractFloat: |
| case Types::Primitive::F32: |
| m_body.append("packed_float"_s, vector.size); |
| break; |
| case Types::Primitive::F16: |
| m_body.append("packed_half"_s, vector.size); |
| break; |
| case Types::Primitive::Bool: |
| case Types::Primitive::Void: |
| case Types::Primitive::Sampler: |
| case Types::Primitive::SamplerComparison: |
| case Types::Primitive::TextureExternal: |
| case Types::Primitive::AccessMode: |
| case Types::Primitive::TexelFormat: |
| case Types::Primitive::AddressSpace: |
| RELEASE_ASSERT_NOT_REACHED(); |
| } |
| return true; |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::Variable& variable) |
| { |
| serializeVariable(variable); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::ConstAssert&) |
| { |
| // const_assert should not generate any code |
| } |
| |
| void FunctionDefinitionWriter::serializeVariable(AST::Variable& variable) |
| { |
| if (variable.flavor() == AST::VariableFlavor::Const) |
| return; |
| |
| auto variableRoleScope = SetForScope(m_variableRole, std::optional<AST::VariableRole> { variable.role() }); |
| |
| const Type* type = variable.storeType(); |
| if (isPrimitiveReference(type, Types::Primitive::TextureExternal)) { |
| ASSERT(variable.maybeInitializer()); |
| m_body.append("texture_external "_s, variable.name(), " { "_s); |
| visit(*variable.maybeInitializer()); |
| m_body.append("_FirstPlane, "_s); |
| visit(*variable.maybeInitializer()); |
| m_body.append("_SecondPlane, "_s); |
| visit(*variable.maybeInitializer()); |
| m_body.append("_UVRemapMatrix, "_s); |
| visit(*variable.maybeInitializer()); |
| m_body.append("_ColorSpaceConversionMatrix }"_s); |
| return; |
| } |
| |
| if (auto* qualifier = variable.maybeQualifier()) { |
| switch (qualifier->addressSpace()) { |
| case AddressSpace::Workgroup: |
| m_body.append("threadgroup "_s); |
| break; |
| case AddressSpace::Function: |
| case AddressSpace::Handle: |
| case AddressSpace::Private: |
| case AddressSpace::Storage: |
| case AddressSpace::Uniform: |
| break; |
| } |
| } |
| |
| visit(type); |
| m_body.append(' ', variable.name()); |
| |
| if (variable.flavor() == AST::VariableFlavor::Override) |
| return; |
| |
| if (auto* initializer = variable.maybeInitializer()) { |
| m_body.append(" = "_s); |
| visit(type, *initializer); |
| } else |
| m_body.append(" { }"_s); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::Attribute& attribute) |
| { |
| AST::Visitor::visit(attribute); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::BuiltinAttribute& builtin) |
| { |
| // Built-in attributes are only valid for parameters. If a struct member originally |
| // had a built-in attribute it must have already been hoisted into a parameter, but |
| // we keep the original struct so we can reconstruct it. |
| if (m_structRole.has_value() && *m_structRole != AST::StructureRole::VertexOutput && *m_structRole != AST::StructureRole::VertexOutputWrapper && *m_structRole != AST::StructureRole::FragmentOutput && *m_structRole != AST::StructureRole::FragmentOutputWrapper) |
| return; |
| |
| switch (builtin.builtin()) { |
| case Builtin::FragDepth: |
| m_body.append("[[depth(any)]]"_s); |
| break; |
| case Builtin::FrontFacing: |
| m_body.append("[[front_facing]]"_s); |
| break; |
| case Builtin::GlobalInvocationId: |
| m_body.append("[[thread_position_in_grid]]"_s); |
| break; |
| case Builtin::InstanceIndex: |
| m_body.append("[[instance_id]]"_s); |
| break; |
| break; |
| case Builtin::LocalInvocationId: |
| m_body.append("[[thread_position_in_threadgroup]]"_s); |
| break; |
| case Builtin::LocalInvocationIndex: |
| m_body.append("[[thread_index_in_threadgroup]]"_s); |
| break; |
| case Builtin::NumWorkgroups: |
| m_body.append("[[threadgroups_per_grid]]"_s); |
| break; |
| case Builtin::Position: |
| m_body.append("[[position]]"_s); |
| break; |
| case Builtin::SampleIndex: |
| m_body.append("[[sample_id]]"_s); |
| break; |
| case Builtin::SampleMask: |
| m_body.append("[[sample_mask]]"_s); |
| break; |
| case Builtin::VertexIndex: |
| m_body.append("[[vertex_id]]"_s); |
| break; |
| case Builtin::WorkgroupId: |
| m_body.append("[[threadgroup_position_in_grid]]"_s); |
| break; |
| } |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::StageAttribute& stage) |
| { |
| m_entryPointStage = { stage.stage() }; |
| switch (stage.stage()) { |
| case ShaderStage::Vertex: |
| m_body.append("[[vertex]]"_s); |
| break; |
| case ShaderStage::Fragment: |
| m_body.append("[[fragment]]"_s); |
| break; |
| case ShaderStage::Compute: |
| m_body.append("[[kernel]]"_s); |
| break; |
| } |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::GroupAttribute& group) |
| { |
| unsigned bufferIndex = group.group().constantValue()->integerValue(); |
| if (m_entryPointStage.has_value() && *m_entryPointStage == ShaderStage::Vertex) { |
| ASSERT(m_shaderModule.configuration().maxBuffersPlusVertexBuffersForVertexStage > 0); |
| auto max = m_shaderModule.configuration().maxBuffersPlusVertexBuffersForVertexStage - 1; |
| bufferIndex = vertexBufferIndexForBindGroup(bufferIndex, max); |
| } |
| m_body.append("[[buffer("_s, bufferIndex, ")]]"_s); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::BindingAttribute& binding) |
| { |
| m_body.append("[[id("_s, binding.binding().constantValue()->integerValue(), ")]]"_s); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::LocationAttribute& location) |
| { |
| if (m_structRole.has_value()) { |
| auto role = *m_structRole; |
| switch (role) { |
| case AST::StructureRole::VertexOutput: |
| case AST::StructureRole::FragmentInput: |
| case AST::StructureRole::VertexOutputWrapper: |
| m_body.append("[[user(loc"_s, location.location().constantValue()->integerValue(), ")]]"_s); |
| return; |
| case AST::StructureRole::BindGroup: |
| case AST::StructureRole::UserDefined: |
| case AST::StructureRole::ComputeInput: |
| case AST::StructureRole::UserDefinedResource: |
| case AST::StructureRole::PackedResource: |
| return; |
| case AST::StructureRole::FragmentOutputWrapper: |
| case AST::StructureRole::FragmentOutput: |
| m_body.append("[[color("_s, location.location().constantValue()->integerValue(), ")]]"_s); |
| return; |
| case AST::StructureRole::VertexInput: |
| m_body.append("[[attribute("_s, location.location().constantValue()->integerValue(), ")]]"_s); |
| break; |
| } |
| } |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::WorkgroupSizeAttribute&) |
| { |
| // This attribute shouldn't generate any code. The workgroup size is passed |
| // to the API through the EntryPointInformation. |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::SizeAttribute&) |
| { |
| // This attribute shouldn't generate any code. The size is used when serializing |
| // structs. |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::AlignAttribute&) |
| { |
| // This attribute shouldn't generate any code. The alignment is used when |
| // serializing structs. |
| } |
| |
| static ASCIILiteral convertToSampleMode(InterpolationType type, InterpolationSampling sampleType) |
| { |
| switch (type) { |
| case InterpolationType::Flat: |
| return "flat"_s; |
| case InterpolationType::Linear: |
| switch (sampleType) { |
| case InterpolationSampling::First: |
| case InterpolationSampling::Either: |
| case InterpolationSampling::Center: |
| return "center_no_perspective"_s; |
| case InterpolationSampling::Centroid: |
| return "centroid_no_perspective"_s; |
| case InterpolationSampling::Sample: |
| return "sample_no_perspective"_s; |
| } |
| case InterpolationType::Perspective: |
| switch (sampleType) { |
| case InterpolationSampling::First: |
| case InterpolationSampling::Either: |
| case InterpolationSampling::Center: |
| return "center_perspective"_s; |
| case InterpolationSampling::Centroid: |
| return "centroid_perspective"_s; |
| case InterpolationSampling::Sample: |
| return "sample_perspective"_s; |
| } |
| } |
| |
| ASSERT_NOT_REACHED(); |
| return "flat"_s; |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::InterpolateAttribute& attribute) |
| { |
| m_body.append("[["_s, convertToSampleMode(attribute.type(), attribute.sampling()), "]]"_s); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::InvariantAttribute&) |
| { |
| if (!m_structRole.has_value() || (*m_structRole != AST::StructureRole::VertexOutput && *m_structRole != AST::StructureRole::VertexOutputWrapper)) |
| return; |
| |
| m_body.append("[[invariant]]"_s); |
| } |
| |
| // Types |
| void FunctionDefinitionWriter::visit(const Type* type, bool shouldPack) |
| { |
| using namespace WGSL::Types; |
| |
| shouldPack |= shouldPackType(); |
| |
| WTF::switchOn(*type, |
| [&](const Primitive& primitive) { |
| switch (primitive.kind) { |
| case Types::Primitive::AbstractInt: |
| case Types::Primitive::I32: |
| m_body.append("int"_s); |
| break; |
| case Types::Primitive::U32: |
| m_body.append("unsigned"_s); |
| break; |
| case Types::Primitive::AbstractFloat: |
| case Types::Primitive::F32: |
| m_body.append("float"_s); |
| break; |
| case Types::Primitive::F16: |
| m_body.append("half"_s); |
| break; |
| case Types::Primitive::Void: |
| case Types::Primitive::Bool: |
| case Types::Primitive::Sampler: |
| m_body.append(*type); |
| break; |
| case Types::Primitive::SamplerComparison: |
| m_body.append("sampler"_s); |
| break; |
| case Types::Primitive::TextureExternal: |
| m_body.append("texture_external"_s); |
| break; |
| case Types::Primitive::AccessMode: |
| case Types::Primitive::TexelFormat: |
| case Types::Primitive::AddressSpace: |
| RELEASE_ASSERT_NOT_REACHED(); |
| } |
| }, |
| [&](const Vector& vector) { |
| if (emitPackedVector(vector, shouldPack)) |
| return; |
| m_body.append("vec<"_s); |
| visit(vector.element, shouldPack); |
| m_body.append(", "_s, vector.size, '>'); |
| }, |
| [&](const Matrix& matrix) { |
| m_body.append("matrix<"_s); |
| visit(matrix.element, shouldPack); |
| m_body.append(", "_s, matrix.columns, ", "_s, matrix.rows, '>'); |
| }, |
| [&](const Array& array) { |
| m_body.append("array<"_s); |
| auto* vector = std::get_if<Types::Vector>(array.element); |
| if (vector && vector->size == 3 && shouldPack) { |
| m_body.append("PackedVec3<"_s); |
| visit(vector->element, shouldPack); |
| m_body.append(">"_s); |
| } else |
| visit(array.element, shouldPack); |
| m_body.append(", "_s); |
| WTF::switchOn(array.size, |
| [&](unsigned size) { m_body.append(size); }, |
| [&](std::monostate) { m_body.append(1); }, |
| [&](AST::Expression* size) { |
| visit(*size); |
| }); |
| m_body.append('>'); |
| }, |
| [&](const Struct& structure) { |
| if (shouldPack && structure.structure.role() == AST::StructureRole::UserDefinedResource) |
| m_body.append("__PackedType<"_s, structure.structure.name(), ">"_s); |
| else |
| m_body.append(structure.structure.name()); |
| }, |
| [&](const PrimitiveStruct& structure) { |
| m_body.append(structure.name, '<'); |
| bool first = true; |
| for (auto& value : structure.values) { |
| if (!first) |
| m_body.append(", "_s); |
| first = false; |
| visit(value, shouldPack); |
| } |
| m_body.append('>'); |
| }, |
| [&](const Texture& texture) { |
| ASCIILiteral type; |
| ASCIILiteral access = "sample"_s; |
| switch (texture.kind) { |
| case Types::Texture::Kind::Texture1d: |
| type = "texture1d"_s; |
| break; |
| case Types::Texture::Kind::Texture2d: |
| type = "texture2d"_s; |
| break; |
| case Types::Texture::Kind::Texture2dArray: |
| type = "texture2d_array"_s; |
| break; |
| case Types::Texture::Kind::Texture3d: |
| type = "texture3d"_s; |
| break; |
| case Types::Texture::Kind::TextureCube: |
| type = "texturecube"_s; |
| break; |
| case Types::Texture::Kind::TextureCubeArray: |
| type = "texturecube_array"_s; |
| break; |
| case Types::Texture::Kind::TextureMultisampled2d: |
| type = "texture2d_ms"_s; |
| access = "read"_s; |
| break; |
| } |
| m_body.append(type, '<'); |
| visit(texture.element, shouldPack); |
| m_body.append(", access::"_s, access, '>'); |
| }, |
| [&](const TextureStorage& texture) { |
| ASCIILiteral base; |
| ASCIILiteral mode; |
| switch (texture.kind) { |
| case Types::TextureStorage::Kind::TextureStorage1d: |
| base = "texture1d"_s; |
| break; |
| case Types::TextureStorage::Kind::TextureStorage2d: |
| base = "texture2d"_s; |
| break; |
| case Types::TextureStorage::Kind::TextureStorage2dArray: |
| base = "texture2d_array"_s; |
| break; |
| case Types::TextureStorage::Kind::TextureStorage3d: |
| base = "texture3d"_s; |
| break; |
| } |
| switch (texture.access) { |
| case AccessMode::Read: |
| mode = "read"_s; |
| break; |
| case AccessMode::Write: |
| mode = "write"_s; |
| break; |
| case AccessMode::ReadWrite: |
| mode = "read_write"_s; |
| break; |
| } |
| m_body.append(base, '<'); |
| visit(shaderTypeForTexelFormat(texture.format, m_shaderModule.types())); |
| m_body.append(", access::"_s, mode, '>'); |
| }, |
| [&](const TextureDepth& texture) { |
| ASCIILiteral base; |
| switch (texture.kind) { |
| case TextureDepth::Kind::TextureDepth2d: |
| base = "depth2d"_s; |
| break; |
| case TextureDepth::Kind::TextureDepth2dArray: |
| base = "depth2d_array"_s; |
| break; |
| case TextureDepth::Kind::TextureDepthCube: |
| base = "depthcube"_s; |
| break; |
| case TextureDepth::Kind::TextureDepthCubeArray: |
| base = "depthcube_array"_s; |
| break; |
| case TextureDepth::Kind::TextureDepthMultisampled2d: |
| base = "depth2d_ms"_s; |
| break; |
| } |
| m_body.append(base, "<float>"_s); |
| }, |
| [&](const Reference& reference) { |
| auto addressSpace = serializeAddressSpace(reference.addressSpace); |
| if (addressSpace.isNull()) { |
| visit(reference.element); |
| return; |
| } |
| if (reference.accessMode == AccessMode::Read) |
| m_body.append("const "_s); |
| m_body.append(addressSpace, ' '); |
| visit(reference.element); |
| m_body.append('&'); |
| }, |
| [&](const Pointer& pointer) { |
| auto addressSpace = serializeAddressSpace(pointer.addressSpace); |
| if (pointer.accessMode == AccessMode::Read) |
| m_body.append("const "_s); |
| if (addressSpace) |
| m_body.append(addressSpace, ' '); |
| bool shouldPack = pointer.addressSpace == AddressSpace::Storage || pointer.addressSpace == AddressSpace::Uniform; |
| visit(pointer.element, shouldPack); |
| m_body.append('*'); |
| }, |
| [&](const Atomic& atomic) { |
| if (atomic.element == m_shaderModule.types().i32Type()) |
| m_body.append("atomic_int"_s); |
| else |
| m_body.append("atomic_uint"_s); |
| }, |
| [&](const Function&) { |
| RELEASE_ASSERT_NOT_REACHED(); |
| }, |
| [&](const TypeConstructor&) { |
| RELEASE_ASSERT_NOT_REACHED(); |
| }); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::Parameter& parameter) |
| { |
| auto parameterRoleScope = SetForScope(m_parameterRole, parameter.role()); |
| visit(parameter.typeName().inferredType()); |
| m_body.append(' ', parameter.name()); |
| for (auto& attribute : parameter.attributes()) { |
| m_body.append(' '); |
| checkErrorAndVisit(attribute); |
| } |
| } |
| |
| void FunctionDefinitionWriter::visitArgumentBufferParameter(AST::Parameter& parameter) |
| { |
| m_body.append("constant "_s); |
| visit(parameter.typeName().inferredType()); |
| m_body.append("& "_s, parameter.name()); |
| for (auto& attribute : parameter.attributes()) { |
| m_body.append(' '); |
| checkErrorAndVisit(attribute); |
| } |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::Expression& expression) |
| { |
| visit(expression.inferredType(), expression); |
| } |
| |
| bool FunctionDefinitionWriter::outlineConstant(const Type* type, AST::Expression& expression) |
| { |
| auto constantValue = expression.constantValue(); |
| if (!constantValue) |
| return false; |
| |
| auto* maybeArrayValue = std::get_if<ConstantArray>(&*constantValue); |
| if (!maybeArrayValue) |
| return false; |
| |
| auto constantName = makeString("__wgslConst"_s, m_constID++); |
| |
| std::swap(m_body, m_constants); |
| |
| m_body.append("const constant "_s); |
| visit(type); |
| m_body.append(' ', constantName, " = "_s); |
| serializeConstant(type, *constantValue); |
| m_body.append(";\n"_s); |
| |
| std::swap(m_body, m_constants); |
| |
| m_body.append(constantName); |
| return true; |
| } |
| |
| void FunctionDefinitionWriter::visit(const Type* type, AST::Expression& expression) |
| { |
| if (outlineConstant(type, expression)) |
| return; |
| |
| if (auto constantValue = expression.constantValue()) { |
| serializeConstant(type, *constantValue); |
| return; |
| } |
| |
| if (auto* call = dynamicDowncast<AST::CallExpression>(expression)) |
| visit(type, *call); |
| else if (auto* identity = dynamicDowncast<AST::IdentityExpression>(expression)) |
| visit(type, identity->expression()); |
| else |
| AST::Visitor::visit(expression); |
| } |
| |
| static void visitArguments(FunctionDefinitionWriter* writer, AST::CallExpression& call, unsigned startOffset = 0) |
| { |
| writer->stringBuilder().append('('); |
| for (unsigned i = startOffset; i < call.arguments().size(); ++i) { |
| if (i != startOffset) |
| writer->stringBuilder().append(", "_s); |
| writer->visit(call.arguments()[i]); |
| } |
| writer->stringBuilder().append(')'); |
| } |
| |
| static void emitTextureDimensions(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| const auto* vector = std::get_if<Types::Vector>(call.inferredType()); |
| const auto& get = [&](ASCIILiteral property) { |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(".get_"_s, property, '('); |
| if (vector && call.arguments().size() > 1) { |
| writer->stringBuilder().append("min("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(".get_num_mip_levels(), uint("_s); |
| writer->visit(call.arguments()[1]); |
| writer->stringBuilder().append(')'); |
| writer->stringBuilder().append(')'); |
| } |
| writer->stringBuilder().append(')'); |
| }; |
| |
| if (!vector) { |
| get("width"_s); |
| return; |
| } |
| |
| auto size = vector->size; |
| ASSERT(size >= 2 && size <= 3); |
| writer->stringBuilder().append("uint"_s, String::number(size), '('); |
| get("width"_s); |
| writer->stringBuilder().append(", "_s); |
| get("height"_s); |
| if (size > 2) { |
| writer->stringBuilder().append(", "_s); |
| get("depth"_s); |
| } |
| writer->stringBuilder().append(')'); |
| } |
| |
| static void emitTextureGather(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| ASSERT(call.arguments().size() > 1); |
| unsigned offset = 0; |
| ASCIILiteral component; |
| bool hasOffset = true; |
| auto& firstArgument = call.arguments()[0]; |
| if (std::holds_alternative<Types::Primitive>(*firstArgument.inferredType())) { |
| offset = 1; |
| switch (firstArgument.constantValue()->integerValue()) { |
| case 0: |
| component = "x"_s; |
| break; |
| case 1: |
| component = "y"_s; |
| break; |
| case 2: |
| component = "z"_s; |
| break; |
| case 3: |
| component = "w"_s; |
| break; |
| default: |
| RELEASE_ASSERT_NOT_REACHED(); |
| } |
| |
| auto& textureType = std::get<Types::Texture>(*call.arguments()[1].inferredType()); |
| if (textureType.kind == Types::Texture::Kind::Texture2d || textureType.kind == Types::Texture::Kind::Texture2dArray) { |
| auto& lastArgument = call.arguments().last(); |
| auto* vectorType = std::get_if<Types::Vector>(lastArgument.inferredType()); |
| if (!vectorType || !satisfies(vectorType->element, Constraints::Integer)) |
| hasOffset = false; |
| } |
| } |
| writer->visit(call.arguments()[offset]); |
| writer->stringBuilder().append(".gather("_s); |
| for (unsigned i = offset + 1; i < call.arguments().size(); ++i) { |
| if (i != offset + 1) |
| writer->stringBuilder().append(", "_s); |
| writer->visit(call.arguments()[i]); |
| } |
| if (!hasOffset) |
| writer->stringBuilder().append(", int2(0)"_s); |
| if (!component.isNull()) |
| writer->stringBuilder().append(", component::"_s, component); |
| writer->stringBuilder().append(')'); |
| } |
| |
| static void emitTextureGatherCompare(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| ASSERT(call.arguments().size() > 1); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(".gather_compare"_s); |
| visitArguments(writer, call, 1); |
| } |
| |
| static void emitTextureLoad(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| auto& texture = call.arguments()[0]; |
| auto* textureType = texture.inferredType(); |
| |
| auto* primitive = std::get_if<Types::Primitive>(textureType); |
| bool isExternalTexture = primitive && primitive->kind == Types::Primitive::TextureExternal; |
| if (!isExternalTexture) { |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(".read"_s); |
| writer->stringBuilder().append('('); |
| bool is1d = true; |
| auto cast = "uint"_s; |
| if (const auto* vector = std::get_if<Types::Vector>(call.arguments()[1].inferredType())) { |
| is1d = false; |
| switch (vector->size) { |
| case 2: |
| cast = "uint2"_s; |
| break; |
| case 3: |
| cast = "uint3"_s; |
| break; |
| default: |
| RELEASE_ASSERT_NOT_REACHED(); |
| } |
| } |
| bool first = true; |
| auto argumentCount = call.arguments().size(); |
| for (unsigned i = 1; i < argumentCount; ++i) { |
| if (first) { |
| writer->stringBuilder().append(cast, '('); |
| writer->visit(call.arguments()[i]); |
| writer->stringBuilder().append(')'); |
| } else if (is1d && i == argumentCount - 1) { |
| // From the MSL spec for texture1d::read: |
| // > Since mipmaps are not supported for 1D textures, lod must be 0. |
| continue; |
| } else { |
| writer->stringBuilder().append(", "_s); |
| writer->visit(call.arguments()[i]); |
| } |
| first = false; |
| } |
| writer->stringBuilder().append(')'); |
| return; |
| } |
| |
| auto& coordinates = call.arguments()[1]; |
| writer->stringBuilder().append("({\n"_s); |
| { |
| IndentationScope scope(writer->indent()); |
| { |
| writer->stringBuilder().append(writer->indent(), "auto __coords = uint2(("_s); |
| writer->visit(texture); |
| writer->stringBuilder().append(".UVRemapMatrix * float3(float2("_s); |
| writer->visit(coordinates); |
| writer->stringBuilder().append("), 1)).xy);\n"_s); |
| } |
| { |
| writer->stringBuilder().append(writer->indent(), "auto __y = float("_s); |
| writer->visit(texture); |
| writer->stringBuilder().append(".FirstPlane.read(__coords).r);\n"_s); |
| } |
| { |
| writer->stringBuilder().append(writer->indent(), "auto __xAdjustment = "_s); |
| writer->visit(texture); |
| writer->stringBuilder().append(writer->indent(), ".SecondPlane.get_width(0) / static_cast<float>("_s); |
| writer->visit(texture); |
| writer->stringBuilder().append(writer->indent(), ".FirstPlane.get_width(0));"_s); |
| |
| writer->stringBuilder().append(writer->indent(), "auto __yAdjustment = "_s); |
| writer->visit(texture); |
| writer->stringBuilder().append(writer->indent(), ".SecondPlane.get_height(0) / static_cast<float>("_s); |
| writer->visit(texture); |
| writer->stringBuilder().append(writer->indent(), ".FirstPlane.get_height(0));"_s); |
| |
| writer->stringBuilder().append(writer->indent(), "auto __cbcr = float2("_s); |
| writer->visit(texture); |
| writer->stringBuilder().append(".SecondPlane.read(uint2(uint(__coords.x * __xAdjustment), uint(__coords.y * __yAdjustment))).rg);\n"_s); |
| } |
| writer->stringBuilder().append(writer->indent(), "auto __ycbcr = float3(__y, __cbcr);\n"_s); |
| { |
| writer->stringBuilder().append(writer->indent(), "float4("_s); |
| writer->visit(texture); |
| writer->stringBuilder().append(".ColorSpaceConversionMatrix * float4(__ycbcr, 1), 1);\n"_s); |
| } |
| } |
| writer->stringBuilder().append(writer->indent(), "})"_s); |
| } |
| |
| static void emitTextureSample(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| ASSERT(call.arguments().size() > 1); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(".sample"_s); |
| visitArguments(writer, call, 1); |
| } |
| |
| static void emitTextureSampleCompare(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| ASSERT(call.arguments().size() > 1); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(".sample_compare"_s); |
| visitArguments(writer, call, 1); |
| } |
| |
| static void emitTextureSampleGrad(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| |
| ASSERT(call.arguments().size() > 1); |
| auto& texture = call.arguments()[0]; |
| auto& textureType = std::get<Types::Texture>(*texture.inferredType()); |
| |
| unsigned gradientIndex; |
| ASCIILiteral gradientFunction; |
| switch (textureType.kind) { |
| case Types::Texture::Kind::Texture1d: |
| case Types::Texture::Kind::Texture2d: |
| case Types::Texture::Kind::TextureMultisampled2d: |
| gradientIndex = 3; |
| gradientFunction = "gradient2d"_s; |
| break; |
| |
| case Types::Texture::Kind::Texture3d: |
| gradientIndex = 3; |
| gradientFunction = "gradient3d"_s; |
| break; |
| |
| case Types::Texture::Kind::TextureCube: |
| gradientIndex = 3; |
| gradientFunction = "gradientcube"_s; |
| break; |
| |
| case Types::Texture::Kind::Texture2dArray: |
| gradientIndex = 4; |
| gradientFunction = "gradient2d"_s; |
| break; |
| |
| case Types::Texture::Kind::TextureCubeArray: |
| gradientIndex = 4; |
| gradientFunction = "gradientcube"_s; |
| break; |
| } |
| writer->visit(texture); |
| writer->stringBuilder().append(".sample("_s); |
| for (unsigned i = 1; i < gradientIndex; ++i) { |
| if (i != 1) |
| writer->stringBuilder().append(", "_s); |
| writer->visit(call.arguments()[i]); |
| } |
| writer->stringBuilder().append(", "_s, gradientFunction, '('); |
| writer->visit(call.arguments()[gradientIndex]); |
| writer->stringBuilder().append(", "_s); |
| writer->visit(call.arguments()[gradientIndex + 1]); |
| writer->stringBuilder().append(')'); |
| for (unsigned i = gradientIndex + 2; i < call.arguments().size(); ++i) { |
| writer->stringBuilder().append(", "_s); |
| writer->visit(call.arguments()[i]); |
| } |
| writer->stringBuilder().append(')'); |
| } |
| |
| static void emitTextureSampleLevel(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| bool isArray = false; |
| auto& texture = call.arguments()[0]; |
| if (auto* textureType = std::get_if<Types::Texture>(texture.inferredType())) { |
| switch (textureType->kind) { |
| case Types::Texture::Kind::Texture2dArray: |
| case Types::Texture::Kind::TextureCubeArray: |
| isArray = true; |
| break; |
| case Types::Texture::Kind::Texture1d: |
| case Types::Texture::Kind::Texture2d: |
| case Types::Texture::Kind::Texture3d: |
| case Types::Texture::Kind::TextureCube: |
| case Types::Texture::Kind::TextureMultisampled2d: |
| break; |
| } |
| } else if (auto* textureStorageType = std::get_if<Types::TextureStorage>(texture.inferredType())) { |
| switch (textureStorageType->kind) { |
| case Types::TextureStorage::Kind::TextureStorage2dArray: |
| isArray = true; |
| break; |
| case Types::TextureStorage::Kind::TextureStorage1d: |
| case Types::TextureStorage::Kind::TextureStorage2d: |
| case Types::TextureStorage::Kind::TextureStorage3d: |
| break; |
| } |
| } else { |
| auto& textureDepthType = std::get<Types::TextureDepth>(*texture.inferredType()); |
| switch (textureDepthType.kind) { |
| case Types::TextureDepth::Kind::TextureDepth2dArray: |
| case Types::TextureDepth::Kind::TextureDepthCubeArray: |
| isArray = true; |
| break; |
| case Types::TextureDepth::Kind::TextureDepth2d: |
| case Types::TextureDepth::Kind::TextureDepthCube: |
| case Types::TextureDepth::Kind::TextureDepthMultisampled2d: |
| break; |
| } |
| } |
| |
| unsigned levelIndex = isArray ? 4 : 3; |
| writer->visit(texture); |
| writer->stringBuilder().append(".sample("_s); |
| for (unsigned i = 1; i < levelIndex; ++i) { |
| if (i != 1) |
| writer->stringBuilder().append(','); |
| writer->visit(call.arguments()[i]); |
| } |
| writer->stringBuilder().append(", level("_s); |
| writer->visit(call.arguments()[levelIndex]); |
| writer->stringBuilder().append(')'); |
| for (unsigned i = levelIndex + 1; i < call.arguments().size(); ++i) { |
| writer->stringBuilder().append(','); |
| writer->visit(call.arguments()[i]); |
| } |
| writer->stringBuilder().append(')'); |
| } |
| |
| static void emitTextureSampleBaseClampToEdge(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| auto& texture = call.arguments()[0]; |
| auto* textureType = std::get_if<Types::Texture>(texture.inferredType()); |
| |
| if (textureType) { |
| // FIXME: <rdar://150364488> this needs to clamp the coordinates |
| writer->visit(texture); |
| writer->stringBuilder().append(".sample"_s); |
| visitArguments(writer, call, 1); |
| return; |
| } |
| |
| auto& sampler = call.arguments()[1]; |
| auto& coordinates = call.arguments()[2]; |
| writer->stringBuilder().append("({\n"_s); |
| { |
| IndentationScope scope(writer->indent()); |
| { |
| writer->stringBuilder().append(writer->indent(), "auto __coords = ("_s); |
| writer->visit(texture); |
| writer->stringBuilder().append(".UVRemapMatrix * float3("_s); |
| writer->visit(coordinates); |
| writer->stringBuilder().append(", 1)).xy;\n"_s); |
| } |
| { |
| writer->stringBuilder().append(writer->indent(), "auto __y = float("_s); |
| writer->visit(texture); |
| writer->stringBuilder().append(".FirstPlane.sample("_s); |
| writer->visit(sampler); |
| writer->stringBuilder().append(", __coords).r);\n"_s); |
| } |
| { |
| writer->stringBuilder().append(writer->indent(), "auto __cbcr = float2("_s); |
| writer->visit(texture); |
| writer->stringBuilder().append(".SecondPlane.sample("_s); |
| writer->visit(sampler); |
| writer->stringBuilder().append(", __coords).rg);\n"_s); |
| } |
| writer->stringBuilder().append(writer->indent(), "auto __ycbcr = float3(__y, __cbcr);\n"_s); |
| { |
| writer->stringBuilder().append(writer->indent(), "float4("_s); |
| writer->visit(texture); |
| writer->stringBuilder().append(".ColorSpaceConversionMatrix * float4(__ycbcr, 1), 1);\n"_s); |
| } |
| } |
| writer->stringBuilder().append(writer->indent(), "})"_s); |
| } |
| |
| static void emitTextureSampleBias(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| auto& texture = call.arguments()[0]; |
| auto& textureType = std::get<Types::Texture>(*texture.inferredType()); |
| bool isArray = false; |
| switch (textureType.kind) { |
| case Types::Texture::Kind::Texture2dArray: |
| case Types::Texture::Kind::TextureCubeArray: |
| isArray = true; |
| break; |
| case Types::Texture::Kind::Texture1d: |
| case Types::Texture::Kind::Texture2d: |
| case Types::Texture::Kind::Texture3d: |
| case Types::Texture::Kind::TextureCube: |
| case Types::Texture::Kind::TextureMultisampled2d: |
| break; |
| } |
| |
| unsigned biasIndex = isArray ? 4 : 3; |
| writer->visit(texture); |
| writer->stringBuilder().append(".sample("_s); |
| for (unsigned i = 1; i < biasIndex; ++i) { |
| if (i != 1) |
| writer->stringBuilder().append(", "_s); |
| writer->visit(call.arguments()[i]); |
| } |
| writer->stringBuilder().append(", bias("_s); |
| writer->visit(call.arguments()[biasIndex]); |
| writer->stringBuilder().append(')'); |
| for (unsigned i = biasIndex + 1; i < call.arguments().size(); ++i) { |
| writer->stringBuilder().append(", "_s); |
| writer->visit(call.arguments()[i]); |
| } |
| writer->stringBuilder().append(')'); |
| } |
| |
| static void emitTextureNumLayers(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(".get_array_size()"_s); |
| } |
| |
| static void emitTextureNumLevels(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(".get_num_mip_levels()"_s); |
| } |
| |
| static void emitTextureNumSamples(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(".get_num_samples()"_s); |
| } |
| |
| static void emitTextureStore(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| auto cast = "uint"_s; |
| if (const auto* vector = std::get_if<Types::Vector>(call.arguments()[1].inferredType())) { |
| switch (vector->size) { |
| case 2: |
| cast = "uint2"_s; |
| break; |
| case 3: |
| cast = "uint3"_s; |
| break; |
| default: |
| RELEASE_ASSERT_NOT_REACHED(); |
| } |
| } |
| |
| AST::Expression& texture = call.arguments()[0]; |
| AST::Expression& coords = call.arguments()[1]; |
| AST::Expression* arrayIndex = nullptr; |
| AST::Expression* value = nullptr; |
| if (call.arguments().size() == 3) |
| value = &call.arguments()[2]; |
| else { |
| arrayIndex = &call.arguments()[2]; |
| value = &call.arguments()[3]; |
| } |
| |
| writer->visit(texture); |
| writer->stringBuilder().append(".write("_s); |
| writer->visit(*value); |
| writer->stringBuilder().append(", "_s, cast, '('); |
| writer->visit(coords); |
| writer->stringBuilder().append(')'); |
| if (arrayIndex) { |
| writer->stringBuilder().append(", "_s); |
| writer->visit(*arrayIndex); |
| } |
| writer->stringBuilder().append(')'); |
| |
| auto& textureType = std::get<Types::TextureStorage>(*texture.inferredType()); |
| if (textureType.access == AccessMode::ReadWrite) { |
| writer->stringBuilder().append(";\n"_s, writer->indent()); |
| writer->visit(texture); |
| writer->stringBuilder().append(".fence()"_s); |
| } |
| } |
| |
| static void emitStorageBarrier(FunctionDefinitionWriter* writer, AST::CallExpression&) |
| { |
| writer->stringBuilder().append("threadgroup_barrier(mem_flags::mem_device)"_s); |
| } |
| |
| static void emitTextureBarrier(FunctionDefinitionWriter* writer, AST::CallExpression&) |
| { |
| writer->stringBuilder().append("threadgroup_barrier(mem_flags::mem_texture)"_s); |
| } |
| |
| static void emitWorkgroupBarrier(FunctionDefinitionWriter* writer, AST::CallExpression&) |
| { |
| writer->stringBuilder().append("threadgroup_barrier(mem_flags::mem_threadgroup)"_s); |
| } |
| |
| static void emitWorkgroupUniformLoad(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->stringBuilder().append("__workgroup_uniform_load("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(')'); |
| } |
| |
| static void atomicFunction(ASCIILiteral name, FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->stringBuilder().append(name, '('); |
| bool first = true; |
| for (auto& argument : call.arguments()) { |
| if (!first) |
| writer->stringBuilder().append(", "_s); |
| first = false; |
| writer->visit(argument); |
| } |
| writer->stringBuilder().append(", memory_order_relaxed)"_s); |
| } |
| |
| static void emitAtomicLoad(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| if (writer->metalAppleGPUFamily() >= 9) |
| writer->stringBuilder().append("({ volatile auto __wgslAtomicLoadResult = "_s); |
| atomicFunction("atomic_load_explicit"_s, writer, call); |
| if (writer->metalAppleGPUFamily() >= 9) |
| writer->stringBuilder().append("; __wgslAtomicLoadResult; })"_s); |
| } |
| |
| static void emitAtomicStore(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| atomicFunction("atomic_store_explicit"_s, writer, call); |
| } |
| |
| static void emitAtomicAdd(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| atomicFunction("atomic_fetch_add_explicit"_s, writer, call); |
| } |
| |
| static void emitAtomicSub(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| atomicFunction("atomic_fetch_sub_explicit"_s, writer, call); |
| } |
| |
| static void emitAtomicMax(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| atomicFunction("atomic_fetch_max_explicit"_s, writer, call); |
| } |
| |
| static void emitAtomicMin(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| atomicFunction("atomic_fetch_min_explicit"_s, writer, call); |
| } |
| |
| static void emitAtomicAnd(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| atomicFunction("atomic_fetch_and_explicit"_s, writer, call); |
| } |
| |
| static void emitAtomicOr(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| atomicFunction("atomic_fetch_or_explicit"_s, writer, call); |
| } |
| |
| static void emitAtomicXor(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| atomicFunction("atomic_fetch_xor_explicit"_s, writer, call); |
| } |
| |
| static void emitAtomicExchange(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| atomicFunction("atomic_exchange_explicit"_s, writer, call); |
| } |
| |
| [[noreturn]] static void emitArrayLength(FunctionDefinitionWriter*, AST::CallExpression&) |
| { |
| RELEASE_ASSERT_NOT_REACHED(); |
| } |
| |
| static void emitDistance(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| auto* argumentType = call.arguments()[0].inferredType(); |
| if (std::holds_alternative<Types::Primitive>(*argumentType)) { |
| writer->stringBuilder().append("abs("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(" - "_s); |
| writer->visit(call.arguments()[1]); |
| writer->stringBuilder().append(')'); |
| return; |
| } |
| writer->visit(call.target()); |
| visitArguments(writer, call); |
| } |
| |
| static void emitLength(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| auto* argumentType = call.arguments()[0].inferredType(); |
| if (!holds_alternative<Types::Vector>(*argumentType)) |
| writer->stringBuilder().append("abs"_s); |
| else |
| writer->stringBuilder().append("length"_s); |
| visitArguments(writer, call); |
| } |
| |
| static void emitDegrees(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->stringBuilder().append("static_cast<"_s); |
| writer->visit(call.inferredType()); |
| writer->stringBuilder().append(">("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(" * "_s, String::number(180 / std::numbers::pi), ')'); |
| } |
| |
| static void emitDynamicOffset(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| auto* targetType = call.target().inferredType(); |
| auto& pointer = std::get<Types::Pointer>(*targetType); |
| auto addressSpace = serializeAddressSpace(pointer.addressSpace); |
| |
| writer->stringBuilder().append("(*("_s); |
| writer->visit(targetType); |
| writer->stringBuilder().append(")((("_s, addressSpace, " uint8_t*)&("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(")) + __DynamicOffsets["_s); |
| writer->visit(call.arguments()[1]); |
| writer->stringBuilder().append("]))"_s); |
| } |
| |
| static void emitBitcast(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->stringBuilder().append("as_type<"_s); |
| writer->visit(call.target().inferredType()); |
| writer->stringBuilder().append(">("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(')'); |
| } |
| |
| static void emitPack2x16Float(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->stringBuilder().append("as_type<uint>(half2("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append("))"_s); |
| } |
| |
| static void emitUnpack2x16Float(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->stringBuilder().append("float2(as_type<half2>("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append("))"_s); |
| } |
| |
| static void emitPack4xI8(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->stringBuilder().append("as_type<uint>(char4("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append("))"_s); |
| } |
| |
| static void emitPack4xI8Clamp(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->stringBuilder().append("as_type<uint>(char4(clamp("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(", -128, 127)))"_s); |
| } |
| |
| static void emitUnpack4xI8(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->stringBuilder().append("int4(as_type<char4>("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append("))"_s); |
| } |
| |
| static void emitPack4xU8(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->stringBuilder().append("as_type<uint>(uchar4("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append("))"_s); |
| } |
| |
| static void emitPack4xU8Clamp(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->stringBuilder().append("as_type<uint>(uchar4(min("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(", 255)))"_s); |
| } |
| |
| static void emitQuantizeToF16(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| auto& argument = call.arguments()[0]; |
| String suffix = ""_s; |
| if (auto* vectorType = std::get_if<Types::Vector>(argument.inferredType())) |
| suffix = String::number(vectorType->size); |
| writer->stringBuilder().append("float"_s, suffix, "(half"_s, suffix, '('); |
| writer->visit(argument); |
| writer->stringBuilder().append("))"_s); |
| } |
| |
| static void emitRadians(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->stringBuilder().append("static_cast<"_s); |
| writer->visit(call.inferredType()); |
| writer->stringBuilder().append(">("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append(" * "_s, String::number(std::numbers::pi / 180), ')'); |
| } |
| |
| static void emitUnpack4xU8(FunctionDefinitionWriter* writer, AST::CallExpression& call) |
| { |
| writer->stringBuilder().append("uint4(as_type<uchar4>("_s); |
| writer->visit(call.arguments()[0]); |
| writer->stringBuilder().append("))"_s); |
| } |
| |
| void FunctionDefinitionWriter::visit(const Type* type, AST::CallExpression& call) |
| { |
| if (auto* target = dynamicDowncast<AST::ElaboratedTypeExpression>(call.target())) { |
| if (target->base() == "bitcast"_s) { |
| emitBitcast(this, call); |
| return; |
| } |
| } |
| |
| auto isArray = is<AST::ArrayTypeExpression>(call.target()); |
| auto isStruct = !isArray && std::holds_alternative<Types::Struct>(*call.target().inferredType()); |
| if (call.isConstructor() && (isArray || isStruct)) { |
| visit(type); |
| m_body.append('('); |
| const Type* arrayElementType = nullptr; |
| if (isArray) |
| arrayElementType = std::get<Types::Array>(*type).element; |
| |
| m_body.append("{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| for (auto& argument : call.arguments()) { |
| m_body.append(m_indent); |
| if (isStruct) |
| visit(argument); |
| else |
| visit(arrayElementType, argument); |
| m_body.append(",\n"_s); |
| } |
| } |
| m_body.append(m_indent, "})"_s); |
| return; |
| } |
| |
| if (auto* target = dynamicDowncast<AST::IdentifierExpression>(call.target())) { |
| static constexpr auto builtinMappings = std::to_array<std::pair<ComparableASCIILiteral, void(*)(FunctionDefinitionWriter*, AST::CallExpression&)>>({ |
| { "__dynamicOffset"_s, emitDynamicOffset }, |
| { "arrayLength"_s, emitArrayLength }, |
| { "atomicAdd"_s, emitAtomicAdd }, |
| { "atomicAnd"_s, emitAtomicAnd }, |
| { "atomicExchange"_s, emitAtomicExchange }, |
| { "atomicLoad"_s, emitAtomicLoad }, |
| { "atomicMax"_s, emitAtomicMax }, |
| { "atomicMin"_s, emitAtomicMin }, |
| { "atomicOr"_s, emitAtomicOr }, |
| { "atomicStore"_s, emitAtomicStore }, |
| { "atomicSub"_s, emitAtomicSub }, |
| { "atomicXor"_s, emitAtomicXor }, |
| { "degrees"_s, emitDegrees }, |
| { "distance"_s, emitDistance }, |
| { "length"_s, emitLength }, |
| { "pack2x16float"_s, emitPack2x16Float }, |
| { "pack4xI8"_s, emitPack4xI8 }, |
| { "pack4xI8Clamp"_s, emitPack4xI8Clamp }, |
| { "pack4xU8"_s, emitPack4xU8 }, |
| { "pack4xU8Clamp"_s, emitPack4xU8Clamp }, |
| { "quantizeToF16"_s, emitQuantizeToF16 }, |
| { "radians"_s, emitRadians }, |
| { "storageBarrier"_s, emitStorageBarrier }, |
| { "textureBarrier"_s, emitTextureBarrier }, |
| { "textureDimensions"_s, emitTextureDimensions }, |
| { "textureGather"_s, emitTextureGather }, |
| { "textureGatherCompare"_s, emitTextureGatherCompare }, |
| { "textureLoad"_s, emitTextureLoad }, |
| { "textureNumLayers"_s, emitTextureNumLayers }, |
| { "textureNumLevels"_s, emitTextureNumLevels }, |
| { "textureNumSamples"_s, emitTextureNumSamples }, |
| { "textureSample"_s, emitTextureSample }, |
| { "textureSampleBaseClampToEdge"_s, emitTextureSampleBaseClampToEdge }, |
| { "textureSampleBias"_s, emitTextureSampleBias }, |
| { "textureSampleCompare"_s, emitTextureSampleCompare }, |
| { "textureSampleCompareLevel"_s, emitTextureSampleCompare }, |
| { "textureSampleGrad"_s, emitTextureSampleGrad }, |
| { "textureSampleLevel"_s, emitTextureSampleLevel }, |
| { "textureStore"_s, emitTextureStore }, |
| { "unpack2x16float"_s, emitUnpack2x16Float }, |
| { "unpack4xI8"_s, emitUnpack4xI8 }, |
| { "unpack4xU8"_s, emitUnpack4xU8 }, |
| { "workgroupBarrier"_s, emitWorkgroupBarrier }, |
| { "workgroupUniformLoad"_s, emitWorkgroupUniformLoad }, |
| }); |
| static constexpr SortedArrayMap builtins { builtinMappings }; |
| const auto& targetName = target->identifier().id(); |
| if (auto mappedBuiltin = builtins.get(targetName)) { |
| mappedBuiltin(this, call); |
| return; |
| } |
| |
| #define EMIT_HELPER(name) \ |
| [](HelperGenerator& helperGenerator) { \ |
| if (!std::exchange(helperGenerator.didEmit##name, true)) \ |
| helperGenerator.emit##name(); \ |
| return "__wgsl"#name##_s;\ |
| } |
| |
| #define NOOP_HELPER(name) \ |
| [](HelperGenerator&) { return #name##_s; } |
| |
| static constexpr auto directMappings = std::to_array<std::pair<ComparableASCIILiteral, ASCIILiteral(*)(HelperGenerator&)>>({ |
| { "acos"_s, EMIT_HELPER(Acos) }, |
| { "acosh"_s, EMIT_HELPER(Acosh) }, |
| { "asin"_s, EMIT_HELPER(Asin) }, |
| { "atanh"_s, EMIT_HELPER(Atanh) }, |
| { "atomicCompareExchangeWeak"_s, NOOP_HELPER(__wgslAtomicCompareExchangeWeak) }, |
| { "countLeadingZeros"_s, NOOP_HELPER(clz) }, |
| { "countOneBits"_s, NOOP_HELPER(popcount) }, |
| { "countTrailingZeros"_s, NOOP_HELPER(ctz) }, |
| { "dot"_s, NOOP_HELPER(__wgslDot) }, |
| { "dot4I8Packed"_s, NOOP_HELPER(__wgslDot4I8Packed) }, |
| { "dot4U8Packed"_s, NOOP_HELPER(__wgslDot4U8Packed) }, |
| { "dpdx"_s, NOOP_HELPER(dfdx) }, |
| { "dpdxCoarse"_s, NOOP_HELPER(dfdx) }, |
| { "dpdxFine"_s, NOOP_HELPER(dfdx) }, |
| { "dpdy"_s, NOOP_HELPER(dfdy) }, |
| { "dpdyCoarse"_s, NOOP_HELPER(dfdy) }, |
| { "dpdyFine"_s, NOOP_HELPER(dfdy) }, |
| { "extractBits"_s, NOOP_HELPER(__wgslExtractBits) }, |
| { "faceForward"_s, NOOP_HELPER(faceforward) }, |
| { "firstLeadingBit"_s, NOOP_HELPER(__wgslFirstLeadingBit) }, |
| { "firstTrailingBit"_s, NOOP_HELPER(__wgslFirstTrailingBit) }, |
| { "frexp"_s, NOOP_HELPER(__wgslFrexp) }, |
| { "fwidthCoarse"_s, NOOP_HELPER(fwidth) }, |
| { "fwidthFine"_s, NOOP_HELPER(fwidth) }, |
| { "insertBits"_s, NOOP_HELPER(__wgslInsertBits) }, |
| { "inverseSqrt"_s, EMIT_HELPER(InverseSqrt) }, |
| { "log"_s, EMIT_HELPER(Log) }, |
| { "log2"_s, EMIT_HELPER(Log2) }, |
| { "modf"_s, NOOP_HELPER(__wgslModf) }, |
| { "pack2x16snorm"_s, EMIT_HELPER(PackFloatToSnorm2x16) }, |
| { "pack2x16unorm"_s, EMIT_HELPER(PackFloatToUnorm2x16) }, |
| { "pack4x8snorm"_s, EMIT_HELPER(PackFloatToSnorm4x8) }, |
| { "pack4x8unorm"_s, EMIT_HELPER(PackFloatToUnorm4x8) }, |
| { "reverseBits"_s, NOOP_HELPER(reverse_bits) }, |
| { "round"_s, NOOP_HELPER(rint) }, |
| { "sign"_s, NOOP_HELPER(__wgslSign) }, |
| { "sqrt"_s, EMIT_HELPER(Sqrt) }, |
| { "unpack2x16snorm"_s, NOOP_HELPER(unpack_snorm2x16_to_float) }, |
| { "unpack2x16unorm"_s, NOOP_HELPER(unpack_unorm2x16_to_float) }, |
| { "unpack4x8snorm"_s, NOOP_HELPER(unpack_snorm4x8_to_float) }, |
| { "unpack4x8unorm"_s, NOOP_HELPER(unpack_unorm4x8_to_float) }, |
| }); |
| |
| #undef EMIT_HELPER |
| #undef NOOP_HELPER |
| |
| static constexpr SortedArrayMap mappedNames { directMappings }; |
| if (call.isConstructor()) { |
| if (call.isFloatToIntConversion()) { |
| m_body.append("__wgslFtoi<"_s); |
| visit(type); |
| m_body.append(">"_s); |
| } else |
| visit(type); |
| } else if (auto mappedName = mappedNames.get(targetName)) |
| m_body.append(mappedName(m_helperGenerator)); |
| else |
| m_body.append(targetName); |
| visitArguments(this, call); |
| return; |
| } |
| |
| visit(type); |
| visitArguments(this, call); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::UnaryExpression& unary) |
| { |
| m_body.append('('); |
| switch (unary.operation()) { |
| case AST::UnaryOperation::Complement: |
| m_body.append('~'); |
| break; |
| case AST::UnaryOperation::Negate: |
| m_body.append('-'); |
| break; |
| case AST::UnaryOperation::Not: |
| m_body.append('!'); |
| break; |
| case AST::UnaryOperation::AddressOf: |
| m_body.append('&'); |
| break; |
| case AST::UnaryOperation::Dereference: |
| m_body.append('*'); |
| break; |
| } |
| visit(unary.expression()); |
| m_body.append(')'); |
| } |
| |
| void FunctionDefinitionWriter::serializeBinaryExpression(AST::Expression& lhs, AST::BinaryOperation operation, AST::Expression& rhs) |
| { |
| bool isDiv = operation == AST::BinaryOperation::Divide; |
| bool isMod = !isDiv && operation == AST::BinaryOperation::Modulo; |
| |
| if (isDiv || isMod) { |
| auto* rightType = rhs.inferredType(); |
| if (auto* vectorType = std::get_if<Types::Vector>(rightType)) |
| rightType = vectorType->element; |
| |
| ASCIILiteral helperFunction; |
| if (satisfies(rightType, Constraints::Integer)) { |
| if (isDiv) |
| helperFunction = "__wgslDiv"_s; |
| else |
| helperFunction = "__wgslMod"_s; |
| } else if (isMod) |
| helperFunction = "fmod"_s; |
| |
| if (!helperFunction.isNull()) { |
| m_body.append(helperFunction, '('); |
| visit(lhs); |
| m_body.append(", "_s); |
| visit(rhs); |
| m_body.append(')'); |
| return; |
| } |
| } |
| |
| m_body.append('('); |
| visit(lhs); |
| switch (operation) { |
| case AST::BinaryOperation::Add: |
| m_body.append(" + "_s); |
| break; |
| case AST::BinaryOperation::Subtract: |
| m_body.append(" - "_s); |
| break; |
| case AST::BinaryOperation::Multiply: |
| m_body.append(" * "_s); |
| break; |
| case AST::BinaryOperation::Divide: |
| m_body.append(" / "_s); |
| break; |
| case AST::BinaryOperation::Modulo: |
| m_body.append(" % "_s); |
| break; |
| case AST::BinaryOperation::And: |
| m_body.append(" & "_s); |
| break; |
| case AST::BinaryOperation::Or: |
| m_body.append(" | "_s); |
| break; |
| case AST::BinaryOperation::Xor: |
| m_body.append(" ^ "_s); |
| break; |
| |
| case AST::BinaryOperation::LeftShift: |
| m_body.append(" << "_s); |
| break; |
| case AST::BinaryOperation::RightShift: |
| m_body.append(" >> "_s); |
| break; |
| |
| case AST::BinaryOperation::Equal: |
| m_body.append(" == "_s); |
| break; |
| case AST::BinaryOperation::NotEqual: |
| m_body.append(" != "_s); |
| break; |
| case AST::BinaryOperation::GreaterThan: |
| m_body.append(" > "_s); |
| break; |
| case AST::BinaryOperation::GreaterEqual: |
| m_body.append(" >= "_s); |
| break; |
| case AST::BinaryOperation::LessThan: |
| m_body.append(" < "_s); |
| break; |
| case AST::BinaryOperation::LessEqual: |
| m_body.append(" <= "_s); |
| break; |
| |
| case AST::BinaryOperation::ShortCircuitAnd: |
| m_body.append(" && "_s); |
| break; |
| case AST::BinaryOperation::ShortCircuitOr: |
| m_body.append(" || "_s); |
| break; |
| } |
| visit(rhs); |
| m_body.append(')'); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::BinaryExpression& binary) |
| { |
| serializeBinaryExpression(binary.leftExpression(), binary.operation(), binary.rightExpression()); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::PointerDereferenceExpression& pointerDereference) |
| { |
| m_body.append("(*"_s); |
| visit(pointerDereference.target()); |
| m_body.append(')'); |
| } |
| void FunctionDefinitionWriter::visit(AST::IndexAccessExpression& access) |
| { |
| bool isPointer = std::holds_alternative<Types::Pointer>(*access.base().inferredType()); |
| if (isPointer) |
| m_body.append("(*("_s); |
| visit(access.base()); |
| if (isPointer) |
| m_body.append("))"_s); |
| m_body.append('['); |
| visit(access.index()); |
| m_body.append(']'); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::IdentifierExpression& identifier) |
| { |
| auto it = m_constantValues.find(identifier.identifier()); |
| if (it != m_constantValues.end()) [[unlikely]] { |
| m_body.append('('); |
| serializeConstant(identifier.inferredType(), it->value); |
| m_body.append(')'); |
| return; |
| } |
| m_body.append(identifier.identifier()); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::FieldAccessExpression& access) |
| { |
| visit(access.base()); |
| auto* baseType = access.base().inferredType(); |
| if (baseType && std::holds_alternative<Types::Pointer>(*baseType)) |
| m_body.append("->"_s); |
| else |
| m_body.append('.'); |
| m_body.append(access.fieldName()); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::BoolLiteral& literal) |
| { |
| m_body.append(literal.value() ? "true"_s : "false"_s); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::AbstractIntegerLiteral& literal) |
| { |
| m_body.append(literal.value()); |
| auto& primitiveType = std::get<Types::Primitive>(*literal.inferredType()); |
| if (primitiveType.kind == Types::Primitive::U32) |
| m_body.append('u'); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::Signed32Literal& literal) |
| { |
| m_body.append(literal.value()); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::Unsigned32Literal& literal) |
| { |
| m_body.append(literal.value(), 'u'); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::AbstractFloatLiteral& literal) |
| { |
| NumberToStringBuffer buffer; |
| m_body.append(WTF::numberToStringWithTrailingPoint(literal.value(), buffer)); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::Float32Literal& literal) |
| { |
| NumberToStringBuffer buffer; |
| m_body.append(WTF::numberToStringWithTrailingPoint(literal.value(), buffer)); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::Float16Literal& literal) |
| { |
| NumberToStringBuffer buffer; |
| m_body.append(WTF::numberToStringWithTrailingPoint(literal.value(), buffer)); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::Statement& statement) |
| { |
| AST::Visitor::visit(statement); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::AssignmentStatement& assignment) |
| { |
| visit(assignment.lhs()); |
| m_body.append(" = "_s); |
| const auto* assignmentType = assignment.lhs().inferredType(); |
| if (!assignmentType) { |
| // In theory this should never happen, but the assignments generated by |
| // the EntryPointRewriter do not have inferred types |
| visit(assignment.rhs()); |
| return; |
| } |
| |
| const auto& reference = std::get<Types::Reference>(*assignmentType); |
| visit(reference.element, assignment.rhs()); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::CallStatement& statement) |
| { |
| visit(statement.call().inferredType(), statement.call()); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::CompoundAssignmentStatement& statement) |
| { |
| bool serialized = false; |
| auto* leftExpression = &statement.leftExpression(); |
| if (auto* identity = dynamicDowncast<AST::IdentityExpression>(*leftExpression)) |
| leftExpression = &identity->expression(); |
| if (auto* call = dynamicDowncast<AST::CallExpression>(*leftExpression)) { |
| auto& target = call->target(); |
| if (auto* identifier = dynamicDowncast<AST::IdentifierExpression>(target)) { |
| if (identifier->identifier() == "__unpack"_s) { |
| serialized = true; |
| visit(call->arguments()[0]); |
| } |
| } |
| } |
| if (!serialized) |
| visit(statement.leftExpression()); |
| |
| m_body.append(" = "_s); |
| serializeBinaryExpression(statement.leftExpression(), statement.operation(), statement.rightExpression()); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::CompoundStatement& statement) |
| { |
| m_body.append("{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| visitStatements(statement.statements()); |
| } |
| m_body.append(m_indent, '}'); |
| } |
| |
| void FunctionDefinitionWriter::visitStatements(AST::Statement::List& statements) |
| { |
| for (auto& statement : statements) { |
| m_body.append(m_indent); |
| checkErrorAndVisit(statement); |
| switch (statement.kind()) { |
| case AST::NodeKind::AssignmentStatement: |
| case AST::NodeKind::BreakStatement: |
| case AST::NodeKind::CallStatement: |
| case AST::NodeKind::CompoundAssignmentStatement: |
| case AST::NodeKind::ContinueStatement: |
| case AST::NodeKind::DecrementIncrementStatement: |
| case AST::NodeKind::DiscardStatement: |
| case AST::NodeKind::PhonyAssignmentStatement: |
| case AST::NodeKind::ReturnStatement: |
| case AST::NodeKind::VariableStatement: |
| m_body.append(';'); |
| break; |
| default: |
| break; |
| } |
| m_body.append('\n'); |
| } |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::DecrementIncrementStatement& statement) |
| { |
| visit(statement.expression()); |
| switch (statement.operation()) { |
| case AST::DecrementIncrementStatement::Operation::Increment: |
| m_body.append("++"_s); |
| break; |
| case AST::DecrementIncrementStatement::Operation::Decrement: |
| m_body.append("--"_s); |
| break; |
| } |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::DiscardStatement&) |
| { |
| #if CPU(X86_64) |
| m_body.append("__asm volatile(\"\"); discard_fragment()"_s); |
| #else |
| m_body.append("discard_fragment();"_s); |
| #endif |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::IfStatement& statement) |
| { |
| m_body.append("if ("_s); |
| visit(statement.test()); |
| m_body.append(") "_s); |
| visit(statement.trueBody()); |
| if (statement.maybeFalseBody()) { |
| m_body.append(" else "_s); |
| visit(*statement.maybeFalseBody()); |
| } |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::PhonyAssignmentStatement& statement) |
| { |
| m_body.append("(void)("_s); |
| visit(statement.rhs()); |
| m_body.append(')'); |
| } |
| |
| static std::optional<std::pair<String, String>> returnIdentifierForFunction(WGSL::Builtin builtIn, AST::Function* function) |
| { |
| if (!function || function->stage() != ShaderStage::Fragment) |
| return std::nullopt; |
| |
| if (auto expression = function->maybeReturnType()) { |
| if (auto* inferredType = expression->inferredType()) { |
| auto& type = *inferredType; |
| auto* returnStruct = std::get_if<WGSL::Types::Struct>(&type); |
| if (!returnStruct) |
| return std::nullopt; |
| |
| for (auto& member : returnStruct->structure.members()) { |
| if (member.builtin() == builtIn) |
| return std::make_pair(returnStruct->structure.name(), member.name()); |
| for (auto& attribute : member.attributes()) { |
| auto* builtinAttribute = dynamicDowncast<AST::BuiltinAttribute>(attribute); |
| if (builtinAttribute && builtinAttribute->builtin() == builtIn) |
| return std::make_pair(returnStruct->structure.name(), member.name()); |
| } |
| } |
| } |
| } |
| |
| return std::nullopt; |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::ReturnStatement& statement) |
| { |
| auto fragDepthIdentifier = returnIdentifierForFunction(WGSL::Builtin::FragDepth, m_currentFunction); |
| auto sampleMaskIdentifier = returnIdentifierForFunction(WGSL::Builtin::SampleMask, m_currentFunction); |
| if (fragDepthIdentifier) |
| m_body.append(fragDepthIdentifier->first, " __wgslFragmentReturnResult = "_s); |
| else if (sampleMaskIdentifier) |
| m_body.append(sampleMaskIdentifier->first, " __wgslFragmentReturnResult = "_s); |
| else |
| m_body.append("return"_s); |
| if (statement.maybeExpression()) { |
| m_body.append(' '); |
| visit(*statement.maybeExpression()); |
| } |
| |
| if (fragDepthIdentifier) |
| m_body.append(";\n__wgslFragmentReturnResult."_s, fragDepthIdentifier->second, " = clamp(__wgslFragmentReturnResult."_s, fragDepthIdentifier->second, ", as_type<float>(__DynamicOffsets[0]), as_type<float>(__DynamicOffsets[1]));\n"_s); |
| if (sampleMaskIdentifier) |
| m_body.append(";\n__wgslFragmentReturnResult."_s, sampleMaskIdentifier->second, " = (__wgslFragmentReturnResult."_s, sampleMaskIdentifier->second, " & __DynamicOffsets[2]);\n"_s); |
| if (fragDepthIdentifier || sampleMaskIdentifier) |
| m_body.append("return __wgslFragmentReturnResult"_s); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::ForStatement& statement) |
| { |
| if (statement.isInternallyGenerated()) |
| m_body.append("for ("_s); |
| else |
| m_body.append("{ " DECLARE_FORWARD_PROGRESS " for ("_s); |
| |
| if (auto* initializer = statement.maybeInitializer()) |
| visit(*initializer); |
| m_body.append(';'); |
| if (auto* test = statement.maybeTest()) { |
| m_body.append(' '); |
| visit(*test); |
| } |
| m_body.append(';'); |
| if (auto* update = statement.maybeUpdate()) { |
| m_body.append(' '); |
| visit(*update); |
| } |
| |
| if (statement.isInternallyGenerated()) |
| m_body.append(')'); |
| else |
| m_body.append(") { " CHECK_FORWARD_PROGRESS " "_s); |
| visit(statement.body()); |
| if (!statement.isInternallyGenerated()) { |
| m_body.append('}'); |
| m_body.append('}'); |
| } |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::LoopStatement& statement) |
| { |
| m_body.append("{ " DECLARE_FORWARD_PROGRESS " while (true) { " CHECK_FORWARD_PROGRESS " \n"_s); |
| { |
| if (statement.containsSwitch()) |
| m_body.append("bool __continuing = false;\n"_s, m_indent); |
| auto& continuing = statement.continuing(); |
| SetForScope continuingScope(m_continuing, continuing.has_value() ? &*continuing : nullptr); |
| |
| IndentationScope scope(m_indent); |
| visitStatements(statement.body()); |
| |
| if (continuing.has_value()) { |
| m_body.append(m_indent); |
| visit(*continuing); |
| } |
| } |
| m_body.append(m_indent, '}'); |
| m_body.append(m_indent, '}'); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::Continuing& continuing) |
| { |
| // Do not emit the same continuing for continue statements within the continuing block |
| SetForScope continuingScope(m_continuing, nullptr); |
| |
| m_body.append("{\n"_s); |
| { |
| IndentationScope scope(m_indent); |
| visitStatements(continuing.body); |
| |
| if (auto* breakIf = continuing.breakIf) { |
| m_body.append(m_indent, "if ("_s); |
| visit(*breakIf); |
| m_body.append(")\n"_s); |
| |
| IndentationScope scope(m_indent); |
| m_body.append(m_indent, "break;\n"_s); |
| } |
| } |
| m_body.append(m_indent, "}\n"_s); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::WhileStatement& statement) |
| { |
| m_body.append("{ " DECLARE_FORWARD_PROGRESS " while ("_s); |
| visit(statement.test()); |
| m_body.append(") { " CHECK_FORWARD_PROGRESS " "_s); |
| visit(statement.body()); |
| m_body.append('}'); |
| m_body.append('}'); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::SwitchStatement& statement) |
| { |
| const auto& visitClause = [&](AST::SwitchClause& clause, bool isDefault = false) { |
| for (auto& selector : clause.selectors) { |
| m_body.append('\n', m_indent, "case "_s); |
| visit(selector); |
| m_body.append(':'); |
| } |
| if (isDefault) |
| m_body.append('\n', m_indent, "default:"_s); |
| m_body.append("\n{ " DECLARE_FORWARD_PROGRESS "\n"_s); |
| visit(clause.body); |
| |
| IndentationScope scope(m_indent); |
| m_body.append('\n', m_indent, "\n}\nbreak;"_s); |
| }; |
| |
| m_body.append("switch ("_s); |
| visit(statement.value()); |
| m_body.append(") {"_s); |
| for (auto& clause : statement.clauses()) |
| visitClause(clause); |
| visitClause(statement.defaultClause(), true); |
| m_body.append('\n', m_indent, '}'); |
| if (statement.isInsideLoop()) { |
| m_body.append('\n', m_indent, "if (__continuing) {"_s); |
| { |
| auto scope = IndentationScope(m_indent); |
| visit(*m_continuing); |
| } |
| m_body.append('\n', m_indent, '}'); |
| } else if (statement.isNestedInsideLoop()) { |
| m_body.append('\n', m_indent, "if (__continuing) {"_s); |
| { |
| auto scope = IndentationScope(m_indent); |
| m_body.append('\n', m_indent, "break;"_s); |
| } |
| m_body.append('\n', m_indent, '}'); |
| } |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::BreakStatement&) |
| { |
| m_body.append("break"_s); |
| } |
| |
| void FunctionDefinitionWriter::visit(AST::ContinueStatement& statement) |
| { |
| if (statement.isFromSwitchToContinuing()) { |
| m_body.append("__continuing = true;\n"_s); |
| m_body.append(m_indent, "break"_s); |
| return; |
| } |
| if (m_continuing) { |
| visit(*m_continuing); |
| m_body.append(m_indent); |
| } |
| m_body.append("continue"_s); |
| } |
| |
| void FunctionDefinitionWriter::serializeConstant(const Type* type, ConstantValue value) |
| { |
| using namespace Types; |
| |
| WTF::switchOn(*type, |
| [&](const Primitive& primitive) { |
| switch (primitive.kind) { |
| case Primitive::AbstractInt: |
| m_body.append(std::get<int64_t>(value)); |
| break; |
| case Primitive::I32: |
| m_body.append(std::get<int32_t>(value)); |
| break; |
| case Primitive::U32: |
| m_body.append(std::get<uint32_t>(value), 'u'); |
| break; |
| case Primitive::AbstractFloat: { |
| NumberToStringBuffer buffer; |
| m_body.append(WTF::numberToStringWithTrailingPoint(std::get<double>(value), buffer)); |
| break; |
| } |
| case Primitive::F32: { |
| NumberToStringBuffer buffer; |
| m_body.append(WTF::numberToStringWithTrailingPoint(std::get<float>(value), buffer)); |
| break; |
| } |
| case Primitive::F16: { |
| NumberToStringBuffer buffer; |
| m_body.append(WTF::numberToStringWithTrailingPoint(std::get<half>(value), buffer), 'h'); |
| break; |
| } |
| case Primitive::Bool: |
| m_body.append(std::get<bool>(value) ? "true"_s : "false"_s); |
| break; |
| case Primitive::Void: |
| case Primitive::Sampler: |
| case Primitive::SamplerComparison: |
| case Primitive::TextureExternal: |
| case Primitive::AccessMode: |
| case Primitive::TexelFormat: |
| case Primitive::AddressSpace: |
| RELEASE_ASSERT_NOT_REACHED(); |
| } |
| }, |
| [&](const Reference& reference) { |
| return serializeConstant(reference.element, value); |
| }, |
| [&](const Vector& vectorType) { |
| auto& vector = std::get<ConstantVector>(value); |
| visit(type); |
| m_body.append('('); |
| bool first = true; |
| for (auto& element : vector.elements) { |
| if (!first) |
| m_body.append(", "_s); |
| first = false; |
| serializeConstant(vectorType.element, element); |
| } |
| m_body.append(')'); |
| }, |
| [&](const Array& arrayType) { |
| auto& array = std::get<ConstantArray>(value); |
| visit(type); |
| m_body.append('{'); |
| bool first = true; |
| for (auto& element : array.elements) { |
| if (!first) |
| m_body.append(", "_s); |
| first = false; |
| serializeConstant(arrayType.element, element); |
| } |
| m_body.append('}'); |
| }, |
| [&](const Matrix& matrixType) { |
| auto& matrix = std::get<ConstantMatrix>(value); |
| m_body.append("matrix<"_s); |
| visit(matrixType.element); |
| m_body.append(", "_s, matrixType.columns, ", "_s, matrixType.rows, ">("_s); |
| bool first = true; |
| for (auto& element : matrix.elements) { |
| if (!first) |
| m_body.append(", "_s); |
| first = false; |
| serializeConstant(matrixType.element, element); |
| } |
| m_body.append(')'); |
| }, |
| [&](const Struct& structType) { |
| auto& constantStruct = std::get<ConstantStruct>(value); |
| m_body.append(structType.structure.name(), " { "_s); |
| for (auto& member : structType.structure.members()) { |
| m_body.append('.', member.name(), " = "_s); |
| serializeConstant(structType.fields.get(member.originalName()), constantStruct.fields.get(member.originalName())); |
| m_body.append(", "_s); |
| } |
| m_body.append(" }"_s); |
| }, |
| [&](const PrimitiveStruct& primitiveStruct) { |
| auto& constantStruct = std::get<ConstantStruct>(value); |
| const auto& keys = Types::PrimitiveStruct::keys[primitiveStruct.kind]; |
| |
| m_body.append(primitiveStruct.name, '<'); |
| bool first = true; |
| for (auto& value : primitiveStruct.values) { |
| if (!first) |
| m_body.append(", "_s); |
| first = false; |
| visit(value); |
| } |
| m_body.append("> {"_s); |
| first = true; |
| for (auto& entry : constantStruct.fields) { |
| if (!first) |
| m_body.append(", "_s); |
| first = false; |
| m_body.append('.', entry.key, " = "_s); |
| auto* key = keys.tryGet(entry.key); |
| RELEASE_ASSERT(key); |
| auto* type = primitiveStruct.values[*key]; |
| serializeConstant(type, entry.value); |
| } |
| m_body.append('}'); |
| }, |
| [&](const Pointer&) { |
| RELEASE_ASSERT_NOT_REACHED(); |
| }, |
| [&](const Function&) { |
| RELEASE_ASSERT_NOT_REACHED(); |
| }, |
| [&](const Texture&) { |
| RELEASE_ASSERT_NOT_REACHED(); |
| }, |
| [&](const TextureStorage&) { |
| RELEASE_ASSERT_NOT_REACHED(); |
| }, |
| [&](const TextureDepth&) { |
| RELEASE_ASSERT_NOT_REACHED(); |
| }, |
| [&](const Atomic&) { |
| RELEASE_ASSERT_NOT_REACHED(); |
| }, |
| [&](const TypeConstructor&) { |
| RELEASE_ASSERT_NOT_REACHED(); |
| }); |
| } |
| |
| void emitMetalFunctions(StringBuilder& stringBuilder, ShaderModule& shaderModule, PrepareResult& prepareResult, const HashMap<String, ConstantValue>& constantValues, DeviceState&& deviceState) |
| { |
| FunctionDefinitionWriter functionDefinitionWriter(shaderModule, stringBuilder, prepareResult, constantValues, WTF::move(deviceState)); |
| functionDefinitionWriter.write(); |
| } |
| |
| #undef DECLARE_FORWARD_PROGRESS |
| #undef CHECK_FORWARD_PROGRESS |
| |
| } // namespace Metal |
| } // namespace WGSL |