blob: 45066ec0feb17b6d3d5d1d20539575466882f0d5 [file] [log] [blame] [edit]
/*
* Copyright (c) 2023 Apple Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "config.h"
#include "EntryPointRewriter.h"
#include "AST.h"
#include "ASTVisitor.h"
#include "TypeStore.h"
#include "Types.h"
#include "WGSL.h"
#include "WGSLShaderModule.h"
#include <wtf/text/MakeString.h>
namespace WGSL {
class EntryPointRewriter {
public:
EntryPointRewriter(ShaderModule&, const AST::Function&, ShaderStage);
void rewrite();
private:
struct MemberOrParameter {
AST::Identifier name;
AST::Expression& type;
AST::Attribute::List attributes;
};
struct BuiltinMemberOrParameter : MemberOrParameter {
Builtin builtin;
};
enum class IsBuiltin {
No = 0,
Yes = 1,
};
void collectParameters();
void checkReturnType();
void constructInputStruct();
void materialize(Vector<String>& path, MemberOrParameter&, IsBuiltin, const String* builtinName = nullptr);
void visit(Vector<String>& path, MemberOrParameter&&);
void appendBuiltins();
ShaderStage m_stage;
ShaderModule& m_shaderModule;
const AST::Function& m_function;
Vector<BuiltinMemberOrParameter> m_builtins;
Vector<MemberOrParameter> m_parameters;
AST::Statement::List m_materializations;
const Type* m_structType;
String m_structTypeName;
String m_structParameterName;
unsigned m_builtinID { 0 };
};
EntryPointRewriter::EntryPointRewriter(ShaderModule& shaderModule, const AST::Function& function, ShaderStage stage)
: m_stage(stage)
, m_shaderModule(shaderModule)
, m_function(function)
{
}
void EntryPointRewriter::rewrite()
{
m_structTypeName = makeString("__"_s, m_function.name(), "_inT"_s);
m_structParameterName = makeString("__"_s, m_function.name(), "_in"_s);
collectParameters();
checkReturnType();
appendBuiltins();
if (!m_parameters.isEmpty()) {
constructInputStruct();
// add parameter to builtins: ${structName} : ${structType}
auto& type = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make(m_structTypeName));
type.m_inferredType = m_structType;
auto& parameter = m_shaderModule.astBuilder().construct<AST::Parameter>(
SourceSpan::empty(),
AST::Identifier::make(m_structParameterName),
type,
AST::Attribute::List { },
AST::ParameterRole::StageIn
);
m_shaderModule.append(m_function.parameters(), parameter);
}
m_shaderModule.insertVector(m_function.body().statements(), 0, m_materializations);
}
void EntryPointRewriter::collectParameters()
{
for (auto& parameter : m_function.parameters()) {
Vector<String> path;
visit(path, MemberOrParameter { parameter.name(), const_cast<AST::Expression&>(parameter.typeName()), parameter.attributes() });
}
m_shaderModule.clear(m_function.parameters());
}
void EntryPointRewriter::checkReturnType()
{
if (m_stage == ShaderStage::Compute)
return;
auto* namedTypeName = dynamicDowncast<AST::IdentifierExpression>(m_function.maybeReturnType());
if (namedTypeName) {
if (auto* structType = std::get_if<Types::Struct>(namedTypeName->inferredType())) {
const auto& duplicateStruct = [&] (AST::StructureRole role, ASCIILiteral suffix) {
ASSERT(structType->structure.role() == AST::StructureRole::UserDefined);
String returnStructName = makeString("__"_s, structType->structure.name(), '_', suffix);
auto& returnStruct = m_shaderModule.astBuilder().construct<AST::Structure>(
SourceSpan::empty(),
AST::Identifier::make(returnStructName),
AST::StructureMember::List(structType->structure.members()),
AST::Attribute::List { },
role
);
m_shaderModule.append(m_shaderModule.declarations(), returnStruct);
auto& returnType = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make(returnStructName)
);
returnType.m_inferredType = m_shaderModule.types().structType(returnStruct);
m_shaderModule.replace(*namedTypeName, returnType);
};
if (m_stage == ShaderStage::Fragment) {
duplicateStruct(AST::StructureRole::FragmentOutput, "FragmentOutput"_s);
return;
}
duplicateStruct(AST::StructureRole::VertexOutput, "VertexOutput"_s);
return;
}
}
if (!m_function.maybeReturnType() || (m_stage != ShaderStage::Fragment && m_stage != ShaderStage::Vertex) || m_function.returnAttributes().isEmpty())
return;
auto stageName = m_stage == ShaderStage::Fragment ? "Fragment"_s : "Vertex"_s;
String returnStructName = makeString("__"_s, m_function.name(), '_', stageName, "Output"_s);
auto& fieldType = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make("__type"_s)
);
fieldType.m_inferredType = m_function.maybeReturnType()->inferredType();
auto& member = m_shaderModule.astBuilder().construct<AST::StructureMember>(
SourceSpan::empty(),
AST::Identifier::make("__value"_s),
fieldType,
AST::Attribute::List(m_function.returnAttributes())
);
auto role = m_stage == ShaderStage::Fragment ? AST::StructureRole::FragmentOutputWrapper : AST::StructureRole::VertexOutputWrapper;
auto& returnStruct = m_shaderModule.astBuilder().construct<AST::Structure>(
SourceSpan::empty(),
AST::Identifier::make(returnStructName),
AST::StructureMember::List({ member }),
AST::Attribute::List { },
role
);
m_shaderModule.append(m_shaderModule.declarations(), returnStruct);
auto& returnType = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make(returnStructName)
);
returnType.m_inferredType = m_shaderModule.types().structType(returnStruct);
if (namedTypeName)
m_shaderModule.replace(*namedTypeName, returnType);
else if (auto* elaboratedExpression = dynamicDowncast<AST::ElaboratedTypeExpression>(m_function.maybeReturnType()))
m_shaderModule.replace(*elaboratedExpression, returnType);
else if (auto* arrayType = dynamicDowncast<AST::ArrayTypeExpression>(m_function.maybeReturnType()))
m_shaderModule.replace(*arrayType, returnType);
else
RELEASE_ASSERT_NOT_REACHED();
}
void EntryPointRewriter::constructInputStruct()
{
// insert `var ${parameter.name()} = ${structName}.${parameter.name()}`
AST::StructureMember::List structMembers;
for (auto& parameter : m_parameters) {
structMembers.append(m_shaderModule.astBuilder().construct<AST::StructureMember>(
SourceSpan::empty(),
WTF::move(parameter.name),
parameter.type,
WTF::move(parameter.attributes)
));
}
AST::StructureRole role;
switch (m_stage) {
case ShaderStage::Compute:
role = AST::StructureRole::ComputeInput;
break;
case ShaderStage::Vertex:
role = AST::StructureRole::VertexInput;
break;
case ShaderStage::Fragment:
role = AST::StructureRole::FragmentInput;
break;
}
auto& structure = m_shaderModule.astBuilder().construct<AST::Structure>(
SourceSpan::empty(),
AST::Identifier::make(m_structTypeName),
WTF::move(structMembers),
AST::Attribute::List { },
role
);
m_shaderModule.append(m_shaderModule.declarations(), structure);
m_structType = m_shaderModule.types().structType(structure);
}
void EntryPointRewriter::materialize(Vector<String>& path, MemberOrParameter& data, IsBuiltin isBuiltin, const String* builtinName)
{
AST::Expression::Ptr rhs;
if (isBuiltin == IsBuiltin::Yes)
rhs = &m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make(*builtinName));
else {
rhs = &m_shaderModule.astBuilder().construct<AST::FieldAccessExpression>(
SourceSpan::empty(),
m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make(m_structParameterName)),
AST::Identifier::make(data.name)
);
}
if (!path.size()) {
m_materializations.append(m_shaderModule.astBuilder().construct<AST::VariableStatement>(
SourceSpan::empty(),
m_shaderModule.astBuilder().construct<AST::Variable>(
SourceSpan::empty(),
AST::VariableFlavor::Var,
AST::Identifier::make(data.name),
nullptr, // TODO: do we need a VariableQualifier?
&data.type,
rhs,
AST::Attribute::List { }
)
));
return;
}
path.append(data.name);
unsigned i = 0;
AST::Expression::Ref lhs = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make(path[i++]));
while (i < path.size()) {
lhs = m_shaderModule.astBuilder().construct<AST::FieldAccessExpression>(
SourceSpan::empty(),
WTF::move(lhs),
AST::Identifier::make(path[i++])
);
}
path.removeLast();
m_materializations.append(m_shaderModule.astBuilder().construct<AST::AssignmentStatement>(
SourceSpan::empty(),
WTF::move(lhs),
*rhs
));
}
void EntryPointRewriter::visit(Vector<String>& path, MemberOrParameter&& data)
{
if (auto* structType = std::get_if<Types::Struct>(data.type.inferredType())) {
m_materializations.append(m_shaderModule.astBuilder().construct<AST::VariableStatement>(
SourceSpan::empty(),
m_shaderModule.astBuilder().construct<AST::Variable>(
SourceSpan::empty(),
AST::VariableFlavor::Var,
AST::Identifier::make(data.name),
nullptr,
&data.type,
nullptr,
AST::Attribute::List { }
)
));
path.append(data.name);
for (auto& member : structType->structure.members())
visit(path, MemberOrParameter { member.name(), member.type(), member.attributes() });
path.removeLast();
return;
}
std::optional<Builtin> builtin;
for (auto& attribute : data.attributes) {
if (auto* builtinAttribute = dynamicDowncast<AST::BuiltinAttribute>(attribute)) {
builtin = builtinAttribute->builtin();
break;
}
}
if (builtin.has_value()) {
if (!path.isEmpty()) {
// builtin was hoisted from a struct into a parameter, we need to reconstruct the struct
// ${path}.${data.name} = __builtin${builtinID}
// Note that we don't use ${data.name} on the right-hand side because it's the name of a
// struct field, and it might not be unique.
auto builtinName = makeString("__builtin"_s, String::number(m_builtinID++));
materialize(path, data, IsBuiltin::Yes, &builtinName);
m_builtins.append({
{
AST::Identifier::make(builtinName),
data.type,
data.attributes
},
*builtin
});
return;
}
// if path is empty, then it was already a parameter and there's nothing to do
m_builtins.append({ data, *builtin });
return;
}
// parameter was moved into a struct, so we need to reload it
// ${path}.${data.name} = ${struct}.${data.name}
materialize(path, data, IsBuiltin::No);
m_parameters.append(WTF::move(data));
}
void EntryPointRewriter::appendBuiltins()
{
for (auto& data : m_builtins) {
auto& parameter = m_shaderModule.astBuilder().construct<AST::Parameter>(
SourceSpan::empty(),
AST::Identifier::make(data.name),
data.type,
WTF::move(data.attributes),
AST::ParameterRole::UserDefined
);
parameter.m_builtin = data.builtin;
m_shaderModule.append(m_function.parameters(), parameter);
}
}
void rewriteEntryPoints(ShaderModule& shaderModule, const HashMap<String, PipelineLayout*>& pipelineLayouts)
{
for (auto& entryPoint : shaderModule.callGraph().entrypoints()) {
if (!pipelineLayouts.contains(entryPoint.originalName))
continue;
EntryPointRewriter rewriter(shaderModule, entryPoint.function, entryPoint.stage);
rewriter.rewrite();
}
}
} // namespace WGSL