blob: b3a9025ccdb4167d8923042a0b8282c77dc1577f [file] [log] [blame]
//
// Copyright 2023 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
#include "compiler/translator/tree_ops/msl/RewriteInterpolants.h"
#include "compiler/translator/StaticType.h"
#include "compiler/translator/msl/AstHelpers.h"
#include "compiler/translator/tree_util/BuiltIn.h"
#include "compiler/translator/tree_util/IntermNode_util.h"
#include "compiler/translator/tree_util/ReplaceVariable.h"
namespace sh
{
namespace
{
class FindInterpolantsTraverser : public TIntermTraverser
{
public:
FindInterpolantsTraverser(TSymbolTable *symbolTable, const DriverUniformMetal *driverUniforms)
: TIntermTraverser(true, false, false, symbolTable),
mDriverUniforms(driverUniforms),
mUsesSampleInterpolation(false)
{}
bool visitDeclaration(Visit, TIntermDeclaration *node) override
{
const TIntermSequence &sequence = *(node->getSequence());
ASSERT(!sequence.empty());
const TIntermTyped &typedNode = *(sequence.front()->getAsTyped());
TQualifier qualifier = typedNode.getQualifier();
if (qualifier == EvqSampleIn || qualifier == EvqNoPerspectiveSampleIn)
{
mUsesSampleInterpolation = true;
}
return true;
}
const TFunction *getFlipFunction()
{
if (mFlipFunction != nullptr)
{
return mFlipFunction->getFunction();
}
const TType *vec2Type = StaticType::GetQualified<EbtFloat, EbpHigh, EvqParamIn, 2>();
TVariable *offsetParam = new TVariable(mSymbolTable, ImmutableString("offset"), vec2Type,
SymbolType::AngleInternal);
TFunction *function =
new TFunction(mSymbolTable, ImmutableString("ANGLEFlipInterpolationOffset"),
SymbolType::AngleInternal, vec2Type, true);
function->addParameter(offsetParam);
TIntermTyped *flipXY =
mDriverUniforms->getFlipXY(mSymbolTable, DriverUniformFlip::Fragment);
TIntermTyped *flipped = new TIntermBinary(EOpMul, new TIntermSymbol(offsetParam), flipXY);
TIntermBranch *returnStatement = new TIntermBranch(EOpReturn, flipped);
TIntermBlock *body = new TIntermBlock;
body->appendStatement(returnStatement);
mFlipFunction = new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
return function;
}
bool visitAggregate(Visit visit, TIntermAggregate *node) override
{
if (!BuiltInGroup::IsInterpolationFS(node->getOp()))
{
return true;
}
TIntermNode *operand = node->getSequence()->at(0);
ASSERT(operand);
// For all of the interpolation functions, <interpolant> must be an input
// variable or an element of an input variable declared as an array.
const TIntermSymbol *symbolNode = operand->getAsSymbolNode();
if (!symbolNode)
{
const TIntermBinary *binaryNode = operand->getAsBinaryNode();
if (binaryNode &&
(binaryNode->getOp() == EOpIndexDirect || binaryNode->getOp() == EOpIndexIndirect))
{
symbolNode = binaryNode->getLeft()->getAsSymbolNode();
}
}
ASSERT(symbolNode);
// If <interpolant> is declared with a "flat" qualifier, the interpolated
// value will have the same value everywhere for a single primitive, so
// the location used for the interpolation has no effect and the functions
// just return that same value.
const TVariable *variable = &symbolNode->variable();
if (variable->getType().getQualifier() != EvqFlatIn)
{
mInterpolants.insert(variable);
}
// Flip offset's Y if needed.
if (node->getOp() == EOpInterpolateAtOffset)
{
TIntermTyped *offsetNode = node->getSequence()->at(1)->getAsTyped();
TIntermTyped *correctedOffset = TIntermAggregate::CreateFunctionCall(
*getFlipFunction(), new TIntermSequence{offsetNode});
queueReplacementWithParent(node, offsetNode, correctedOffset, OriginalNode::IS_DROPPED);
}
return true;
}
bool usesSampleInterpolation() const { return mUsesSampleInterpolation; }
const std::unordered_set<const TVariable *> &getInterpolants() const { return mInterpolants; }
TIntermFunctionDefinition *getFlipFunctionDefinition() { return mFlipFunction; }
private:
const DriverUniformMetal *mDriverUniforms;
bool mUsesSampleInterpolation;
std::unordered_set<const TVariable *> mInterpolants;
TIntermFunctionDefinition *mFlipFunction = nullptr;
};
class WrapInterpolantsTraverser : public TIntermTraverser
{
public:
WrapInterpolantsTraverser(TSymbolTable *symbolTable)
: TIntermTraverser(true, false, false, symbolTable), mUsesSampleInterpolant(false)
{}
void visitSymbol(TIntermSymbol *node) override
{
// Skip all symbols not previously marked as
// interpolants by FindInterpolantsTraverser
const TType &type = node->variable().getType();
if (!type.isInterpolant())
{
return;
}
TIntermNode *ancestor = getAncestorNode(0);
ASSERT(ancestor);
// Only root-level input varying declarations should be
// reachable by this line and they must not be wrapped.
if (ancestor->getAsDeclarationNode())
{
return;
}
auto checkSkip = [](TIntermNode *node, TIntermNode *parentNode) {
if (TIntermAggregate *callNode = parentNode->getAsAggregate())
{
if (BuiltInGroup::IsInterpolationFS(callNode->getOp()) &&
callNode->getSequence()->at(0) == node)
{
return true;
}
}
return false;
};
// Skip symbols used as the first operand of interpolation functions
if (checkSkip(node, ancestor))
{
return;
}
TIntermNode *original = node;
if (TIntermBinary *binaryNode = ancestor->getAsBinaryNode())
{
if (binaryNode->getOp() == EOpIndexDirect || binaryNode->getOp() == EOpIndexIndirect)
{
ancestor = getAncestorNode(1);
ASSERT(ancestor);
// Skip array elements used as the first operand of interpolation functions
if (checkSkip(binaryNode, ancestor))
{
return;
}
original = binaryNode;
}
}
const char *functionName = nullptr;
TIntermSequence *arguments = new TIntermSequence{original};
switch (type.getQualifier())
{
case EvqFragmentIn:
case EvqSmoothIn:
case EvqNoPerspectiveIn:
// `metal::interpolant` variables cannot be used directly,
// so MSL has a dedicated interpolation function to obtain
// their pixel-center values. This function is included in
// the `MetalFragmentSample` built-in functions group.
functionName = "interpolateAtCenter";
break;
case EvqCentroidIn:
case EvqNoPerspectiveCentroidIn:
functionName = "interpolateAtCentroid";
break;
case EvqSampleIn:
case EvqNoPerspectiveSampleIn:
functionName = "interpolateAtSample";
arguments->push_back(new TIntermSymbol(BuiltInVariable::gl_SampleID()));
mUsesSampleInterpolant = true;
break;
default:
UNREACHABLE();
break;
}
TIntermTyped *replacement = CreateBuiltInFunctionCallNode(
functionName, arguments, *mSymbolTable, kESSLInternalBackendBuiltIns);
queueReplacementWithParent(ancestor, original, replacement, OriginalNode::BECOMES_CHILD);
}
bool usesSampleInterpolant() const { return mUsesSampleInterpolant; }
private:
bool mUsesSampleInterpolant;
};
} // anonymous namespace
[[nodiscard]] bool RewriteInterpolants(TCompiler &compiler,
TIntermBlock &root,
TSymbolTable &symbolTable,
const DriverUniformMetal *driverUniforms,
bool *outUsesSampleInterpolation,
bool *outUsesSampleInterpolant)
{
// Find all fragment inputs used with interpolation functions.
FindInterpolantsTraverser findInterpolantsTraverser(&symbolTable, driverUniforms);
root.traverse(&findInterpolantsTraverser);
// Define ANGLEFlipInterpolationOffset if interpolateAtOffset was used.
if (findInterpolantsTraverser.getFlipFunctionDefinition() != nullptr)
{
const size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(&root);
root.insertStatement(firstFunctionIndex,
findInterpolantsTraverser.getFlipFunctionDefinition());
}
if (!findInterpolantsTraverser.updateTree(&compiler, &root))
{
return false;
}
*outUsesSampleInterpolation = findInterpolantsTraverser.usesSampleInterpolation();
// Skip further operations when interpolation functions are not used.
if (findInterpolantsTraverser.getInterpolants().empty())
{
return true;
}
// Adjust variable types as per MSL requirements
//
// * Inputs with omitted and smooth interpolation qualifiers will be written as
// metal::interpolant<T, metal::interpolation::perspective>
//
// * Inputs with noperspective interpolation qualifiers will be written as
// metal::interpolant<T, metal::interpolation::no_perspective>
for (const TVariable *var : findInterpolantsTraverser.getInterpolants())
{
TType *replacementType = new TType(var->getType());
replacementType->setInterpolant(true);
TVariable *replacement =
new TVariable(&symbolTable, var->name(), replacementType, var->symbolType());
if (!ReplaceVariable(&compiler, &root, var, replacement))
{
return false;
}
}
// Wrap direct usages of interpolants with explicit interpolation
// functions depending on their auxiliary qualifiers
// in vec4 interpolant -> ANGLE_interpolateAtCenter(interpolant)
// centroid in vec4 interpolant -> ANGLE_interpolateAtCentroid(interpolant)
// sample in vec4 interpolant -> ANGLE_interpolateAtSample(interpolant, gl_SampleID)
WrapInterpolantsTraverser wrapInterpolantsTraverser(&symbolTable);
root.traverse(&wrapInterpolantsTraverser);
if (!wrapInterpolantsTraverser.updateTree(&compiler, &root))
{
return false;
}
*outUsesSampleInterpolant = wrapInterpolantsTraverser.usesSampleInterpolant();
return true;
}
} // namespace sh