blob: 98a4c8ca63918659ad0133f80101c0625a069baf [file] [log] [blame]
//
// Copyright 2025 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// EnsureLoopForwardProgress is an AST traverser that inserts volatile variable
// access inside loops which it cannot prove to be finite.
//
#include "compiler/translator/tree_ops/msl/EnsureLoopForwardProgress.h"
#include "compiler/translator/Compiler.h"
#include "compiler/translator/StaticType.h"
#include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
#include "compiler/translator/tree_util/IntermNode_util.h"
#include "compiler/translator/tree_util/IntermTraverse.h"
namespace sh
{
namespace
{
const TVariable *ViewSymbolVariable(TIntermTyped &node)
{
TIntermSymbol *symbol = node.getAsSymbolNode();
if (symbol == nullptr)
{
return nullptr;
}
return &symbol->variable();
}
bool IsReadOnlyExpr(TIntermTyped &node)
{
switch (node.getQualifier())
{
case EvqConst:
case EvqAttribute:
case EvqUniform:
case EvqVaryingIn:
case EvqSmoothIn:
case EvqFlatIn:
case EvqNoPerspectiveIn:
case EvqCentroidIn:
case EvqSampleIn:
case EvqNoPerspectiveCentroidIn:
case EvqNoPerspectiveSampleIn:
return true;
default:
break;
}
return false;
}
const TVariable *computeFiniteLoopVariable(TIntermLoop *loop)
{
// Currently matches only to loop of form:
// for (**; cond ; expr)
// where
// cond is `variable` `relation` `readonly symbol` and `variable`is of type int or uint
// expr increments or decrements the variable by one.
// Assumes ints wrap around in a defined way.
TIntermTyped *cond = loop->getCondition();
if (cond == nullptr)
{
return nullptr;
}
TIntermTyped *expr = loop->getExpression();
if (expr == nullptr)
{
return nullptr;
}
TIntermBinary *binCond = cond->getAsBinaryNode();
if (binCond == nullptr)
{
return nullptr;
}
const TVariable *variable = ViewSymbolVariable(*binCond->getLeft());
if (variable == nullptr)
{
return nullptr;
}
if (!IsInteger(variable->getType().getBasicType()))
{
return nullptr;
}
switch (binCond->getOp())
{
case EOpEqual:
case EOpNotEqual:
case EOpLessThan:
case EOpGreaterThan:
case EOpLessThanEqual:
case EOpGreaterThanEqual:
break;
default:
return nullptr;
}
// Loop index must be compared with a constant or uniform or similar read-only variable.
if (!IsReadOnlyExpr(*binCond->getRight()))
{
return nullptr;
}
if (TIntermUnary *unary = expr->getAsUnaryNode())
{
switch (unary->getOp())
{
case EOpPostIncrement:
case EOpPreIncrement:
case EOpPreDecrement:
case EOpPostDecrement:
break;
default:
return nullptr;
}
if (variable != ViewSymbolVariable(*unary->getOperand()))
{
return nullptr;
}
}
else if (TIntermBinary *binExpr = expr->getAsBinaryNode())
{
switch (binExpr->getOp())
{
case EOpAddAssign:
case EOpSubAssign:
break;
default:
return nullptr;
}
if (variable != ViewSymbolVariable(*binExpr->getLeft()))
{
return nullptr;
}
const TConstantUnion *value = binExpr->getRight()->getConstantValue();
if (value == nullptr)
{
return nullptr;
}
switch (value->getType())
{
case EbtInt:
if (value->getIConst() == -1 || value->getIConst() == 1)
{
break;
}
return nullptr;
case EbtUInt:
if (value->getUConst() == 1)
{
break;
}
return nullptr;
default:
UNREACHABLE();
return nullptr;
}
}
return variable;
}
class LoopInfoStack
{
public:
LoopInfoStack(TIntermLoop *node, LoopInfoStack *parent);
bool isFinite() const { return mVariable != nullptr; }
LoopInfoStack *getParent() const { return mParent; }
void setNotFinite() { mVariable = nullptr; }
LoopInfoStack *findLoopForVariable(const TVariable *variable);
private:
LoopInfoStack *mParent = nullptr;
const TVariable *mVariable = nullptr;
};
LoopInfoStack::LoopInfoStack(TIntermLoop *node, LoopInfoStack *parent)
: mParent(parent), mVariable(computeFiniteLoopVariable(node))
{}
LoopInfoStack *LoopInfoStack::findLoopForVariable(const TVariable *variable)
{
LoopInfoStack *info = this;
do
{
if (info->mVariable == variable)
{
return info;
}
info = info->mParent;
} while (info != nullptr);
return nullptr;
}
class EnsureLoopForwardProgressTraverser final : public TLValueTrackingTraverser
{
public:
EnsureLoopForwardProgressTraverser(TSymbolTable *symbolTable);
void visitSymbol(TIntermSymbol *node) override;
void traverseLoop(TIntermLoop *node) override;
private:
LoopInfoStack *mLoopInfoStack = nullptr;
};
EnsureLoopForwardProgressTraverser::EnsureLoopForwardProgressTraverser(TSymbolTable *symbolTable)
: TLValueTrackingTraverser(true, false, false, symbolTable)
{}
void EnsureLoopForwardProgressTraverser::traverseLoop(TIntermLoop *node)
{
LoopInfoStack loopInfo{node, mLoopInfoStack};
mLoopInfoStack = &loopInfo;
ScopedNodeInTraversalPath addToPath(this, node);
node->getBody()->traverse(this);
if (!loopInfo.isFinite())
{
TIntermBlock *newBody = new TIntermBlock();
TIntermSequence *sequence = newBody->getSequence();
sequence->push_back(CreateBuiltInFunctionCallNode("loopForwardProgress", {}, *mSymbolTable,
kESSLInternalBackendBuiltIns));
sequence->push_back(node->getBody());
node->setBody(newBody);
}
mLoopInfoStack = mLoopInfoStack->getParent();
}
void EnsureLoopForwardProgressTraverser::visitSymbol(TIntermSymbol *node)
{
if (!mLoopInfoStack)
{
return;
}
LoopInfoStack *loop = mLoopInfoStack->findLoopForVariable(&node->variable());
if (loop != nullptr && isLValueRequiredHere())
{
loop->setNotFinite();
}
}
} // namespace
bool EnsureLoopForwardProgress(TCompiler *compiler, TIntermNode *root)
{
EnsureLoopForwardProgressTraverser traverser(&compiler->getSymbolTable());
root->traverse(&traverser);
return traverser.updateTree(compiler, root);
}
} // namespace sh