blob: 50d6c4de6883e8d85d9fa768dcdc08c30e5b0d2c [file] [log] [blame] [edit]
/*
* Copyright (c) 2023 Apple Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "config.h"
#include "GlobalVariableRewriter.h"
#include "AST.h"
#include "ASTIdentifier.h"
#include "ASTVisitor.h"
#include "CallGraph.h"
#include "WGSL.h"
#include "WGSLShaderModule.h"
#include <ranges>
#include <wtf/DataLog.h>
#include <wtf/HashMap.h>
#include <wtf/ListHashSet.h>
#include <wtf/SetForScope.h>
#include <wtf/text/MakeString.h>
namespace WGSL {
constexpr bool shouldLogGlobalVariableRewriting = false;
class RewriteGlobalVariables : public AST::Visitor {
public:
RewriteGlobalVariables(ShaderModule& shaderModule, const HashMap<String, PipelineLayout*>& pipelineLayouts, HashMap<String, Reflection::EntryPointInformation>& entryPointInformations)
: AST::Visitor()
, m_shaderModule(shaderModule)
, m_pipelineLayouts(pipelineLayouts)
, m_entryPointInformations(entryPointInformations)
{
}
std::optional<Error> run();
void visit(AST::Function&) override;
void visit(AST::Variable&) override;
void visit(AST::Parameter&) override;
void visit(AST::CompoundStatement&) override;
void visit(AST::AssignmentStatement&) override;
void visit(AST::VariableStatement&) override;
void visit(AST::PhonyAssignmentStatement&) override;
void visit(AST::CompoundAssignmentStatement&) override;
void visit(AST::Expression&) override;
private:
struct Global {
struct Resource {
unsigned group;
unsigned binding;
};
std::optional<Resource> resource;
AST::Variable* declaration;
};
template<typename Value>
using IndexMap = HashMap<uint64_t, Value, WTF::IntHash<uint64_t>, WTF::UnsignedWithZeroKeyHashTraits<uint64_t>>;
using UsedResources = IndexMap<IndexMap<Global*>>;
using UsedPrivateGlobals = Vector<Global*>;
struct UsedGlobals {
UsedResources resources;
UsedPrivateGlobals privateGlobals;
};
struct Insertion {
AST::Statement* statement;
unsigned index;
};
static AST::Identifier argumentBufferParameterName(unsigned group);
static AST::Identifier dynamicOffsetVariableName();
AST::Identifier argumentBufferStructName(unsigned group);
void def(const AST::Identifier&, AST::Variable*);
std::optional<Error> collectGlobals();
std::optional<Error> visitEntryPoint(const CallGraph::EntryPoint&);
void visitCallee(const CallGraph::Callee&);
Result<UsedGlobals> determineUsedGlobals(const AST::Function&);
void collectDynamicOffsetGlobals(const PipelineLayout&);
void usesOverride(AST::Variable&);
Vector<unsigned> insertStructs(const UsedResources&);
Result<Vector<unsigned>> insertStructs(PipelineLayout&, const UsedResources&);
AST::StructureMember& createArgumentBufferEntry(unsigned binding, AST::Variable&);
AST::StructureMember& createArgumentBufferEntry(unsigned binding, const SourceSpan&, const String& name, AST::Expression& type);
void finalizeArgumentBufferStruct(unsigned group, Vector<std::pair<unsigned, AST::StructureMember*>>&);
void insertDynamicOffsetsBufferIfNeeded(const AST::Function&);
void insertDynamicOffsetsBufferIfNeeded(const SourceSpan&, const AST::Function&);
void insertParameter(const SourceSpan&, const AST::Function&, unsigned, AST::Identifier&&, AST::Expression* = nullptr, AST::ParameterRole = AST::ParameterRole::BindGroup);
void insertParameters(AST::Function&, const Vector<unsigned>&);
void insertMaterializations(AST::Function&, const UsedResources&);
void insertLocalDefinitions(AST::Function&, const UsedPrivateGlobals&);
const Global* readVariable(AST::IdentifierExpression&);
void insertBeforeCurrentStatement(AST::Statement&);
AST::Expression& bufferLengthType();
AST::Expression& bufferLengthReferenceType();
// zero initialization
void initializeVariables(AST::Function&, const UsedPrivateGlobals&, size_t);
void insertWorkgroupBarrier(AST::Function&, size_t);
AST::Identifier& findOrInsertLocalInvocationIndex(AST::Function&);
AST::Statement::List storeInitialValue(const UsedPrivateGlobals&);
void storeInitialValue(AST::Expression&, AST::Statement::List&, unsigned);
void packResource(AST::Variable&);
void packArrayResource(AST::Variable&, const Types::Array*);
void packStructResource(AST::Variable&, const Types::Struct*);
const Type* packType(const Type*);
const Type* packStructType(const Types::Struct*);
const Type* packArrayType(const Types::Array*);
void updateReference(AST::Variable&, AST::Expression&);
Packing pack(Packing, AST::Expression&);
Packing getPacking(AST::IdentifierExpression&);
Packing getPacking(AST::FieldAccessExpression&);
Packing getPacking(AST::IndexAccessExpression&);
Packing getPacking(AST::BinaryExpression&);
Packing getPacking(AST::UnaryExpression&);
Packing getPacking(AST::CallExpression&);
Packing getPacking(AST::IdentityExpression&);
Packing packingForType(const Type*);
AST::IdentifierExpression& getBase(AST::Expression&, unsigned&);
ShaderModule& m_shaderModule;
HashMap<String, Global> m_globals;
HashMap<std::tuple<unsigned, unsigned>, AST::Variable*> m_globalsByBinding;
IndexMap<Vector<std::pair<unsigned, String>>> m_groupBindingMap;
IndexMap<const Type*> m_structTypes;
HashMap<String, AST::Variable*> m_defs;
ListHashSet<String> m_reads;
HashMap<AST::Function*, ListHashSet<String>> m_lengthParameters;
HashMap<AST::Function*, ListHashSet<String>> m_visitedFunctions;
Reflection::EntryPointInformation* m_entryPointInformation { nullptr };
HashMap<uint32_t, uint32_t, DefaultHash<uint32_t>, WTF::UnsignedWithZeroKeyHashTraits<uint32_t>> m_generateLayoutGroupMapping;
PipelineLayout* m_generatedLayout { nullptr };
unsigned m_currentStatementIndex { 0 };
unsigned m_entryPointID { 0 };
Vector<Insertion> m_pendingInsertions;
HashMap<const Types::Struct*, const Type*> m_packedStructTypes;
ShaderStage m_stage { ShaderStage::Vertex };
const HashMap<String, PipelineLayout*>& m_pipelineLayouts;
HashMap<String, Reflection::EntryPointInformation>& m_entryPointInformations;
HashMap<AST::Variable*, AST::Variable*> m_bufferLengthMap;
HashMap<AST::Variable*, AST::Variable*> m_reverseBufferLengthMap;
AST::Expression* m_bufferLengthType { nullptr };
AST::Expression* m_bufferLengthReferenceType { nullptr };
AST::Function* m_currentFunction { nullptr };
HashMap<std::pair<unsigned, unsigned>, unsigned> m_globalsUsingDynamicOffset;
HashSet<AST::Expression*> m_doNotUnpack;
CheckedUint32 m_combinedFunctionVariablesSize;
};
std::optional<Error> RewriteGlobalVariables::run()
{
dataLogLnIf(shouldLogGlobalVariableRewriting, "BEGIN: GlobalVariableRewriter");
if (auto error = collectGlobals())
return error;
for (auto& entryPoint : m_shaderModule.callGraph().entrypoints()) {
auto maybeError = visitEntryPoint(entryPoint);
++m_entryPointID;
if (maybeError.has_value())
return maybeError;
}
dataLogLnIf(shouldLogGlobalVariableRewriting, "END: GlobalVariableRewriter");
return std::nullopt;
}
void RewriteGlobalVariables::visitCallee(const CallGraph::Callee& callee)
{
const auto& updateCallee = [&] {
for (auto& read : m_reads) {
auto it = m_globals.find(read);
RELEASE_ASSERT(it != m_globals.end());
auto& global = it->value;
AST::Expression* type;
if (global.declaration->flavor() == AST::VariableFlavor::Var)
type = global.declaration->maybeReferenceType();
else {
ASSERT(global.declaration->flavor() == AST::VariableFlavor::Override);
type = global.declaration->maybeTypeName();
if (!type) {
auto* storeType = global.declaration->storeType();
auto& typeExpression = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make(storeType->toString()));
typeExpression.m_inferredType = storeType;
type = &typeExpression;
}
}
ASSERT(type);
auto parameterRole = global.declaration->role() == AST::VariableRole::PackedResource ? AST::ParameterRole::PackedResource : AST::ParameterRole::UserDefined;
m_shaderModule.append(callee.target->parameters(), m_shaderModule.astBuilder().construct<AST::Parameter>(
SourceSpan::empty(),
AST::Identifier::make(read),
*type,
AST::Attribute::List { },
parameterRole
));
}
auto it = m_lengthParameters.find(callee.target);
if (it != m_lengthParameters.end() && !it->value.isEmpty()) {
auto& lengthType = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make("u32"_s));
lengthType.m_inferredType = m_shaderModule.types().u32Type();
for (auto& length : it->value) {
auto lengthName = makeString("__"_s, length, "_ArrayLength"_s);
if (m_reads.contains(lengthName))
continue;
m_shaderModule.append(callee.target->parameters(), m_shaderModule.astBuilder().construct<AST::Parameter>(
SourceSpan::empty(),
AST::Identifier::make(lengthName),
lengthType,
AST::Attribute::List { },
AST::ParameterRole::UserDefined
));
}
}
};
const auto& updateCallSites = [&] {
for (auto& read : m_reads) {
for (auto& [_, call] : callee.callSites) {
auto it = m_globals.find(read);
RELEASE_ASSERT(it != m_globals.end());
auto& global = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make(read)
);
global.m_inferredType = it->value.declaration->storeType();
m_shaderModule.append(call->arguments(), global);
m_doNotUnpack.add(&global);
}
}
auto it = m_lengthParameters.find(callee.target);
if (it != m_lengthParameters.end() && !it->value.isEmpty()) {
auto& lengthType = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make("u32"_s));
lengthType.m_inferredType = m_shaderModule.types().u32Type();
for (auto& lengthParameter : it->value) {
auto lengthName = makeString("__"_s, lengthParameter, "_ArrayLength"_s);
if (m_reads.contains(lengthName))
continue;
unsigned index = 0;
for (auto& parameter : callee.target->parameters()) {
if (parameter.name() == lengthParameter)
break;
++index;
}
for (auto& [caller, call] : callee.callSites) {
auto& argument = call->arguments()[index];
unsigned arrayOffset = 0;
auto& base = getBase(argument, arrayOffset);
auto& identifier = base.identifier();
auto result = m_lengthParameters.add(caller, ListHashSet<String> { });
result.iterator->value.add(identifier);
auto lengthName = makeString("__"_s, identifier, "_ArrayLength"_s);
auto& length = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make(lengthName)
);
length.m_inferredType = m_shaderModule.types().u32Type();
AST::Expression* lhs = &length;
if (arrayOffset) {
auto& arrayOffsetExpression = m_shaderModule.astBuilder().construct<AST::Unsigned32Literal>(
SourceSpan::empty(),
arrayOffset
);
arrayOffsetExpression.m_inferredType = m_shaderModule.types().u32Type();
lhs = &m_shaderModule.astBuilder().construct<AST::BinaryExpression>(
SourceSpan::empty(),
length,
arrayOffsetExpression,
AST::BinaryOperation::Subtract
);
lhs->m_inferredType = m_shaderModule.types().u32Type();
}
m_shaderModule.append(call->arguments(), *lhs);
}
}
}
};
auto it = m_visitedFunctions.find(callee.target);
if (it != m_visitedFunctions.end()) {
dataLogLnIf(shouldLogGlobalVariableRewriting, "> Already visited callee: ", callee.target->name());
m_reads = it->value;
updateCallSites();
return;
}
dataLogLnIf(shouldLogGlobalVariableRewriting, "> Visiting callee: ", callee.target->name());
visit(*callee.target);
updateCallee();
updateCallSites();
m_visitedFunctions.add(callee.target, m_reads);
}
void RewriteGlobalVariables::visit(AST::Function& function)
{
ListHashSet<String> reads;
for (auto& callee : m_shaderModule.callGraph().callees(function)) {
visitCallee(callee);
if (hasError())
return;
for (const auto& read : m_reads)
reads.add(read);
}
m_reads = WTF::move(reads);
m_defs.clear();
m_combinedFunctionVariablesSize = 0;
def(function.name(), nullptr);
m_currentFunction = &function;
AST::Visitor::visit(function);
m_currentFunction = nullptr;
// https://www.w3.org/TR/WGSL/#limits
constexpr unsigned maximumCombinedFunctionVariablesSize = 8192;
if (m_combinedFunctionVariablesSize.hasOverflowed() || m_combinedFunctionVariablesSize.value() > maximumCombinedFunctionVariablesSize) [[unlikely]]
setError(Error(makeString("The combined byte size of all variables in this function exceeds "_s, String::number(maximumCombinedFunctionVariablesSize), " bytes"_s), function.span()));
}
void RewriteGlobalVariables::visit(AST::Parameter& parameter)
{
def(parameter.name(), nullptr);
AST::Visitor::visit(parameter);
}
void RewriteGlobalVariables::visit(AST::Variable& variable)
{
def(variable.name(), &variable);
AST::Visitor::visit(variable);
}
void RewriteGlobalVariables::visit(AST::CompoundStatement& statement)
{
auto indexScope = SetForScope(m_currentStatementIndex, 0);
auto insertionScope = SetForScope(m_pendingInsertions, Vector<Insertion>());
for (auto& statement : statement.statements()) {
AST::Visitor::visit(statement);
++m_currentStatementIndex;
}
unsigned offset = 0;
for (auto& insertion : m_pendingInsertions) {
m_shaderModule.insert(statement.statements(), insertion.index + offset, AST::Statement::Ref(*insertion.statement));
++offset;
}
}
void RewriteGlobalVariables::visit(AST::CompoundAssignmentStatement& statement)
{
Packing lhsPacking = pack(Packing::Unpacked, statement.leftExpression());
pack(lhsPacking, statement.rightExpression());
}
void RewriteGlobalVariables::visit(AST::AssignmentStatement& statement)
{
Packing lhsPacking = pack(Packing::Either, statement.lhs());
ASSERT(lhsPacking != Packing::Either);
if (lhsPacking == Packing::PackedVec3)
lhsPacking = Packing::Either;
pack(lhsPacking, statement.rhs());
}
void RewriteGlobalVariables::visit(AST::VariableStatement& statement)
{
if (statement.variable().flavor() == AST::VariableFlavor::Var)
m_combinedFunctionVariablesSize += statement.variable().storeType()->size();
if (auto* initializer = statement.variable().maybeInitializer())
pack(static_cast<Packing>(Packing::Unpacked), *initializer);
}
void RewriteGlobalVariables::visit(AST::PhonyAssignmentStatement& statement)
{
pack(Packing::Either, statement.rhs());
}
void RewriteGlobalVariables::visit(AST::Expression& expression)
{
pack(Packing::Unpacked, expression);
}
Packing RewriteGlobalVariables::pack(Packing expectedPacking, AST::Expression& expression)
{
if (m_doNotUnpack.contains(&expression))
return expectedPacking;
const auto& visitAndReplace = [&](auto& expression) -> Packing {
auto packing = getPacking(expression);
if (expectedPacking & packing)
return packing;
auto* type = expression.inferredType();
if (auto* referenceType = std::get_if<Types::Reference>(type))
type = referenceType->element;
ASCIILiteral operation;
if (std::holds_alternative<Types::Struct>(*type)) {
if (!type->isConstructible())
return packing;
operation = packing & Packing::Packed ? "__unpack"_s : "__pack"_s;
} else if (std::holds_alternative<Types::Array>(*type)) {
if (packing & Packing::Packed) {
operation = "__unpack"_s;
m_shaderModule.setUsesUnpackArray();
} else {
operation = "__pack"_s;
m_shaderModule.setUsesPackArray();
}
} else {
if (packing & Packing::Packed) {
operation = "__unpack"_s;
m_shaderModule.setUsesUnpackVector();
} else {
operation = "__pack"_s;
m_shaderModule.setUsesPackVector();
}
}
RELEASE_ASSERT(!operation.isNull());
auto& callee = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make(operation)
);
callee.m_inferredType = m_shaderModule.types().u32Type();
auto& argument = m_shaderModule.astBuilder().construct<std::remove_cvref_t<decltype(expression)>>(expression);
auto& call = m_shaderModule.astBuilder().construct<AST::CallExpression>(
SourceSpan::empty(),
callee,
AST::Expression::List { argument }
);
call.m_inferredType = argument.inferredType();
m_shaderModule.replace(expression, call);
return static_cast<Packing>(Packing::Either ^ packing);
};
switch (expression.kind()) {
case AST::NodeKind::IdentifierExpression:
return visitAndReplace(uncheckedDowncast<AST::IdentifierExpression>(expression));
case AST::NodeKind::FieldAccessExpression:
return visitAndReplace(uncheckedDowncast<AST::FieldAccessExpression>(expression));
case AST::NodeKind::IndexAccessExpression:
return visitAndReplace(uncheckedDowncast<AST::IndexAccessExpression>(expression));
case AST::NodeKind::BinaryExpression:
return visitAndReplace(uncheckedDowncast<AST::BinaryExpression>(expression));
case AST::NodeKind::UnaryExpression:
return visitAndReplace(uncheckedDowncast<AST::UnaryExpression>(expression));
case AST::NodeKind::CallExpression:
return visitAndReplace(uncheckedDowncast<AST::CallExpression>(expression));
case AST::NodeKind::IdentityExpression:
return visitAndReplace(uncheckedDowncast<AST::IdentityExpression>(expression));
default:
AST::Visitor::visit(expression);
return Packing::Unpacked;
}
}
Packing RewriteGlobalVariables::getPacking(AST::IdentifierExpression& identifier)
{
auto* global = readVariable(identifier);
if (global && global->resource.has_value())
return packingForType(identifier.inferredType());
return Packing::Unpacked;
}
Packing RewriteGlobalVariables::getPacking(AST::FieldAccessExpression& expression)
{
auto* baseType = expression.base().inferredType();
if (!baseType) {
// All AST nodes should have an inferred type, but we create field
// access nodes from the EntryPointRewriter which don't have a trivial
// type, so we work around it by returning unpacked since those are only
// used for marshalling inputs/outputs and don't need to packed/unpacked.
return Packing::Unpacked;
}
if (auto* referenceType = std::get_if<Types::Reference>(baseType))
baseType = referenceType->element;
bool isPointer = false;
if (auto* pointerType = std::get_if<Types::Pointer>(baseType)) {
isPointer = true;
baseType = pointerType->element;
}
if (std::holds_alternative<Types::Vector>(*baseType)) {
if (std::holds_alternative<Types::Vector>(*expression.inferredType())) {
if (isPointer) {
auto& dereference = m_shaderModule.astBuilder().construct<AST::UnaryExpression>(
expression.base().span(),
expression.base(),
AST::UnaryOperation::Dereference
);
dereference.m_inferredType = baseType;
auto& fieldAccessExpression = m_shaderModule.astBuilder().construct<AST::FieldAccessExpression>(
expression.span(),
dereference,
AST::Identifier::make(expression.fieldName())
);
fieldAccessExpression.m_inferredType = expression.inferredType();
m_shaderModule.replace(expression, fieldAccessExpression);
}
pack(Packing::Unpacked, expression.base());
} else
pack(Packing::Either, expression.base());
return Packing::Unpacked;
}
auto basePacking = pack(Packing::Either, expression.base());
if (basePacking & Packing::Unpacked)
return Packing::Unpacked;
ASSERT(std::holds_alternative<Types::Struct>(*baseType));
auto& structType = std::get<Types::Struct>(*baseType);
auto* fieldType = structType.fields.get(expression.originalFieldName());
return packingForType(fieldType);
}
Packing RewriteGlobalVariables::getPacking(AST::IndexAccessExpression& expression)
{
auto basePacking = pack(Packing::Either, expression.base());
pack(Packing::Unpacked, expression.index());
if (basePacking & Packing::Unpacked)
return Packing::Unpacked;
auto* baseType = expression.base().inferredType();
if (auto* referenceType = std::get_if<Types::Reference>(baseType))
baseType = referenceType->element;
if (auto* pointerType = std::get_if<Types::Pointer>(baseType))
baseType = pointerType->element;
if (std::holds_alternative<Types::Vector>(*baseType))
return Packing::Unpacked;
ASSERT(std::holds_alternative<Types::Array>(*baseType));
auto& arrayType = std::get<Types::Array>(*baseType);
return packingForType(arrayType.element);
}
Packing RewriteGlobalVariables::getPacking(AST::BinaryExpression& expression)
{
pack(Packing::Unpacked, expression.leftExpression());
pack(Packing::Unpacked, expression.rightExpression());
return Packing::Unpacked;
}
Packing RewriteGlobalVariables::getPacking(AST::UnaryExpression& expression)
{
if (expression.operation() == AST::UnaryOperation::AddressOf) {
pack(Packing::Either, expression.expression());
// we can't pack/unpack pointers, so we return Either to avoid that
return Packing::Either;
}
if (expression.operation() == AST::UnaryOperation::Dereference) {
// similarly to above, pointers are handled differently, so we can't trust
// the packing of the underlying element and instead we skip any packing
// of the operand (since we can't pack/unpack pointers) and return the
// packing of the resulting type from dereferencing
pack(Packing::Either, expression.expression());
auto* pointer = std::get_if<Types::Pointer>(expression.expression().inferredType());
if (pointer->addressSpace == AddressSpace::Storage || pointer->addressSpace == AddressSpace::Uniform)
return expression.inferredType()->packing();
return Packing::Unpacked;
}
return pack(Packing::Unpacked, expression.expression());
}
Packing RewriteGlobalVariables::getPacking(AST::IdentityExpression& expression)
{
return pack(Packing::Either, expression.expression());
}
Packing RewriteGlobalVariables::getPacking(AST::CallExpression& call)
{
if (auto target = dynamicDowncast<AST::IdentifierExpression>(call.target())) {
if (target->identifier() == "arrayLength"_s) {
ASSERT(call.arguments().size() == 1);
auto& arrayPointer = call.arguments()[0];
unsigned arrayOffset = 0;
auto& base = getBase(arrayPointer, arrayOffset);
auto& identifier = base.identifier();
if (!m_globals.contains(identifier)) {
auto result = m_lengthParameters.add(m_currentFunction, ListHashSet<String> { });
result.iterator->value.add(identifier);
}
auto lengthName = makeString("__"_s, identifier, "_ArrayLength"_s);
auto& length = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make(lengthName)
);
length.m_inferredType = m_shaderModule.types().u32Type();
auto* arrayPointerType = arrayPointer.inferredType();
ASSERT(std::holds_alternative<Types::Pointer>(*arrayPointerType));
auto& arrayType = std::get<Types::Pointer>(*arrayPointerType).element;
ASSERT(std::holds_alternative<Types::Array>(*arrayType));
auto* elementType = std::get<Types::Array>(*arrayType).element;
auto arrayStride = elementType->size();
arrayStride = WTF::roundUpToMultipleOf(elementType->alignment(), arrayStride);
auto& strideExpression = m_shaderModule.astBuilder().construct<AST::Unsigned32Literal>(
SourceSpan::empty(),
arrayStride
);
strideExpression.m_inferredType = m_shaderModule.types().u32Type();
AST::Expression* lhs = &length;
if (arrayOffset) {
auto& arrayOffsetExpression = m_shaderModule.astBuilder().construct<AST::Unsigned32Literal>(
SourceSpan::empty(),
arrayOffset
);
arrayOffsetExpression.m_inferredType = m_shaderModule.types().u32Type();
lhs = &m_shaderModule.astBuilder().construct<AST::BinaryExpression>(
SourceSpan::empty(),
length,
arrayOffsetExpression,
AST::BinaryOperation::Subtract
);
lhs->m_inferredType = m_shaderModule.types().u32Type();
}
m_shaderModule.setUsesDivision();
auto& elementCount = m_shaderModule.astBuilder().construct<AST::BinaryExpression>(
SourceSpan::empty(),
*lhs,
strideExpression,
AST::BinaryOperation::Divide
);
elementCount.m_inferredType = m_shaderModule.types().u32Type();
m_shaderModule.replace(call, elementCount);
// mark both the array and array length as read
readVariable(base);
readVariable(length);
return Packing::Unpacked;
}
}
for (auto& argument : call.arguments())
pack(Packing::Unpacked, argument);
return Packing::Unpacked;
}
Packing RewriteGlobalVariables::packingForType(const Type* type)
{
return type->packing();
}
AST::IdentifierExpression& RewriteGlobalVariables::getBase(AST::Expression& expression, unsigned& arrayOffset)
{
if (auto* identityExpression = dynamicDowncast<AST::IdentityExpression>(expression))
return getBase(identityExpression->expression(), arrayOffset);
if (auto* unaryExpression = dynamicDowncast<AST::UnaryExpression>(expression))
return getBase(unaryExpression->expression(), arrayOffset);
if (auto* fieldAccess = dynamicDowncast<AST::FieldAccessExpression>(expression)) {
auto& base = fieldAccess->base();
auto* type = base.inferredType();
if (auto* reference = std::get_if<Types::Reference>(type))
type = reference->element;
if (auto* pointer = std::get_if<Types::Pointer>(type))
type = pointer->element;
auto& structure = std::get<Types::Struct>(*type).structure;
auto& lastMember = structure.members().last();
RELEASE_ASSERT(lastMember.name().id() == fieldAccess->fieldName().id());
arrayOffset += lastMember.offset();
return getBase(base, arrayOffset);
}
if (auto* identifierExpression = dynamicDowncast<AST::IdentifierExpression>(expression))
return *identifierExpression;
RELEASE_ASSERT_NOT_REACHED();
}
static unsigned buffersForStage(const Configuration& configuration, ShaderStage stage)
{
switch (stage) {
case ShaderStage::Compute:
return configuration.maxBuffersForComputeStage;
case ShaderStage::Vertex:
return configuration.maxBuffersPlusVertexBuffersForVertexStage;
case ShaderStage::Fragment:
return configuration.maxBuffersForFragmentStage;
}
}
std::optional<Error> RewriteGlobalVariables::collectGlobals()
{
Vector<std::tuple<AST::Variable*, unsigned>> bufferLengths;
// we can't use a range-based for loop here since we might create new structs
// and insert them into the declarations vector
auto size = m_shaderModule.declarations().size();
for (unsigned i = 0; i < size; ++i) {
auto* globalVar = dynamicDowncast<AST::Variable>(m_shaderModule.declarations()[i]);
if (!globalVar)
continue;
std::optional<Global::Resource> resource;
if (globalVar->group().has_value()) {
RELEASE_ASSERT(globalVar->binding().has_value());
unsigned bufferIndex = *globalVar->group();
auto buffersCountForStage = buffersForStage(m_shaderModule.configuration(), m_stage);
if (bufferIndex >= buffersCountForStage)
return Error(makeString("global has buffer index "_s, bufferIndex, " which exceeds the max allowed buffer index "_s, buffersCountForStage, " for this stage"_s), SourceSpan::empty());
resource = { *globalVar->group(), *globalVar->binding() };
}
dataLogLnIf(shouldLogGlobalVariableRewriting, "> Found global: ", globalVar->name(), ", isResource: ", resource.has_value() ? "yes" : "no");
auto result = m_globals.add(globalVar->name(), Global {
resource,
globalVar
});
ASSERT_UNUSED(result, result.isNewEntry);
if (resource.has_value()) {
m_globalsByBinding.add({ resource->group + 1, resource->binding + 1 }, globalVar);
auto result = m_groupBindingMap.add(resource->group, Vector<std::pair<unsigned, String>>());
result.iterator->value.append({ resource->binding, globalVar->name() });
packResource(*globalVar);
bufferLengths.append({ globalVar, resource->group });
}
}
for (auto& [_, vector] : m_groupBindingMap)
std::ranges::sort(vector, { }, &std::pair<unsigned, String>::first);
if (!bufferLengths.isEmpty()) {
for (const auto& [variable, group] : bufferLengths) {
auto name = AST::Identifier::make(makeString("__"_s, variable->name(), "_ArrayLength"_s));
auto& lengthVariable = m_shaderModule.astBuilder().construct<AST::Variable>(
SourceSpan::empty(),
AST::VariableFlavor::Var,
AST::Identifier::make(name),
&bufferLengthType(),
nullptr
);
lengthVariable.m_referenceType = &bufferLengthReferenceType();
auto it = m_groupBindingMap.find(group);
ASSERT(it != m_groupBindingMap.end());
auto binding = it->value.last().first + 1;
it->value.append({ binding, name });
auto result = m_globals.add(name, Global {
{ {
group,
binding,
} },
&lengthVariable
});
ASSERT_UNUSED(result, result.isNewEntry);
m_bufferLengthMap.add(variable, &lengthVariable);
m_reverseBufferLengthMap.add(&lengthVariable, variable);
}
}
return std::nullopt;
}
AST::Expression& RewriteGlobalVariables::bufferLengthType()
{
if (m_bufferLengthType)
return *m_bufferLengthType;
m_bufferLengthType = &m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make("u32"_s));
m_bufferLengthType->m_inferredType = m_shaderModule.types().u32Type();
return *m_bufferLengthType;
}
AST::Expression& RewriteGlobalVariables::bufferLengthReferenceType()
{
if (m_bufferLengthReferenceType)
return *m_bufferLengthReferenceType;
m_bufferLengthReferenceType = &m_shaderModule.astBuilder().construct<AST::ReferenceTypeExpression>(
SourceSpan::empty(),
bufferLengthType()
);
m_bufferLengthReferenceType->m_inferredType = m_shaderModule.types().referenceType(AddressSpace::Handle, m_shaderModule.types().u32Type(), AccessMode::Read);
return *m_bufferLengthReferenceType;
}
void RewriteGlobalVariables::packResource(AST::Variable& global)
{
auto* maybeTypeName = global.maybeTypeName();
ASSERT(maybeTypeName);
auto* resolvedType = maybeTypeName->inferredType();
if (auto* arrayType = std::get_if<Types::Array>(resolvedType)) {
packArrayResource(global, arrayType);
return;
}
if (auto* structType = std::get_if<Types::Struct>(resolvedType)) {
packStructResource(global, structType);
return;
}
}
void RewriteGlobalVariables::packStructResource(AST::Variable& global, const Types::Struct* structType)
{
const Type* packedStructType = packStructType(structType);
if (!packedStructType)
return;
auto& packedType = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make(std::get<Types::Struct>(*packedStructType).structure.name().id())
);
packedType.m_inferredType = packedStructType;
auto& namedTypeName = downcast<AST::IdentifierExpression>(*global.maybeTypeName());
m_shaderModule.replace(namedTypeName, packedType);
updateReference(global, packedType);
m_shaderModule.replace(&global.role(), AST::VariableRole::PackedResource);
}
void RewriteGlobalVariables::packArrayResource(AST::Variable& global, const Types::Array* arrayType)
{
const Type* packedArrayType = packArrayType(arrayType);
if (!packedArrayType)
return;
const Type* packedElementType = std::get<Types::Array>(*packedArrayType).element;
auto& packedType = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make("__PackedArray"_s) // This name is not used by the codegen
);
packedType.m_inferredType = packedElementType;
auto& arrayTypeName = downcast<AST::ArrayTypeExpression>(*global.maybeTypeName());
auto& packedArrayTypeName = m_shaderModule.astBuilder().construct<AST::ArrayTypeExpression>(
arrayTypeName.span(),
&packedType,
arrayTypeName.maybeElementCount()
);
packedArrayTypeName.m_inferredType = packedArrayType;
m_shaderModule.replace(arrayTypeName, packedArrayTypeName);
updateReference(global, packedArrayTypeName);
m_shaderModule.replace(&global.role(), AST::VariableRole::PackedResource);
}
void RewriteGlobalVariables::updateReference(AST::Variable& global, AST::Expression& packedType)
{
auto* maybeReference = global.maybeReferenceType();
ASSERT(maybeReference);
auto& reference = downcast<AST::ReferenceTypeExpression>(*maybeReference);
auto* referenceType = std::get_if<Types::Reference>(reference.inferredType());
ASSERT(referenceType);
auto& packedTypeReference = m_shaderModule.astBuilder().construct<AST::ReferenceTypeExpression>(
SourceSpan::empty(),
packedType
);
packedTypeReference.m_inferredType = m_shaderModule.types().referenceType(
referenceType->addressSpace,
packedType.inferredType(),
referenceType->accessMode
);
m_shaderModule.replace(reference, packedTypeReference);
}
const Type* RewriteGlobalVariables::packType(const Type* type)
{
if (auto* structType = std::get_if<Types::Struct>(type))
return packStructType(structType);
if (auto* arrayType = std::get_if<Types::Array>(type))
return packArrayType(arrayType);
if (auto* vectorType = std::get_if<Types::Vector>(type)) {
if (vectorType->size == 3) {
m_shaderModule.setUsesPackedVec3();
return type;
}
}
return nullptr;
}
const Type* RewriteGlobalVariables::packStructType(const Types::Struct* structType)
{
if (structType->structure.role() == AST::StructureRole::UserDefinedResource)
return m_packedStructTypes.get(structType);
// Ensure we pack nested structs
bool packedAnyMember = false;
for (auto& member : structType->structure.members()) {
if (packType(member.type().inferredType()))
packedAnyMember = true;
}
if (!packedAnyMember && !structType->structure.hasSizeOrAlignmentAttributes())
return nullptr;
ASSERT(structType->structure.role() == AST::StructureRole::UserDefined);
m_shaderModule.replace(&structType->structure.role(), AST::StructureRole::UserDefinedResource);
String packedStructName = makeString("__"_s, structType->structure.name(), "_Packed"_s);
auto& packedStruct = m_shaderModule.astBuilder().construct<AST::Structure>(
SourceSpan::empty(),
AST::Identifier::make(packedStructName),
AST::StructureMember::List(structType->structure.members()),
AST::Attribute::List { },
AST::StructureRole::PackedResource,
&structType->structure
);
m_shaderModule.append(m_shaderModule.declarations(), packedStruct);
const Type* packedStructType = m_shaderModule.types().structType(packedStruct);
packedStruct.m_inferredType = packedStructType;
m_packedStructTypes.add(structType, packedStructType);
return packedStructType;
}
const Type* RewriteGlobalVariables::packArrayType(const Types::Array* arrayType)
{
auto* packedElementType = packType(arrayType->element);
if (!packedElementType)
return nullptr;
m_shaderModule.setUsesUnpackArray();
m_shaderModule.setUsesPackArray();
return m_shaderModule.types().arrayType(packedElementType, arrayType->size);
}
static size_t getRoundedSize(const AST::Variable& variable)
{
auto* type = variable.storeType();
return roundUpToMultipleOf(16, type ? type->size() : 0);
}
void RewriteGlobalVariables::insertParameter(const SourceSpan& span, const AST::Function& function, unsigned group, AST::Identifier&& name, AST::Expression* type, AST::ParameterRole parameterRole)
{
if (!type) {
type = &m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(span, argumentBufferStructName(group));
type->m_inferredType = m_structTypes.get(group);
}
auto& groupValue = m_shaderModule.astBuilder().construct<AST::AbstractIntegerLiteral>(span, group);
groupValue.m_inferredType = m_shaderModule.types().abstractIntType();
groupValue.setConstantValue(group);
auto& groupAttribute = m_shaderModule.astBuilder().construct<AST::GroupAttribute>(span, groupValue);
m_shaderModule.append(function.parameters(), m_shaderModule.astBuilder().construct<AST::Parameter>(
span,
WTF::move(name),
*type,
AST::Attribute::List { groupAttribute },
parameterRole
));
};
std::optional<Error> RewriteGlobalVariables::visitEntryPoint(const CallGraph::EntryPoint& entryPoint)
{
dataLogLnIf(shouldLogGlobalVariableRewriting, "> Visiting entrypoint: ", entryPoint.function.name());
auto it = m_pipelineLayouts.find(entryPoint.originalName);
if (it == m_pipelineLayouts.end())
return std::nullopt;
m_reads.clear();
m_structTypes.clear();
m_globalsUsingDynamicOffset.clear();
m_generateLayoutGroupMapping.clear();
auto result = m_entryPointInformations.add(entryPoint.originalName, Reflection::EntryPointInformation { });
RELEASE_ASSERT(result.isNewEntry);
m_stage = entryPoint.stage;
m_entryPointInformation = &result.iterator->value;
m_entryPointInformation->originalName = entryPoint.originalName;
m_entryPointInformation->mangledName = entryPoint.function.name();
switch (m_stage) {
case ShaderStage::Compute: {
for (auto& attribute : entryPoint.function.attributes()) {
auto* workgroupSize = dynamicDowncast<AST::WorkgroupSizeAttribute>(attribute);
if (!workgroupSize)
continue;
m_entryPointInformation->typedEntryPoint = Reflection::Compute { &workgroupSize->x(), workgroupSize->maybeY(), workgroupSize->maybeZ() };
break;
}
break;
}
case ShaderStage::Vertex:
m_entryPointInformation->typedEntryPoint = Reflection::Vertex { false };
if (entryPoint.function.returnTypeInvariant())
m_entryPointInformation->usesInvariant = true;
else if (auto* returnType = entryPoint.function.maybeReturnType()) {
if (auto* structType = std::get_if<Types::Struct>(returnType->inferredType())) {
for (const auto& member : structType->structure.members()) {
if (member.invariant()) {
m_entryPointInformation->usesInvariant = true;
break;
}
}
}
}
break;
case ShaderStage::Fragment:
m_entryPointInformation->typedEntryPoint = Reflection::Fragment { };
break;
}
if (!it->value) {
m_entryPointInformation->defaultLayout = { PipelineLayout { } };
m_generatedLayout = &*m_entryPointInformation->defaultLayout;
} else {
m_generatedLayout = nullptr;
collectDynamicOffsetGlobals(*it->value);
}
visit(entryPoint.function);
if (hasError())
return AST::Visitor::result().error();
if (m_reads.isEmpty()) {
insertDynamicOffsetsBufferIfNeeded(entryPoint.function);
return std::nullopt;
}
auto maybeUsedGlobals = determineUsedGlobals(entryPoint.function);
if (!maybeUsedGlobals) {
insertDynamicOffsetsBufferIfNeeded(entryPoint.function);
return maybeUsedGlobals.error();
}
auto usedGlobals = *maybeUsedGlobals;
auto maybeGroups = m_generatedLayout ? Result<Vector<unsigned>>(insertStructs(usedGlobals.resources)) : insertStructs(*it->value, usedGlobals.resources);
if (!maybeGroups) {
insertDynamicOffsetsBufferIfNeeded(entryPoint.function);
return maybeGroups.error();
}
insertParameters(entryPoint.function, *maybeGroups);
insertMaterializations(entryPoint.function, usedGlobals.resources);
insertLocalDefinitions(entryPoint.function, usedGlobals.privateGlobals);
for (auto* global : usedGlobals.privateGlobals) {
if (!global || !global->declaration)
continue;
auto* variable = global->declaration;
if (variable->addressSpace() == AddressSpace::Workgroup)
m_entryPointInformation->sizeForWorkgroupVariables += getRoundedSize(*variable);
}
return std::nullopt;
}
void RewriteGlobalVariables::collectDynamicOffsetGlobals(const PipelineLayout& pipelineLayout)
{
unsigned group = 0;
for (const auto& bindGroupLayout : pipelineLayout.bindGroupLayouts) {
for (const auto& entry : bindGroupLayout.entries) {
if (!entry.visibility.contains(m_stage))
continue;
auto bufferDynamicOffset = [&] {
switch (m_stage) {
case ShaderStage::Vertex:
return entry.vertexBufferDynamicOffset;
case ShaderStage::Fragment:
return entry.fragmentBufferDynamicOffset;
case ShaderStage::Compute:
return entry.computeBufferDynamicOffset;
}
}();
if (!bufferDynamicOffset.has_value())
continue;
m_globalsUsingDynamicOffset.add({ group + 1, entry.binding + 1 }, *bufferDynamicOffset);
}
++group;
}
}
static WGSL::StorageTextureAccess convertAccess(const AccessMode accessMode)
{
switch (accessMode) {
case AccessMode::Read:
return WGSL::StorageTextureAccess::ReadOnly;
case AccessMode::ReadWrite:
return WGSL::StorageTextureAccess::ReadWrite;
case AccessMode::Write:
return WGSL::StorageTextureAccess::WriteOnly;
}
}
static BindGroupLayoutEntry::BindingMember bindingMemberForGlobal(auto& global)
{
auto* variable = global.declaration;
ASSERT(variable);
auto* maybeReference = variable->maybeReferenceType();
auto* type = variable->storeType();
ASSERT(type);
auto addressSpace = [&]() {
if (maybeReference) {
auto& reference = downcast<AST::ReferenceTypeExpression>(*maybeReference);
auto* referenceType = std::get_if<Types::Reference>(reference.inferredType());
if (referenceType && referenceType->addressSpace == AddressSpace::Storage)
return referenceType->accessMode == AccessMode::Read ? BufferBindingType::ReadOnlyStorage : BufferBindingType::Storage;
}
return BufferBindingType::Uniform;
};
using namespace WGSL::Types;
return WTF::switchOn(*type, [&](const Primitive& primitive) -> BindGroupLayoutEntry::BindingMember {
switch (primitive.kind) {
case Types::Primitive::AbstractInt:
case Types::Primitive::I32:
case Types::Primitive::U32:
case Types::Primitive::AbstractFloat:
case Types::Primitive::F32:
case Types::Primitive::F16:
case Types::Primitive::Void:
case Types::Primitive::Bool:
return BufferBindingLayout {
.type = addressSpace(),
.hasDynamicOffset = false,
.minBindingSize = type->size()
};
case Types::Primitive::Sampler:
return SamplerBindingLayout {
.type = SamplerBindingType::Filtering
};
case Types::Primitive::SamplerComparison:
return SamplerBindingLayout {
.type = SamplerBindingType::Comparison
};
case Types::Primitive::TextureExternal:
return ExternalTextureBindingLayout { };
case Types::Primitive::AccessMode:
case Types::Primitive::TexelFormat:
case Types::Primitive::AddressSpace:
RELEASE_ASSERT_NOT_REACHED();
}
}, [&](const Vector& vector) -> BindGroupLayoutEntry::BindingMember {
auto* primitive = std::get_if<Primitive>(vector.element);
UNUSED_PARAM(primitive);
return BufferBindingLayout {
.type = addressSpace(),
.hasDynamicOffset = false,
.minBindingSize = type->size()
};
}, [&](const Matrix& matrix) -> BindGroupLayoutEntry::BindingMember {
UNUSED_PARAM(matrix);
return BufferBindingLayout {
.type = addressSpace(),
.hasDynamicOffset = false,
.minBindingSize = type->size()
};
}, [&](const Array& array) -> BindGroupLayoutEntry::BindingMember {
UNUSED_PARAM(array);
return BufferBindingLayout {
.type = addressSpace(),
.hasDynamicOffset = false,
.minBindingSize = type->size()
};
}, [&](const Struct& structure) -> BindGroupLayoutEntry::BindingMember {
UNUSED_PARAM(structure);
return BufferBindingLayout {
.type = addressSpace(),
.hasDynamicOffset = false,
.minBindingSize = type->size()
};
}, [&](const Texture& texture) -> BindGroupLayoutEntry::BindingMember {
TextureViewDimension viewDimension;
bool multisampled = false;
switch (texture.kind) {
case Types::Texture::Kind::Texture1d:
viewDimension = TextureViewDimension::OneDimensional;
break;
case Types::Texture::Kind::Texture2d:
viewDimension = TextureViewDimension::TwoDimensional;
break;
case Types::Texture::Kind::Texture2dArray:
viewDimension = TextureViewDimension::TwoDimensionalArray;
break;
case Types::Texture::Kind::Texture3d:
viewDimension = TextureViewDimension::ThreeDimensional;
break;
case Types::Texture::Kind::TextureCube:
viewDimension = TextureViewDimension::Cube;
break;
case Types::Texture::Kind::TextureCubeArray:
viewDimension = TextureViewDimension::CubeArray;
break;
case Types::Texture::Kind::TextureMultisampled2d:
viewDimension = TextureViewDimension::TwoDimensional;
multisampled = true;
break;
}
return TextureBindingLayout {
.sampleType = TextureSampleType::UnfilterableFloat,
.viewDimension = viewDimension,
.multisampled = multisampled
};
}, [&](const TextureStorage& texture) -> BindGroupLayoutEntry::BindingMember {
TextureViewDimension viewDimension;
switch (texture.kind) {
case Types::TextureStorage::Kind::TextureStorage1d:
viewDimension = TextureViewDimension::OneDimensional;
break;
case Types::TextureStorage::Kind::TextureStorage2d:
viewDimension = TextureViewDimension::TwoDimensional;
break;
case Types::TextureStorage::Kind::TextureStorage2dArray:
viewDimension = TextureViewDimension::TwoDimensionalArray;
break;
case Types::TextureStorage::Kind::TextureStorage3d:
viewDimension = TextureViewDimension::ThreeDimensional;
break;
}
return StorageTextureBindingLayout {
.access = convertAccess(texture.access),
.format = texture.format,
.viewDimension = viewDimension
};
}, [&](const TextureDepth& texture) -> BindGroupLayoutEntry::BindingMember {
TextureViewDimension viewDimension;
bool multisampled = false;
switch (texture.kind) {
case Types::TextureDepth::Kind::TextureDepth2d:
viewDimension = TextureViewDimension::TwoDimensional;
break;
case Types::TextureDepth::Kind::TextureDepth2dArray:
viewDimension = TextureViewDimension::TwoDimensionalArray;
break;
case Types::TextureDepth::Kind::TextureDepthCube:
viewDimension = TextureViewDimension::Cube;
break;
case Types::TextureDepth::Kind::TextureDepthCubeArray:
viewDimension = TextureViewDimension::CubeArray;
break;
case Types::TextureDepth::Kind::TextureDepthMultisampled2d:
viewDimension = TextureViewDimension::TwoDimensional;
multisampled = true;
break;
}
return TextureBindingLayout {
.sampleType = TextureSampleType::Depth,
.viewDimension = viewDimension,
.multisampled = multisampled
};
}, [&](const Atomic&) -> BindGroupLayoutEntry::BindingMember {
return BufferBindingLayout {
.type = addressSpace(),
.hasDynamicOffset = false,
.minBindingSize = type->size()
};
}, [&](const PrimitiveStruct&) -> BindGroupLayoutEntry::BindingMember {
RELEASE_ASSERT_NOT_REACHED();
}, [&](const Reference&) -> BindGroupLayoutEntry::BindingMember {
RELEASE_ASSERT_NOT_REACHED();
}, [&](const Pointer&) -> BindGroupLayoutEntry::BindingMember {
RELEASE_ASSERT_NOT_REACHED();
}, [&](const Function&) -> BindGroupLayoutEntry::BindingMember {
RELEASE_ASSERT_NOT_REACHED();
}, [&](const TypeConstructor&) -> BindGroupLayoutEntry::BindingMember {
RELEASE_ASSERT_NOT_REACHED();
});
}
auto RewriteGlobalVariables::determineUsedGlobals(const AST::Function& function) -> Result<UsedGlobals>
{
UsedGlobals usedGlobals;
// https://www.w3.org/TR/WGSL/#limits
constexpr unsigned maximumCombinedPrivateVariablesSize = 8192;
unsigned maximumCombinedWorkgroupVariablesSize = m_shaderModule.configuration().maximumCombinedWorkgroupVariablesSize;
Vector<const Type*, 16> workgroupVariables;
Vector<const Type*, 16> privateVariables;
for (const auto& globalName : m_reads) {
auto it = m_globals.find(globalName);
RELEASE_ASSERT(it != m_globals.end());
auto& global = it->value;
AST::Variable& variable = *global.declaration;
switch (variable.flavor()) {
case AST::VariableFlavor::Override:
usesOverride(variable);
continue;
case AST::VariableFlavor::Var:
case AST::VariableFlavor::Let:
case AST::VariableFlavor::Const:
if (!global.resource.has_value()) {
usedGlobals.privateGlobals.append(&global);
if (auto* qualifier = variable.maybeQualifier(); qualifier && qualifier->addressSpace() == AddressSpace::Workgroup)
workgroupVariables.append(variable.storeType());
else
privateVariables.append(variable.storeType());
continue;
}
break;
}
auto group = global.resource->group;
auto binding = global.resource->binding;
auto groupResult = usedGlobals.resources.add(group, IndexMap<Global*>());
auto bindingResult = groupResult.iterator->value.add(binding, &global);
// FIXME: <rdar://150368198> this check needs to occur during WGSL::staticCheck
if (!bindingResult.isNewEntry)
return makeUnexpected(Error(makeString("entry point '"_s, m_entryPointInformation->originalName, "' uses variables '"_s, bindingResult.iterator->value->declaration->originalName(), "' and '"_s, variable.originalName(), "', both which use the same resource binding: @group("_s, group, ") @binding("_s, binding, ')'), variable.span()));
}
m_shaderModule.addOverrideValidation([span = function.span(), variables = WTF::move(workgroupVariables), maximumCombinedWorkgroupVariablesSize] -> std::optional<Error> {
CheckedUint32 combinedWorkgroupVariablesSize = 0;
for (const Type* type : variables)
combinedWorkgroupVariablesSize += type->size();
if (combinedWorkgroupVariablesSize.hasOverflowed() || combinedWorkgroupVariablesSize.value() > maximumCombinedWorkgroupVariablesSize) [[unlikely]]
return { Error(makeString("The combined byte size of all variables in the workgroup address space exceeds "_s, String::number(maximumCombinedWorkgroupVariablesSize), " bytes"_s), span) };
return std::nullopt;
});
m_shaderModule.addOverrideValidation([span = function.span(), variables = WTF::move(privateVariables)] -> std::optional<Error> {
CheckedUint32 combinedPrivateVariablesSize = 0;
for (const Type* type : variables)
combinedPrivateVariablesSize += type->size();
if (combinedPrivateVariablesSize.hasOverflowed() || combinedPrivateVariablesSize.value() > maximumCombinedPrivateVariablesSize) [[unlikely]]
return { Error(makeString("The combined byte size of all variables in the private address space exceeds "_s, String::number(maximumCombinedPrivateVariablesSize), " bytes"_s), span) };
return std::nullopt;
});
return usedGlobals;
}
void RewriteGlobalVariables::usesOverride(AST::Variable& variable)
{
Reflection::SpecializationConstantType constantType;
const Type* type = variable.storeType();
ASSERT(std::holds_alternative<Types::Primitive>(*type));
const auto& primitive = std::get<Types::Primitive>(*type);
switch (primitive.kind) {
case Types::Primitive::Bool:
constantType = Reflection::SpecializationConstantType::Boolean;
break;
case Types::Primitive::F32:
constantType = Reflection::SpecializationConstantType::Float;
break;
case Types::Primitive::F16:
constantType = Reflection::SpecializationConstantType::Half;
break;
case Types::Primitive::I32:
constantType = Reflection::SpecializationConstantType::Int;
break;
case Types::Primitive::U32:
constantType = Reflection::SpecializationConstantType::Unsigned;
break;
case Types::Primitive::Void:
case Types::Primitive::AbstractInt:
case Types::Primitive::AbstractFloat:
case Types::Primitive::Sampler:
case Types::Primitive::SamplerComparison:
case Types::Primitive::TextureExternal:
case Types::Primitive::AccessMode:
case Types::Primitive::TexelFormat:
case Types::Primitive::AddressSpace:
RELEASE_ASSERT_NOT_REACHED();
}
String entryName = variable.originalName();
if (variable.id())
entryName = String::number(*variable.id());
m_entryPointInformation->specializationConstants.add(entryName, Reflection::SpecializationConstant { variable.name(), constantType, variable.maybeInitializer() });
}
enum class BindingType {
Undefined,
Buffer,
Texture,
TextureMultisampled,
TextureStorageReadOnly,
TextureStorageReadWrite,
TextureStorageWriteOnly,
Sampler,
SamplerComparison,
TextureExternal,
};
static BindingType bindingTypeForPrimitive(const Types::Primitive& primitive)
{
switch (primitive.kind) {
case Types::Primitive::AbstractInt:
case Types::Primitive::AbstractFloat:
case Types::Primitive::I32:
case Types::Primitive::U32:
case Types::Primitive::F32:
case Types::Primitive::F16:
case Types::Primitive::Bool:
case Types::Primitive::Void:
return BindingType::Buffer;
case Types::Primitive::Sampler:
return BindingType::Sampler;
case Types::Primitive::SamplerComparison:
return BindingType::SamplerComparison;
case Types::Primitive::TextureExternal:
return BindingType::TextureExternal;
case Types::Primitive::AccessMode:
case Types::Primitive::TexelFormat:
case Types::Primitive::AddressSpace:
return BindingType::Undefined;
}
}
static BindingType bindingTypeForType(const Type* type)
{
if (!type)
return BindingType::Undefined;
return WTF::switchOn(*type,
[&](const Types::Primitive& primitive) {
return bindingTypeForPrimitive(primitive);
},
[&](const Types::Vector&) {
return BindingType::Buffer;
},
[&](const Types::Array&) -> BindingType {
return BindingType::Buffer;
},
[&](const Types::Struct&) -> BindingType {
return BindingType::Buffer;
},
[&](const Types::PrimitiveStruct&) -> BindingType {
return BindingType::Buffer;
},
[&](const Types::Matrix&) {
return BindingType::Buffer;
},
[&](const Types::Reference&) -> BindingType {
return BindingType::Buffer;
},
[&](const Types::Pointer&) -> BindingType {
return BindingType::Buffer;
},
[&](const Types::Function&) -> BindingType {
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::Texture& texture) {
return texture.kind == Types::Texture::Kind::TextureMultisampled2d ? BindingType::TextureMultisampled : BindingType::Texture;
},
[&](const Types::TextureStorage& storageTexture) {
switch (storageTexture.access) {
case AccessMode::Read:
return BindingType::TextureStorageReadOnly;
case AccessMode::ReadWrite:
return BindingType::TextureStorageReadWrite;
case AccessMode::Write:
return BindingType::TextureStorageWriteOnly;
}
},
[&](const Types::TextureDepth& texture) {
return texture.kind == Types::TextureDepth::Kind::TextureDepthMultisampled2d ? BindingType::TextureMultisampled : BindingType::Texture;
},
[&](const Types::Atomic&) -> BindingType {
return BindingType::Buffer;
},
[&](const Types::TypeConstructor&) -> BindingType {
RELEASE_ASSERT_NOT_REACHED();
});
}
static bool isExternalTexture(const AST::Variable& variable)
{
return bindingTypeForType(variable.storeType()) == BindingType::TextureExternal;
}
Vector<unsigned> RewriteGlobalVariables::insertStructs(const UsedResources& usedResources)
{
Vector<unsigned> groups;
for (auto& groupBinding : m_groupBindingMap) {
auto usedResource = usedResources.find(groupBinding.key);
if (usedResource == usedResources.end())
continue;
auto& bindingGlobalMap = groupBinding.value;
const IndexMap<Global*>& usedBindings = usedResource->value;
Vector<std::pair<unsigned, AST::StructureMember*>> entries;
unsigned metalId = 0;
HashMap<AST::Variable*, unsigned> bufferSizeToOwnerMap;
for (auto [binding, globalName] : bindingGlobalMap) {
unsigned group = groupBinding.key;
auto it = m_globals.find(globalName);
RELEASE_ASSERT(it != m_globals.end());
auto& global = it->value;
if (auto bufferIt = m_reverseBufferLengthMap.find(global.declaration); bufferIt != m_reverseBufferLengthMap.end()) {
if (isExternalTexture(*bufferIt->value))
continue;
if (bufferIt->value->addressSpace() != AddressSpace::Storage)
continue;
if (!bufferSizeToOwnerMap.contains(global.declaration))
continue;
} else if (!usedBindings.contains(binding))
continue;
else if (!m_reads.contains(globalName))
continue;
ASSERT(global.declaration->maybeTypeName());
entries.append({ metalId, &createArgumentBufferEntry(metalId, *global.declaration) });
if (auto it = m_generateLayoutGroupMapping.find(groupBinding.key); it != m_generateLayoutGroupMapping.end())
group = it->value;
else {
auto newGroup = m_generatedLayout->bindGroupLayouts.size();
m_generateLayoutGroupMapping.add(group, newGroup);
m_generatedLayout->bindGroupLayouts.append({ group, { } });
group = newGroup;
}
BindGroupLayoutEntry entry {
.binding = metalId,
.webBinding = global.resource->binding,
.visibility = m_stage,
.bindingMember = bindingMemberForGlobal(global),
.name = global.declaration->name(),
.vertexArgumentBufferIndex = std::nullopt,
.vertexArgumentBufferSizeIndex = std::nullopt,
.vertexBufferDynamicOffset = std::nullopt,
.fragmentArgumentBufferIndex = std::nullopt,
.fragmentArgumentBufferSizeIndex = std::nullopt,
.fragmentBufferDynamicOffset = std::nullopt,
.computeArgumentBufferIndex = std::nullopt,
.computeArgumentBufferSizeIndex = std::nullopt,
.computeBufferDynamicOffset = std::nullopt
};
auto bufferSizeIt = m_bufferLengthMap.find(global.declaration);
if (bufferSizeIt != m_bufferLengthMap.end()) {
auto* variable = bufferSizeIt->value;
bufferSizeToOwnerMap.add(variable, m_generatedLayout->bindGroupLayouts[group].entries.size());
} else if (auto ownerIt = bufferSizeToOwnerMap.find(global.declaration); ownerIt != bufferSizeToOwnerMap.end()) {
// FIXME: <rdar://150369108> since we only ever generate a layout
// for one shader stage at a time, we always store the indices in
// the vertex slot, but we should use a structs to pass information
// from the compiler to the API (instead of reusing the same struct
// the API uses to pass information to the compiler)
m_generatedLayout->bindGroupLayouts[group].entries[ownerIt->value].vertexArgumentBufferSizeIndex = metalId;
}
auto* type = global.declaration->storeType();
if (isPrimitive(type, Types::Primitive::TextureExternal))
metalId += 4;
else
++metalId;
m_generatedLayout->bindGroupLayouts[group].entries.append(WTF::move(entry));
}
if (entries.isEmpty())
continue;
groups.append(groupBinding.key);
finalizeArgumentBufferStruct(groupBinding.key, entries);
}
return groups;
}
AST::StructureMember& RewriteGlobalVariables::createArgumentBufferEntry(unsigned binding, AST::Variable& variable)
{
return createArgumentBufferEntry(binding, variable.span(), variable.name(), *variable.maybeReferenceType());
}
AST::StructureMember& RewriteGlobalVariables::createArgumentBufferEntry(unsigned binding, const SourceSpan& span, const String& name, AST::Expression& type)
{
auto& bindingValue = m_shaderModule.astBuilder().construct<AST::AbstractIntegerLiteral>(span, binding);
bindingValue.m_inferredType = m_shaderModule.types().abstractIntType();
bindingValue.setConstantValue(binding);
auto& bindingAttribute = m_shaderModule.astBuilder().construct<AST::BindingAttribute>(span, bindingValue);
++m_entryPointInformation->bindingCount;
return m_shaderModule.astBuilder().construct<AST::StructureMember>(
span,
AST::Identifier::make(name),
type,
AST::Attribute::List { bindingAttribute }
);
}
void RewriteGlobalVariables::finalizeArgumentBufferStruct(unsigned group, Vector<std::pair<unsigned, AST::StructureMember*>>& entries)
{
std::ranges::sort(entries, { }, &std::pair<unsigned, AST::StructureMember*>::first);
AST::StructureMember::List structMembers;
for (auto& [_, member] : entries)
structMembers.append(*member);
auto& argumentBufferStruct = m_shaderModule.astBuilder().construct<AST::Structure>(
SourceSpan::empty(),
argumentBufferStructName(group),
WTF::move(structMembers),
AST::Attribute::List { },
AST::StructureRole::BindGroup
);
argumentBufferStruct.m_inferredType = m_shaderModule.types().structType(argumentBufferStruct);
m_shaderModule.append(m_shaderModule.declarations(), argumentBufferStruct);
m_structTypes.add(group, argumentBufferStruct.m_inferredType);
}
static AddressSpace addressSpaceForBindingMember(const BindGroupLayoutEntry::BindingMember& bindingMember)
{
return WTF::switchOn(bindingMember, [](const BufferBindingLayout& bufferBinding) {
switch (bufferBinding.type) {
case BufferBindingType::Uniform:
return AddressSpace::Uniform;
case BufferBindingType::Storage:
return AddressSpace::Storage;
case BufferBindingType::ReadOnlyStorage:
return AddressSpace::Storage;
}
}, [](const SamplerBindingLayout&) {
return AddressSpace::Handle;
}, [](const TextureBindingLayout&) {
return AddressSpace::Handle;
}, [](const StorageTextureBindingLayout&) {
return AddressSpace::Handle;
}, [](const ExternalTextureBindingLayout&) {
return AddressSpace::Handle;
});
}
static AccessMode accessModeForBindingMember(const BindGroupLayoutEntry::BindingMember& bindingMember)
{
return WTF::switchOn(bindingMember, [](const BufferBindingLayout& bufferBinding) {
switch (bufferBinding.type) {
case BufferBindingType::Uniform:
return AccessMode::Read;
case BufferBindingType::Storage:
return AccessMode::ReadWrite;
case BufferBindingType::ReadOnlyStorage:
return AccessMode::Read;
}
}, [](const SamplerBindingLayout&) {
return AccessMode::Read;
}, [](const TextureBindingLayout&) {
return AccessMode::Read;
}, [](const StorageTextureBindingLayout&) {
return AccessMode::Read;
}, [](const ExternalTextureBindingLayout&) {
return AccessMode::Read;
});
}
static bool isBuffer(const AST::Variable& variable)
{
return bindingTypeForType(variable.storeType()) == BindingType::Buffer;
}
static bool isSampler(const AST::Variable& variable, SamplerBindingType bindingType)
{
switch (bindingType) {
case SamplerBindingType::Filtering:
case SamplerBindingType::NonFiltering:
return bindingTypeForType(variable.storeType()) == BindingType::Sampler;
case SamplerBindingType::Comparison:
return bindingTypeForType(variable.storeType()) == BindingType::SamplerComparison;
}
}
static bool textureKindEqualsViewDimension(Types::Texture::Kind kind, TextureViewDimension viewDimension, bool isMultisampled, TextureSampleType sampleType)
{
if (isMultisampled)
return kind == Types::Texture::Kind::TextureMultisampled2d && viewDimension == TextureViewDimension::TwoDimensional;
switch (viewDimension) {
case TextureViewDimension::OneDimensional:
return kind == Types::Texture::Kind::Texture1d && sampleType != TextureSampleType::Depth;
case TextureViewDimension::TwoDimensional:
return kind == Types::Texture::Kind::Texture2d;
case TextureViewDimension::TwoDimensionalArray:
return kind == Types::Texture::Kind::Texture2dArray;
case TextureViewDimension::Cube:
return kind == Types::Texture::Kind::TextureCube;
case TextureViewDimension::CubeArray:
return kind == Types::Texture::Kind::TextureCubeArray;
case TextureViewDimension::ThreeDimensional:
return kind == Types::Texture::Kind::Texture3d && sampleType != TextureSampleType::Depth;
}
return false;
}
static bool depthTextureKindEqualsViewDimension(Types::TextureDepth::Kind kind, TextureViewDimension viewDimension, bool isMultisampled)
{
if (isMultisampled)
return kind == Types::TextureDepth::Kind::TextureDepthMultisampled2d && viewDimension == TextureViewDimension::TwoDimensional;
switch (viewDimension) {
case TextureViewDimension::OneDimensional:
return false;
case TextureViewDimension::TwoDimensional:
return kind == Types::TextureDepth::Kind::TextureDepth2d;
case TextureViewDimension::TwoDimensionalArray:
return kind == Types::TextureDepth::Kind::TextureDepth2dArray;
case TextureViewDimension::Cube:
return kind == Types::TextureDepth::Kind::TextureDepthCube;
case TextureViewDimension::CubeArray:
return kind == Types::TextureDepth::Kind::TextureDepthCubeArray;
case TextureViewDimension::ThreeDimensional:
return false;
}
return false;
}
static ASCIILiteral nameForBindingType(BindingType bindingType)
{
switch (bindingType) {
default:
case BindingType::Undefined:
return "Undefined"_s;
case BindingType::Buffer:
return "Buffer"_s;
case BindingType::Texture:
return "Texture"_s;
case BindingType::TextureMultisampled:
return "TextureMultisampled"_s;
case BindingType::TextureStorageReadOnly:
return "TextureStorageReadOnly"_s;
case BindingType::TextureStorageReadWrite:
return "TextureStorageReadWrite"_s;
case BindingType::TextureStorageWriteOnly:
return "TextureStorageWriteOnly"_s;
case BindingType::Sampler:
return "Sampler"_s;
case BindingType::SamplerComparison:
return "SamplerComparison"_s;
case BindingType::TextureExternal:
return "TextureExternal"_s;
}
}
static ASCIILiteral nameForTextureSampleType(TextureSampleType sampleType)
{
switch (sampleType) {
case TextureSampleType::Float:
return "Float"_s;
case TextureSampleType::UnfilterableFloat:
return "UnfilterableFloat"_s;
case TextureSampleType::Depth:
return "Depth"_s;
case TextureSampleType::SignedInt:
return "SignedInt"_s;
case TextureSampleType::UnsignedInt:
return "UnsignedInt"_s;
default:
return "Undefined"_s;
}
}
static ASCIILiteral nameForTextureDepthKind(Types::TextureDepth::Kind kind)
{
switch (kind) {
case Types::TextureDepth::Kind::TextureDepth2d:
return "TextureDepth2d"_s;
case Types::TextureDepth::Kind::TextureDepth2dArray:
return "TextureDepth2d"_s;
case Types::TextureDepth::Kind::TextureDepthCube:
return "TextureDepth2d"_s;
case Types::TextureDepth::Kind::TextureDepthCubeArray:
return "TextureDepth2d"_s;
case Types::TextureDepth::Kind::TextureDepthMultisampled2d:
return "Texture2d"_s;
default:
return "Undefined"_s;
}
}
static ASCIILiteral nameForTextureKind(Types::Texture::Kind kind)
{
switch (kind) {
case Types::Texture::Kind::Texture1d:
return "Texture1d"_s;
case Types::Texture::Kind::Texture2d:
return "Texture2d"_s;
case Types::Texture::Kind::Texture2dArray:
return "Texture2d"_s;
case Types::Texture::Kind::TextureCube:
return "Texture2d"_s;
case Types::Texture::Kind::TextureCubeArray:
return "Texture2d"_s;
case Types::Texture::Kind::TextureMultisampled2d:
return "Texture2d"_s;
case Types::Texture::Kind::Texture3d:
return "Texture3d"_s;
default:
return "Undefined"_s;
}
}
static ASCIILiteral nameForTextureViewDimension(TextureViewDimension viewDimension)
{
switch (viewDimension) {
case TextureViewDimension::OneDimensional:
return "1d"_s;
case TextureViewDimension::TwoDimensional:
return "2d"_s;
case TextureViewDimension::TwoDimensionalArray:
return "2d_array"_s;
case TextureViewDimension::Cube:
return "cube"_s;
case TextureViewDimension::CubeArray:
return "cube-array"_s;
case TextureViewDimension::ThreeDimensional:
return "3d"_s;
}
return "undefined"_s;
}
static ASCIILiteral nameForPrimitiveKind(Types::Primitive::Kind primitiveKind)
{
switch (primitiveKind) {
case Types::Primitive::AbstractInt:
return "<AbstractInt>"_s;
case Types::Primitive::I32:
return "int32"_s;
case Types::Primitive::U32:
return "uint32"_s;
case Types::Primitive::AbstractFloat:
return "<AbstractFloat>"_s;
case Types::Primitive::F16:
return "f16"_s;
case Types::Primitive::F32:
return "f32"_s;
case Types::Primitive::Void:
return "void"_s;
case Types::Primitive::Bool:
return "bool"_s;
case Types::Primitive::Sampler:
return "sampler"_s;
case Types::Primitive::SamplerComparison:
return "sampler_comparion"_s;
case Types::Primitive::TextureExternal:
return "texture_external"_s;
case Types::Primitive::AccessMode:
return "access_mode"_s;
case Types::Primitive::TexelFormat:
return "texel_format"_s;
case Types::Primitive::AddressSpace:
return "address_space"_s;
}
return "undefined"_s;
}
static String errorValidatingTexture(const AST::Variable& variable, const TextureBindingLayout& textureBinding)
{
bool isMultisampled = textureBinding.multisampled;
auto targetValue = isMultisampled ? BindingType::TextureMultisampled : BindingType::Texture;
auto storeType = variable.storeType();
auto bindingForType = bindingTypeForType(storeType);
if (bindingForType != targetValue)
return makeString("types don't match: WGSL type "_s, nameForBindingType(bindingForType), " target type "_s, nameForBindingType(targetValue));
auto sampleType = textureBinding.sampleType;
const Types::Texture* possibleTexture = std::get_if<Types::Texture>(storeType);
if (!possibleTexture) {
const Types::TextureDepth* possibleTextureDepth = std::get_if<Types::TextureDepth>(storeType);
if (!possibleTextureDepth || sampleType != TextureSampleType::Depth)
return makeString("depth validation failed: "_s, nameForTextureSampleType(sampleType));
bool result = depthTextureKindEqualsViewDimension(possibleTextureDepth->kind, textureBinding.viewDimension, isMultisampled);
if (!result)
return makeString("viewDimensions don't match: "_s, nameForTextureDepthKind(possibleTextureDepth->kind), ", textureBinding view dimension "_s, nameForTextureViewDimension(textureBinding.viewDimension), ", multisampled = "_s, isMultisampled ? "yes"_s : "no"_s);
return emptyString();
}
if (!textureKindEqualsViewDimension(possibleTexture->kind, textureBinding.viewDimension, isMultisampled, sampleType) || !possibleTexture->element)
return makeString("viewDimensions don't match: "_s, nameForTextureKind(possibleTexture->kind), ", bindingViewDimension = "_s, nameForTextureViewDimension(textureBinding.viewDimension), ", multisampled = "_s, isMultisampled ? "yes"_s : "no"_s, ", bindingSampleType = "_s, nameForTextureSampleType(sampleType));
bool result = false;
if (const auto* primitive = std::get_if<Types::Primitive>(possibleTexture->element)) {
switch (sampleType) {
case TextureSampleType::Float:
case TextureSampleType::UnfilterableFloat:
result = primitive->kind == Types::Primitive::F32 || primitive->kind == Types::Primitive::F16;
break;
case TextureSampleType::Depth:
break;
case TextureSampleType::SignedInt:
result = primitive->kind == Types::Primitive::I32;
break;
case TextureSampleType::UnsignedInt:
result = primitive->kind == Types::Primitive::U32;
break;
}
if (!result)
return makeString("element types don't match: sampleType "_s, nameForTextureSampleType(sampleType), ", primitive->kind "_s, nameForPrimitiveKind(primitive->kind));
}
if (!result)
return makeString("WGSL texture has no elementType: sampleType "_s, nameForTextureSampleType(sampleType));
return emptyString();
}
static bool storageTextureKindEqualsViewDimension(Types::TextureStorage::Kind kind, TextureViewDimension viewDimension)
{
switch (viewDimension) {
case TextureViewDimension::OneDimensional:
return kind == Types::TextureStorage::Kind::TextureStorage1d;
case TextureViewDimension::TwoDimensional:
return kind == Types::TextureStorage::Kind::TextureStorage2d;
case TextureViewDimension::TwoDimensionalArray:
return kind == Types::TextureStorage::Kind::TextureStorage2dArray;
case TextureViewDimension::Cube:
return false;
case TextureViewDimension::CubeArray:
return false;
case TextureViewDimension::ThreeDimensional:
return kind == Types::TextureStorage::Kind::TextureStorage3d;
}
return false;
}
static bool isStorageTexture(const AST::Variable& variable, const StorageTextureBindingLayout& storageTexture)
{
auto textureAccess = storageTexture.access;
auto storeType = variable.storeType();
const Types::TextureStorage* possibleStorageTexture = std::get_if<Types::TextureStorage>(storeType);
if (!possibleStorageTexture)
return false;
if (!storageTextureKindEqualsViewDimension(possibleStorageTexture->kind, storageTexture.viewDimension) || possibleStorageTexture->format != storageTexture.format)
return false;
switch (bindingTypeForType(storeType)) {
case BindingType::TextureStorageReadOnly:
return textureAccess == StorageTextureAccess::ReadOnly;
case BindingType::TextureStorageReadWrite:
return textureAccess == StorageTextureAccess::ReadWrite;
case BindingType::TextureStorageWriteOnly:
return textureAccess == StorageTextureAccess::WriteOnly || textureAccess == StorageTextureAccess::ReadWrite;
default:
return false;
}
}
static String errorValidatingTypes(const AST::Variable& variable, const BindGroupLayoutEntry::BindingMember& bindingMember)
{
return WTF::switchOn(bindingMember, [&](const BufferBindingLayout&) {
return isBuffer(variable) ? emptyString() : "WGSL variable is not a buffer"_s;
}, [&](const SamplerBindingLayout& samplerBinding) {
return isSampler(variable, samplerBinding.type) ? emptyString() : "WGSL variable is not a sampler"_s;
}, [&](const TextureBindingLayout& textureBinding) {
return errorValidatingTexture(variable, textureBinding);
}, [&](const StorageTextureBindingLayout& storageTexture) {
return isStorageTexture(variable, storageTexture) ? emptyString() : "WGSL variable is not a storage texture"_s;
}, [&](const ExternalTextureBindingLayout&) {
return isExternalTexture(variable) ? emptyString() : "WGSL variable is not an external texture"_s;
});
}
static String errorValidatingVariableAndEntryMatch(const AST::Variable& variable, const BindGroupLayoutEntry& entry)
{
if (auto error = errorValidatingTypes(variable, entry.bindingMember); error.length())
return error;
auto variableAddressSpace = variable.addressSpace();
auto entryAddressSpace = addressSpaceForBindingMember(entry.bindingMember);
if (variableAddressSpace && *variableAddressSpace != entryAddressSpace)
return "variableAddressSpace != entryAddressSpace"_s;
auto variableAccessMode = variable.accessMode();
auto entryAccessMode = accessModeForBindingMember(entry.bindingMember);
if (variableAccessMode && *variableAccessMode != entryAccessMode)
return "variableAccessMode != entryAccessMode"_s;
return emptyString();
}
Result<Vector<unsigned>> RewriteGlobalVariables::insertStructs(PipelineLayout& layout, const UsedResources& usedResources)
{
Vector<unsigned> groups;
unsigned group = 0;
HashMap<AST::Variable*, BindGroupLayoutEntry*> serializedVariables;
for (auto& bindGroupLayout : layout.bindGroupLayouts) {
Vector<std::pair<unsigned, AST::StructureMember*>> entries;
Vector<std::pair<unsigned, AST::Variable*>> bufferLengths;
for (auto& entry : bindGroupLayout.entries) {
if (!entry.visibility.contains(m_stage))
continue;
auto argumentBufferIndex = [&] {
switch (m_stage) {
case ShaderStage::Vertex:
return entry.vertexArgumentBufferIndex;
case ShaderStage::Fragment:
return entry.fragmentArgumentBufferIndex;
case ShaderStage::Compute:
return entry.computeArgumentBufferIndex;
}
}();
auto argumentBufferSizeIndex = [&] {
switch (m_stage) {
case ShaderStage::Vertex:
return entry.vertexArgumentBufferSizeIndex;
case ShaderStage::Fragment:
return entry.fragmentArgumentBufferSizeIndex;
case ShaderStage::Compute:
return entry.computeArgumentBufferSizeIndex;
}
}();
AST::Variable* variable = nullptr;
auto globalIt = m_globalsByBinding.find({ group + 1, entry.binding + 1 });
if (globalIt != m_globalsByBinding.end()) {
auto groupIt = usedResources.find(group);
if (groupIt != usedResources.end()) {
auto& bindings = groupIt->value;
auto bindingIt = bindings.find(entry.binding);
if (bindingIt != bindings.end()) {
variable = bindingIt->value->declaration;
serializedVariables.add(variable, &entry);
entries.append({ *argumentBufferIndex, &createArgumentBufferEntry(*argumentBufferIndex, *variable) });
}
}
}
if (!variable) {
auto& type = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make("void"_s));
type.m_inferredType = WTF::switchOn(entry.bindingMember,
[&](const BufferBindingLayout& buffer) -> const Type* {
AddressSpace addressSpace;
AccessMode accessMode;
switch (buffer.type) {
case BufferBindingType::Uniform:
addressSpace = AddressSpace::Uniform;
accessMode = AccessMode::Read;
break;
case BufferBindingType::Storage:
addressSpace = AddressSpace::Storage;
accessMode = AccessMode::ReadWrite;
break;
case BufferBindingType::ReadOnlyStorage:
addressSpace = AddressSpace::Storage;
accessMode = AccessMode::Read;
break;
}
return m_shaderModule.types().pointerType(addressSpace, m_shaderModule.types().voidType(), accessMode);
},
[&](const SamplerBindingLayout&) -> const Type* {
return m_shaderModule.types().samplerType();
},
[&](const TextureBindingLayout& layout) -> const Type* {
const Type* sampleType;
switch (layout.sampleType) {
case TextureSampleType::Float:
case TextureSampleType::UnfilterableFloat:
case TextureSampleType::Depth:
sampleType = m_shaderModule.types().f32Type();
break;
case TextureSampleType::SignedInt:
sampleType = m_shaderModule.types().i32Type();
break;
case TextureSampleType::UnsignedInt:
sampleType = m_shaderModule.types().u32Type();
break;
}
Types::Texture::Kind textureKind;
switch (layout.viewDimension) {
case TextureViewDimension::OneDimensional:
textureKind = Types::Texture::Kind::Texture1d;
break;
case TextureViewDimension::TwoDimensional:
if (layout.multisampled)
textureKind = Types::Texture::Kind::TextureMultisampled2d;
else
textureKind = Types::Texture::Kind::Texture2d;
break;
case TextureViewDimension::TwoDimensionalArray:
textureKind = Types::Texture::Kind::Texture2dArray;
break;
case TextureViewDimension::Cube:
textureKind = Types::Texture::Kind::TextureCube;
break;
case TextureViewDimension::CubeArray:
textureKind = Types::Texture::Kind::TextureCubeArray;
break;
case TextureViewDimension::ThreeDimensional:
textureKind = Types::Texture::Kind::Texture3d;
break;
}
return m_shaderModule.types().textureType(textureKind, sampleType);
},
[&](const StorageTextureBindingLayout& layout) -> const Type* {
Types::TextureStorage::Kind textureStorageKind;
switch (layout.viewDimension) {
case TextureViewDimension::OneDimensional:
textureStorageKind = Types::TextureStorage::Kind::TextureStorage1d;
break;
case TextureViewDimension::TwoDimensional:
textureStorageKind = Types::TextureStorage::Kind::TextureStorage2d;
break;
case TextureViewDimension::TwoDimensionalArray:
textureStorageKind = Types::TextureStorage::Kind::TextureStorage2dArray;
break;
case TextureViewDimension::ThreeDimensional:
textureStorageKind = Types::TextureStorage::Kind::TextureStorage3d;
break;
default:
RELEASE_ASSERT_NOT_REACHED();
}
AccessMode accessMode;
switch (layout.access) {
case StorageTextureAccess::WriteOnly:
accessMode = AccessMode::Write;
break;
case StorageTextureAccess::ReadOnly:
accessMode = AccessMode::Read;
break;
case StorageTextureAccess::ReadWrite:
accessMode = AccessMode::ReadWrite;
break;
}
return m_shaderModule.types().textureStorageType(textureStorageKind, layout.format, accessMode);
},
[&](const ExternalTextureBindingLayout&) -> const Type* {
m_shaderModule.setUsesExternalTextures();
return m_shaderModule.types().textureExternalType();
});
entries.append({
*argumentBufferIndex,
&createArgumentBufferEntry(*argumentBufferIndex, SourceSpan::empty(), makeString("__ArgumentBufferPlaceholder_"_s, String::number(*argumentBufferIndex)), type)
});
}
if (argumentBufferSizeIndex.has_value())
bufferLengths.append({ *argumentBufferSizeIndex, variable });
}
for (auto [binding, variable] : bufferLengths) {
if (variable) {
auto it = m_bufferLengthMap.find(variable);
RELEASE_ASSERT(it != m_bufferLengthMap.end());
serializedVariables.add(it->value, nullptr);
entries.append({ binding, &createArgumentBufferEntry(binding, *it->value) });
} else {
entries.append({
binding,
&createArgumentBufferEntry(binding, SourceSpan::empty(), makeString("__ArgumentBufferPlaceholder_"_s, String::number(binding)), bufferLengthType())
});
}
}
if (entries.isEmpty()) {
++group;
continue;
}
groups.append(group);
finalizeArgumentBufferStruct(group++, entries);
}
for (auto& [_, bindingGlobalMap] : usedResources) {
for (auto [_, global] : bindingGlobalMap) {
auto* variable = global->declaration;
if (auto entryIt = serializedVariables.find(variable); entryIt != serializedVariables.end() && entryIt->value) {
if (auto error = errorValidatingVariableAndEntryMatch(*variable, *entryIt->value); error.length())
return makeUnexpected(Error(makeString("Shader is incompatible with layout pipeline: "_s, error), SourceSpan::empty()));
if (auto* bufferBindingLayout = std::get_if<BufferBindingLayout>(&entryIt->value->bindingMember))
bufferBindingLayout->minBindingSize = variable->storeType()->size();
}
if (!m_reads.contains(variable->name()))
continue;
if (!serializedVariables.contains(variable))
return makeUnexpected(Error("Shader is incompatible with layout pipeline"_s, SourceSpan::empty()));
}
}
return { groups };
}
void RewriteGlobalVariables::insertDynamicOffsetsBufferIfNeeded(const SourceSpan& span, const AST::Function& function)
{
if (!m_globalsUsingDynamicOffset.isEmpty() || (m_stage == ShaderStage::Fragment && (m_shaderModule.usesFragDepth() || m_shaderModule.usesSampleMask()))) {
unsigned group;
switch (m_stage) {
case ShaderStage::Vertex:
group = m_shaderModule.configuration().maxBuffersPlusVertexBuffersForVertexStage;
break;
case ShaderStage::Fragment:
group = m_shaderModule.configuration().maxBuffersForFragmentStage;
break;
case ShaderStage::Compute:
group = m_shaderModule.configuration().maxBuffersForComputeStage;
break;
}
auto& type = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(span, AST::Identifier::make("u32"_s));
type.m_inferredType = m_shaderModule.types().pointerType(AddressSpace::Uniform, m_shaderModule.types().u32Type(), AccessMode::Read);
insertParameter(span, function, group, AST::Identifier::make(dynamicOffsetVariableName()), &type, AST::ParameterRole::UserDefined);
}
}
void RewriteGlobalVariables::insertDynamicOffsetsBufferIfNeeded(const AST::Function& function)
{
insertDynamicOffsetsBufferIfNeeded(function.span(), function);
}
void RewriteGlobalVariables::insertParameters(AST::Function& function, const Vector<unsigned>& groups)
{
auto span = function.span();
for (auto group : groups)
insertParameter(span, function, group, argumentBufferParameterName(group));
insertDynamicOffsetsBufferIfNeeded(span, function);
}
void RewriteGlobalVariables::insertMaterializations(AST::Function& function, const UsedResources& usedResources)
{
auto span = function.span();
for (auto& [group, bindings] : usedResources) {
auto& argument = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
span,
AST::Identifier::make(argumentBufferParameterName(group))
);
for (auto& [binding, global] : bindings) {
auto& name = global->declaration->name();
String fieldName = name;
auto* storeType = global->declaration->storeType();
if (isPrimitive(storeType, Types::Primitive::TextureExternal)) {
fieldName = makeString("__"_s, name);
m_shaderModule.setUsesExternalTextures();
}
auto& access = m_shaderModule.astBuilder().construct<AST::FieldAccessExpression>(
SourceSpan::empty(),
argument,
AST::Identifier::make(WTF::move(fieldName))
);
AST::Expression* initializer = &access;
auto it = global->declaration->binding() ? m_globalsUsingDynamicOffset.find({ group + 1, binding + 1 }) : m_globalsUsingDynamicOffset.end();
if (it != m_globalsUsingDynamicOffset.end()) {
auto offset = it->value;
auto& target = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make("__dynamicOffset"_s)
);
auto& reference = std::get<Types::Reference>(*global->declaration->maybeReferenceType()->inferredType());
target.m_inferredType = m_shaderModule.types().pointerType(reference.addressSpace, storeType, reference.accessMode);
auto& offsetExpression = m_shaderModule.astBuilder().construct<AST::Unsigned32Literal>(span, offset);
offsetExpression.m_inferredType = m_shaderModule.types().u32Type();
offsetExpression.setConstantValue(offset);
initializer = &m_shaderModule.astBuilder().construct<AST::CallExpression>(
SourceSpan::empty(),
target,
AST::Expression::List { access, offsetExpression }
);
}
auto& variable = m_shaderModule.astBuilder().construct<AST::Variable>(
SourceSpan::empty(),
AST::VariableFlavor::Let,
AST::Identifier::make(name),
nullptr,
global->declaration->maybeReferenceType(),
initializer,
AST::Attribute::List { },
AST::VariableRole::PackedResource
);
auto& variableStatement = m_shaderModule.astBuilder().construct<AST::VariableStatement>(SourceSpan::empty(), variable);
m_shaderModule.insert(function.body().statements(), 0, std::reference_wrapper<AST::Statement>(variableStatement));
}
}
}
void RewriteGlobalVariables::insertLocalDefinitions(AST::Function& function, const UsedPrivateGlobals& usedPrivateGlobals)
{
auto initialBodySize = function.body().statements().size();
for (auto* global : usedPrivateGlobals) {
auto& variableStatement = m_shaderModule.astBuilder().construct<AST::VariableStatement>(SourceSpan::empty(), *global->declaration);
m_shaderModule.insert(function.body().statements(), 0, std::reference_wrapper<AST::Statement>(variableStatement));
}
auto offset = function.body().statements().size() - initialBodySize;
initializeVariables(function, usedPrivateGlobals, offset);
}
void RewriteGlobalVariables::initializeVariables(AST::Function& function, const UsedPrivateGlobals& globals, size_t offset)
{
auto initializations = storeInitialValue(globals);
if (initializations.isEmpty())
return;
insertWorkgroupBarrier(function, offset);
auto localInvocationIndex = findOrInsertLocalInvocationIndex(function);
auto& testLhs = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make(localInvocationIndex.id())
);
testLhs.m_inferredType = m_shaderModule.types().u32Type();
auto& testRhs = m_shaderModule.astBuilder().construct<AST::Unsigned32Literal>(SourceSpan::empty(), 0);
testLhs.m_inferredType = m_shaderModule.types().u32Type();
auto& testExpression = m_shaderModule.astBuilder().construct<AST::BinaryExpression>(
SourceSpan::empty(),
testLhs,
testRhs,
AST::BinaryOperation::Equal
);
testExpression.m_inferredType = m_shaderModule.types().boolType();
auto& body = m_shaderModule.astBuilder().construct<AST::CompoundStatement>(
SourceSpan::empty(),
AST::Attribute::List { },
WTF::move(initializations)
);
auto& ifStatement = m_shaderModule.astBuilder().construct<AST::IfStatement>(
SourceSpan::empty(),
testExpression,
body,
nullptr,
AST::Attribute::List { }
);
m_shaderModule.insert(function.body().statements(), offset, std::reference_wrapper<AST::Statement>(ifStatement));
}
void RewriteGlobalVariables::insertWorkgroupBarrier(AST::Function& function, size_t offset)
{
auto& callee = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make("workgroupBarrier"_s));
callee.m_inferredType = m_shaderModule.types().voidType();
auto& call = m_shaderModule.astBuilder().construct<AST::CallExpression>(
SourceSpan::empty(),
callee,
AST::Expression::List { }
);
call.m_inferredType = m_shaderModule.types().voidType();
auto& callStatement = m_shaderModule.astBuilder().construct<AST::CallStatement>(
SourceSpan::empty(),
call
);
m_shaderModule.insert(function.body().statements(), offset, std::reference_wrapper<AST::Statement>(callStatement));
}
AST::Identifier& RewriteGlobalVariables::findOrInsertLocalInvocationIndex(AST::Function& function)
{
for (auto& parameter : function.parameters()) {
if (auto builtin = parameter.builtin(); builtin.has_value() && *builtin == Builtin::LocalInvocationIndex)
return parameter.name();
}
auto& type = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make("u32"_s)
);
type.m_inferredType = m_shaderModule.types().u32Type();
auto& builtinAttribute = m_shaderModule.astBuilder().construct<AST::BuiltinAttribute>(
SourceSpan::empty(),
Builtin::LocalInvocationIndex
);
auto& parameter = m_shaderModule.astBuilder().construct<AST::Parameter>(
SourceSpan::empty(),
AST::Identifier::make("__localInvocationIndex"_s),
type,
AST::Attribute::List { builtinAttribute },
AST::ParameterRole::UserDefined
);
m_shaderModule.append(function.parameters(), parameter);
return parameter.name();
}
AST::Statement::List RewriteGlobalVariables::storeInitialValue(const UsedPrivateGlobals& globals)
{
AST::Statement::List statements;
for (auto* global : globals) {
auto& variable = *global->declaration;
if (auto addressSpace = variable.addressSpace(); !addressSpace.has_value() || *addressSpace != AddressSpace::Workgroup)
continue;
auto* type = variable.storeType();
auto& target = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make(variable.name().id())
);
target.m_inferredType = type;
storeInitialValue(target, statements, 0);
}
return statements;
}
void RewriteGlobalVariables::storeInitialValue(AST::Expression& target, AST::Statement::List& statements, unsigned arrayDepth)
{
const auto& zeroInitialize = [&]() {
// This piece of code generation relies on 2 implementation details from the metal serializer:
// - The callee's name won't be used if the call is set to constructor
// - There's a special case to handle the case where the left-hand side
// of the assignment doesn't have a type, so we can erase it
auto& callee = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make("__initialize"_s));
callee.m_inferredType = target.inferredType();
auto& call = m_shaderModule.astBuilder().construct<AST::CallExpression>(
SourceSpan::empty(),
callee,
AST::Expression::List { }
);
call.m_inferredType = target.inferredType();
call.m_isConstructor = true;
target.m_inferredType = nullptr;
auto& assignmentStatement = m_shaderModule.astBuilder().construct<AST::AssignmentStatement>(
SourceSpan::empty(),
target,
call
);
statements.append(AST::Statement::Ref(assignmentStatement));
};
auto* type = target.inferredType();
if (auto* arrayType = std::get_if<Types::Array>(type)) {
RELEASE_ASSERT(!arrayType->isRuntimeSized());
String indexVariableName = makeString("__i"_s, arrayDepth);
auto& indexVariable = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make(indexVariableName)
);
indexVariable.m_inferredType = m_shaderModule.types().u32Type();
auto& arrayAccess = m_shaderModule.astBuilder().construct<AST::IndexAccessExpression>(
SourceSpan::empty(),
target,
indexVariable
);
arrayAccess.m_inferredType = arrayType->element;
AST::Statement::List forBodyStatements;
storeInitialValue(arrayAccess, forBodyStatements, arrayDepth + 1);
auto& zero = m_shaderModule.astBuilder().construct<AST::Unsigned32Literal>(
SourceSpan::empty(),
0
);
zero.m_inferredType = m_shaderModule.types().u32Type();
auto& forVariable = m_shaderModule.astBuilder().construct<AST::Variable>(
SourceSpan::empty(),
AST::VariableFlavor::Var,
AST::Identifier::make(indexVariableName),
nullptr,
&zero
);
auto& forInitializer = m_shaderModule.astBuilder().construct<AST::VariableStatement>(
SourceSpan::empty(),
forVariable
);
auto* arrayLength = [&]() -> AST::Expression* {
if (auto* overrideExpression = std::get_if<AST::Expression*>(&arrayType->size))
return *overrideExpression;
auto& arrayLength = m_shaderModule.astBuilder().construct<AST::Unsigned32Literal>(
SourceSpan::empty(),
std::get<unsigned>(arrayType->size)
);
arrayLength.m_inferredType = m_shaderModule.types().u32Type();
return &arrayLength;
}();
auto& forTest = m_shaderModule.astBuilder().construct<AST::BinaryExpression>(
SourceSpan::empty(),
indexVariable,
*arrayLength,
AST::BinaryOperation::LessThan
);
forTest.m_inferredType = m_shaderModule.types().boolType();
auto& one = m_shaderModule.astBuilder().construct<AST::Unsigned32Literal>(
SourceSpan::empty(),
1
);
one.m_inferredType = m_shaderModule.types().u32Type();
auto& forUpdate = m_shaderModule.astBuilder().construct<AST::CompoundAssignmentStatement>(
SourceSpan::empty(),
indexVariable,
one,
AST::BinaryOperation::Add
);
auto& forBody = m_shaderModule.astBuilder().construct<AST::CompoundStatement>(
SourceSpan::empty(),
AST::Attribute::List { },
WTF::move(forBodyStatements)
);
auto& forStatement = m_shaderModule.astBuilder().construct<AST::ForStatement>(
SourceSpan::empty(),
&forInitializer,
&forTest,
&forUpdate,
forBody
);
forStatement.setInternallyGenerated();
statements.append(AST::Statement::Ref(forStatement));
return;
}
if (auto* structType = std::get_if<Types::Struct>(type)) {
if (type->isConstructible()) {
zeroInitialize();
return;
}
for (auto& member : structType->structure.members()) {
auto* fieldType = member.type().inferredType();
auto& fieldAccess = m_shaderModule.astBuilder().construct<AST::FieldAccessExpression>(
SourceSpan::empty(),
target,
AST::Identifier::make(member.name())
);
fieldAccess.m_inferredType = fieldType;
storeInitialValue(fieldAccess, statements, arrayDepth);
}
return;
}
if (type && std::holds_alternative<Types::Atomic>(*type)) {
auto& callee = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make("atomicStore"_s));
callee.m_inferredType = m_shaderModule.types().voidType();
auto& pointer = m_shaderModule.astBuilder().construct<AST::UnaryExpression>(
SourceSpan::empty(),
target,
AST::UnaryOperation::AddressOf
);
pointer.m_inferredType = m_shaderModule.types().voidType();
auto& value = m_shaderModule.astBuilder().construct<AST::AbstractIntegerLiteral>(SourceSpan::empty(), 0);
value.m_inferredType = m_shaderModule.types().abstractIntType();
auto& call = m_shaderModule.astBuilder().construct<AST::CallExpression>(
SourceSpan::empty(),
callee,
AST::Expression::List { pointer, value }
);
call.m_inferredType = m_shaderModule.types().voidType();
auto& callStatement = m_shaderModule.astBuilder().construct<AST::CallStatement>(
SourceSpan::empty(),
call
);
statements.append(AST::Statement::Ref(callStatement));
return;
}
zeroInitialize();
}
void RewriteGlobalVariables::def(const AST::Identifier& name, AST::Variable* variable)
{
dataLogLnIf(shouldLogGlobalVariableRewriting, "> def: ", name, " at line:", name.span().line, " column: ", name.span().lineOffset);
m_defs.add(name, variable);
}
auto RewriteGlobalVariables::readVariable(AST::IdentifierExpression& identifier) -> const Global*
{
auto def = m_defs.find(identifier.identifier());
if (def != m_defs.end())
return nullptr;
auto it = m_globals.find(identifier.identifier());
if (it == m_globals.end())
return nullptr;
auto& global = it->value;
if (global.declaration->flavor() == AST::VariableFlavor::Const)
return nullptr;
dataLogLnIf(shouldLogGlobalVariableRewriting, "> read global: ", identifier.identifier(), " at line:", identifier.span().line, " column: ", identifier.span().lineOffset);
auto addResult = m_reads.add(identifier.identifier());
if (addResult.isNewEntry) {
if (auto* type = global.declaration->maybeTypeName())
visit(*type);
if (auto* initializer = global.declaration->maybeInitializer())
visit(*initializer);
}
return &global;
}
void RewriteGlobalVariables::insertBeforeCurrentStatement(AST::Statement& statement)
{
m_pendingInsertions.append({ &statement, m_currentStatementIndex });
}
AST::Identifier RewriteGlobalVariables::argumentBufferParameterName(unsigned group)
{
return AST::Identifier::make(makeString("__ArgumentBuffer_"_s, group));
}
AST::Identifier RewriteGlobalVariables::argumentBufferStructName(unsigned group)
{
return AST::Identifier::make(makeString("__ArgumentBufferT_"_s, m_entryPointID, "_"_s, group));
}
AST::Identifier RewriteGlobalVariables::dynamicOffsetVariableName()
{
return AST::Identifier::make("__DynamicOffsets"_str);
}
std::optional<Error> rewriteGlobalVariables(ShaderModule& shaderModule, const HashMap<String, PipelineLayout*>& pipelineLayouts, HashMap<String, Reflection::EntryPointInformation>& entryPointInformations)
{
return RewriteGlobalVariables(shaderModule, pipelineLayouts, entryPointInformations).run();
}
} // namespace WGSL