Vulkan: Refactor atomic counter retype code
A generic "retyper" class is extracted out of the atomic counter retype
code to be used with coverting samplerCube to sampler2DArray for seamful
cubemap sampling emulation.
Bug: angleproject:3732
Change-Id: I8b5f835125b9513afcfe7baeea48afaf1299a027
Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/1733807
Commit-Queue: Shahbaz Youssefi <syoussefi@chromium.org>
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Reviewed-by: Tim Van Patten <timvp@google.com>
diff --git a/src/compiler/translator/Common.h b/src/compiler/translator/Common.h
index 179cfe2..be3fb2a 100644
--- a/src/compiler/translator/Common.h
+++ b/src/compiler/translator/Common.h
@@ -71,6 +71,7 @@
TVector() : std::vector<T, pool_allocator<T>>() {}
TVector(const pool_allocator<T> &a) : std::vector<T, pool_allocator<T>>(a) {}
TVector(size_type i) : std::vector<T, pool_allocator<T>>(i) {}
+ TVector(std::initializer_list<T> init) : std::vector<T, pool_allocator<T>>(init) {}
};
template <class K, class D, class H = std::hash<K>, class CMP = std::equal_to<K>>
diff --git a/src/compiler/translator/TranslatorVulkan.cpp b/src/compiler/translator/TranslatorVulkan.cpp
index 657dfcc..5f0c4c1 100644
--- a/src/compiler/translator/TranslatorVulkan.cpp
+++ b/src/compiler/translator/TranslatorVulkan.cpp
@@ -152,14 +152,6 @@
bool mInDefaultUniform;
};
-TIntermConstantUnion *CreateFloatConstant(float value)
-{
- const TType *constantType = StaticType::GetBasic<EbtFloat, 1>();
- TConstantUnion *constantValue = new TConstantUnion;
- constantValue->setFConst(value);
- return new TIntermConstantUnion(constantValue, *constantType);
-}
-
constexpr ImmutableString kFlippedPointCoordName = ImmutableString("flippedPointCoord");
constexpr ImmutableString kFlippedFragCoordName = ImmutableString("flippedFragCoord");
constexpr ImmutableString kEmulatedDepthRangeParams = ImmutableString("ANGLEDepthRangeParams");
@@ -223,8 +215,7 @@
TIntermSymbol *builtinRef = new TIntermSymbol(builtin);
// Create a swizzle to "builtin.y"
- TVector<int> swizzleOffsetY;
- swizzleOffsetY.push_back(1);
+ TVector<int> swizzleOffsetY = {1};
TIntermSwizzle *builtinY = new TIntermSwizzle(builtinRef, swizzleOffsetY);
// Create a symbol reference to our new variable that will hold the modified builtin.
@@ -296,16 +287,14 @@
TIntermSymbol *positionRef = new TIntermSymbol(position);
// Create a swizzle to "gl_Position.z"
- TVector<int> swizzleOffsetZ;
- swizzleOffsetZ.push_back(2);
+ TVector<int> swizzleOffsetZ = {2};
TIntermSwizzle *positionZ = new TIntermSwizzle(positionRef, swizzleOffsetZ);
// Create a constant "0.5"
- TIntermConstantUnion *oneHalf = CreateFloatConstant(0.5f);
+ TIntermConstantUnion *oneHalf = CreateFloatNode(0.5f);
// Create a swizzle to "gl_Position.w"
- TVector<int> swizzleOffsetW;
- swizzleOffsetW.push_back(3);
+ TVector<int> swizzleOffsetW = {3};
TIntermSwizzle *positionW = new TIntermSwizzle(positionRef->deepCopy(), swizzleOffsetW);
// Create the expression "(gl_Position.z + gl_Position.w) * 0.5".
@@ -517,27 +506,22 @@
// Create a swizzle to "ANGLEUniforms.viewport.xy".
TIntermBinary *viewportRef = CreateDriverUniformRef(driverUniforms, kViewport);
- TVector<int> swizzleOffsetXY;
- swizzleOffsetXY.push_back(0);
- swizzleOffsetXY.push_back(1);
+ TVector<int> swizzleOffsetXY = {0, 1};
TIntermSwizzle *viewportXY = new TIntermSwizzle(viewportRef->deepCopy(), swizzleOffsetXY);
// Create a swizzle to "ANGLEUniforms.viewport.zw".
- TVector<int> swizzleOffsetZW;
- swizzleOffsetZW.push_back(2);
- swizzleOffsetZW.push_back(3);
+ TVector<int> swizzleOffsetZW = {2, 3};
TIntermSwizzle *viewportZW = new TIntermSwizzle(viewportRef, swizzleOffsetZW);
// ANGLEPosition.xy / ANGLEPosition.w
TIntermSymbol *position = new TIntermSymbol(anglePosition);
TIntermSwizzle *positionXY = new TIntermSwizzle(position, swizzleOffsetXY);
- TVector<int> swizzleOffsetW;
- swizzleOffsetW.push_back(3);
+ TVector<int> swizzleOffsetW = {3};
TIntermSwizzle *positionW = new TIntermSwizzle(position->deepCopy(), swizzleOffsetW);
TIntermBinary *positionNDC = new TIntermBinary(EOpDiv, positionXY, positionW);
// ANGLEPosition * 0.5
- TIntermConstantUnion *oneHalf = CreateFloatConstant(0.5f);
+ TIntermConstantUnion *oneHalf = CreateFloatNode(0.5f);
TIntermBinary *halfPosition = new TIntermBinary(EOpVectorTimesScalar, positionNDC, oneHalf);
// (ANGLEPosition * 0.5) + 0.5
@@ -575,7 +559,7 @@
TIntermBinary *baSq = new TIntermBinary(EOpMul, ba, ba->deepCopy());
// 2.0 * ba * ba
- TIntermTyped *two = CreateFloatConstant(2.0f);
+ TIntermTyped *two = CreateFloatNode(2.0f);
TIntermBinary *twoBaSq = new TIntermBinary(EOpVectorTimesScalar, baSq, two);
// Assign to a temporary "ba2".
@@ -583,9 +567,7 @@
TIntermDeclaration *ba2Decl = CreateTempInitDeclarationNode(ba2Temp, twoBaSq);
// Create a swizzle to "ba2.yx".
- TVector<int> swizzleOffsetYX;
- swizzleOffsetYX.push_back(1);
- swizzleOffsetYX.push_back(0);
+ TVector<int> swizzleOffsetYX = {1, 0};
TIntermSymbol *ba2 = CreateTempSymbolNode(ba2Temp);
TIntermSwizzle *ba2YX = new TIntermSwizzle(ba2, swizzleOffsetYX);
@@ -599,21 +581,19 @@
TIntermSymbol *bp = CreateTempSymbolNode(bpTemp);
// Create a swizzle to "bp.x".
- TVector<int> swizzleOffsetX;
- swizzleOffsetX.push_back(0);
+ TVector<int> swizzleOffsetX = {0};
TIntermSwizzle *bpX = new TIntermSwizzle(bp, swizzleOffsetX);
// Using a small epsilon value ensures that we don't suffer from numerical instability when
// lines are exactly vertical or horizontal.
static constexpr float kEpisilon = 0.00001f;
- TIntermConstantUnion *epsilon = CreateFloatConstant(kEpisilon);
+ TIntermConstantUnion *epsilon = CreateFloatNode(kEpisilon);
// bp.x > epsilon
TIntermBinary *checkX = new TIntermBinary(EOpGreaterThan, bpX, epsilon);
// Create a swizzle to "bp.y".
- TVector<int> swizzleOffsetY;
- swizzleOffsetY.push_back(1);
+ TVector<int> swizzleOffsetY = {1};
TIntermSwizzle *bpY = new TIntermSwizzle(bp->deepCopy(), swizzleOffsetY);
// bp.y > epsilon
@@ -798,7 +778,7 @@
{
TIntermBinary *viewportYScale =
CreateDriverUniformRef(driverUniforms, kNegViewportYScale);
- TIntermConstantUnion *pivot = CreateFloatConstant(0.5f);
+ TIntermConstantUnion *pivot = CreateFloatNode(0.5f);
FlipBuiltinVariable(root, GetMainSequence(root), viewportYScale, &getSymbolTable(),
BuiltInVariable::gl_PointCoord(), kFlippedPointCoordName, pivot);
}
diff --git a/src/compiler/translator/tree_ops/RewriteAtomicCounters.cpp b/src/compiler/translator/tree_ops/RewriteAtomicCounters.cpp
index 7f7b88c..5067d27 100644
--- a/src/compiler/translator/tree_ops/RewriteAtomicCounters.cpp
+++ b/src/compiler/translator/tree_ops/RewriteAtomicCounters.cpp
@@ -13,6 +13,7 @@
#include "compiler/translator/SymbolTable.h"
#include "compiler/translator/tree_util/IntermNode_util.h"
#include "compiler/translator/tree_util/IntermTraverse.h"
+#include "compiler/translator/tree_util/ReplaceVariable.h"
namespace sh
{
@@ -150,10 +151,6 @@
: TIntermTraverser(true, true, true, symbolTable),
mAtomicCounters(atomicCounters),
mAcbBufferOffsets(acbBufferOffsets),
- mCurrentAtomicCounterOffset(0),
- mCurrentAtomicCounterBinding(0),
- mCurrentAtomicCounterDecl(nullptr),
- mCurrentAtomicCounterDeclParent(nullptr),
mAtomicCounterType(nullptr),
mAtomicCounterTypeConst(nullptr),
mAtomicCounterTypeDeclaration(nullptr)
@@ -161,86 +158,59 @@
bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
{
+ if (visit != PreVisit)
+ {
+ return true;
+ }
+
const TIntermSequence &sequence = *(node->getSequence());
TIntermTyped *variable = sequence.front()->getAsTyped();
const TType &type = variable->getType();
bool isAtomicCounter = type.getQualifier() == EvqUniform && type.isAtomicCounter();
- if (visit == PreVisit || visit == InVisit)
+ if (isAtomicCounter)
{
- if (isAtomicCounter)
- {
- mCurrentAtomicCounterDecl = node;
- mCurrentAtomicCounterDeclParent = getParentNode()->getAsBlock();
- mCurrentAtomicCounterOffset = type.getLayoutQualifier().offset;
- mCurrentAtomicCounterBinding = type.getLayoutQualifier().binding;
- }
+ // Atomic counters cannot have initializers, so the declaration must necessarily be a
+ // symbol.
+ TIntermSymbol *samplerVariable = variable->getAsSymbolNode();
+ ASSERT(samplerVariable != nullptr);
+
+ declareAtomicCounter(&samplerVariable->variable(), node);
+ return false;
}
- else if (visit == PostVisit)
- {
- mCurrentAtomicCounterDecl = nullptr;
- mCurrentAtomicCounterDeclParent = nullptr;
- mCurrentAtomicCounterOffset = 0;
- mCurrentAtomicCounterBinding = 0;
- }
+
return true;
}
void visitFunctionPrototype(TIntermFunctionPrototype *node) override
{
const TFunction *function = node->getFunction();
- // Go over the parameters and replace the atomic arguments with a uint type. If this is
- // the function definition, keep the replaced variable for future encounters.
- mAtomicCounterFunctionParams.clear();
+ // Go over the parameters and replace the atomic arguments with a uint type.
+ mRetyper.visitFunctionPrototype();
for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
{
const TVariable *param = function->getParam(paramIndex);
TVariable *replacement = convertFunctionParameter(node, param);
if (replacement)
{
- mAtomicCounterFunctionParams[param] = replacement;
+ mRetyper.replaceFunctionParam(param, replacement);
}
}
- if (mAtomicCounterFunctionParams.empty())
- {
- return;
- }
-
- // Create a new function prototype and replace this with it.
- TFunction *replacementFunction = new TFunction(
- mSymbolTable, function->name(), SymbolType::UserDefined,
- new TType(function->getReturnType()), function->isKnownToNotHaveSideEffects());
- for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
- {
- const TVariable *param = function->getParam(paramIndex);
- TVariable *replacement = nullptr;
- if (param->getType().isAtomicCounter())
- {
- ASSERT(mAtomicCounterFunctionParams.count(param) != 0);
- replacement = mAtomicCounterFunctionParams[param];
- }
- else
- {
- replacement = new TVariable(mSymbolTable, param->name(),
- new TType(param->getType()), SymbolType::UserDefined);
- }
- replacementFunction->addParameter(replacement);
- }
-
TIntermFunctionPrototype *replacementPrototype =
- new TIntermFunctionPrototype(replacementFunction);
- queueReplacement(replacementPrototype, OriginalNode::IS_DROPPED);
-
- mReplacedFunctions[function] = replacementFunction;
+ mRetyper.convertFunctionPrototype(mSymbolTable, function);
+ if (replacementPrototype)
+ {
+ queueReplacement(replacementPrototype, OriginalNode::IS_DROPPED);
+ }
}
bool visitAggregate(Visit visit, TIntermAggregate *node) override
{
if (visit == PreVisit)
{
- mAtomicCounterFunctionCallArgs.clear();
+ mRetyper.preVisitAggregate();
}
if (visit != PostVisit)
@@ -254,8 +224,13 @@
}
else if (node->getOp() == EOpCallFunctionInAST)
{
- convertASTFunction(node);
+ TIntermAggregate *substituteCall = mRetyper.convertASTFunction(node);
+ if (substituteCall)
+ {
+ queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
+ }
}
+ mRetyper.postVisitAggregate();
return true;
}
@@ -264,12 +239,6 @@
{
const TVariable *symbolVariable = &symbol->variable();
- if (mCurrentAtomicCounterDecl)
- {
- declareAtomicCounter(symbolVariable);
- return;
- }
-
if (!symbol->getType().isAtomicCounter())
{
return;
@@ -336,22 +305,14 @@
// atomicAdd(atomicCounters[ac.binding]counters[ac.offset+n]);
// }
//
- // In all cases, the argument transformation is stored in |mAtomicCounterFunctionCallArgs|.
- // In the function call's PostVisit, if it's a builtin, the look up in
- // |atomicCounters.counters| is done as well as the builtin function change. Otherwise,
- // the transformed argument is passed on as is.
+ // In all cases, the argument transformation is stored in mRetyper. In the function call's
+ // PostVisit, if it's a builtin, the look up in |atomicCounters.counters| is done as well as
+ // the builtin function change. Otherwise, the transformed argument is passed on as is.
//
- TIntermTyped *bindingOffset = nullptr;
- if (mAtomicCounterBindingOffsets.count(symbolVariable) != 0)
- {
- bindingOffset = new TIntermSymbol(mAtomicCounterBindingOffsets[symbolVariable]);
- }
- else
- {
- ASSERT(mAtomicCounterFunctionParams.count(symbolVariable) != 0);
- bindingOffset = new TIntermSymbol(mAtomicCounterFunctionParams[symbolVariable]);
- }
+ TIntermTyped *bindingOffset =
+ new TIntermSymbol(mRetyper.getVariableReplacement(symbolVariable));
+ ASSERT(bindingOffset != nullptr);
TIntermNode *argument = symbol;
@@ -389,22 +350,21 @@
TIntermBinary *modifiedOffset = new TIntermBinary(
EOpAddAssign, offsetField, arrayExpression->getRight()->deepCopy());
- TIntermSequence *modifySequence = new TIntermSequence();
- modifySequence->push_back(modifiedDecl);
- modifySequence->push_back(modifiedOffset);
+ TIntermSequence *modifySequence =
+ new TIntermSequence({modifiedDecl, modifiedOffset});
insertStatementsInParentBlock(*modifySequence);
bindingOffset = modifiedSymbol->deepCopy();
}
}
- mAtomicCounterFunctionCallArgs[argument] = bindingOffset;
+ mRetyper.replaceFunctionCallArg(argument, bindingOffset);
}
TIntermDeclaration *getAtomicCounterTypeDeclaration() { return mAtomicCounterTypeDeclaration; }
private:
- void declareAtomicCounter(const TVariable *symbolVariable)
+ void declareAtomicCounter(const TVariable *atomicCounterVar, TIntermDeclaration *node)
{
// Create a global variable that contains the binding and offset of this atomic counter
// declaration.
@@ -414,12 +374,16 @@
}
ASSERT(mAtomicCounterTypeConst);
- TVariable *bindingOffset = new TVariable(mSymbolTable, symbolVariable->name(),
+ TVariable *bindingOffset = new TVariable(mSymbolTable, atomicCounterVar->name(),
mAtomicCounterTypeConst, SymbolType::UserDefined);
- ASSERT(mCurrentAtomicCounterOffset % 4 == 0);
- TIntermTyped *bindingOffsetInitValue = CreateAtomicCounterConstant(
- mAtomicCounterTypeConst, mCurrentAtomicCounterBinding, mCurrentAtomicCounterOffset / 4);
+ const TType &atomicCounterType = atomicCounterVar->getType();
+ uint32_t offset = atomicCounterType.getLayoutQualifier().offset;
+ uint32_t binding = atomicCounterType.getLayoutQualifier().binding;
+
+ ASSERT(offset % 4 == 0);
+ TIntermTyped *bindingOffsetInitValue =
+ CreateAtomicCounterConstant(mAtomicCounterTypeConst, binding, offset / 4);
TIntermSymbol *bindingOffsetSymbol = new TIntermSymbol(bindingOffset);
TIntermBinary *bindingOffsetInit =
@@ -431,11 +395,10 @@
// Replace the atomic_uint declaration with the binding/offset declaration.
TIntermSequence replacement;
replacement.push_back(bindingOffsetDeclaration);
- mMultiReplacements.emplace_back(mCurrentAtomicCounterDeclParent, mCurrentAtomicCounterDecl,
- replacement);
+ mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, replacement);
// Remember the binding/offset variable.
- mAtomicCounterBindingOffsets[symbolVariable] = bindingOffset;
+ mRetyper.replaceGlobalVariable(atomicCounterVar, bindingOffset);
}
void declareAtomicCounterType()
@@ -529,9 +492,8 @@
}
const TIntermNode *param = (*arguments)[0];
- ASSERT(mAtomicCounterFunctionCallArgs.count(param) != 0);
- TIntermTyped *bindingOffset = mAtomicCounterFunctionCallArgs[param];
+ TIntermTyped *bindingOffset = mRetyper.getFunctionCallArgReplacement(param);
TIntermSequence *substituteArguments = new TIntermSequence;
substituteArguments->push_back(
@@ -551,60 +513,10 @@
queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
}
- void convertASTFunction(TIntermAggregate *node)
- {
- // See if the function needs replacement at all.
- const TFunction *function = node->getFunction();
- if (mReplacedFunctions.count(function) == 0)
- {
- return;
- }
-
- // atomic_uint arguments to this call are staged to be replaced at the same time.
- TFunction *substituteFunction = mReplacedFunctions[function];
- TIntermSequence *substituteArguments = new TIntermSequence;
-
- for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
- {
- TIntermNode *param = node->getChildNode(paramIndex);
-
- TIntermNode *replacement = nullptr;
- if (param->getAsTyped()->getType().isAtomicCounter())
- {
- ASSERT(mAtomicCounterFunctionCallArgs.count(param) != 0);
- replacement = mAtomicCounterFunctionCallArgs[param];
- }
- else
- {
- replacement = param->getAsTyped()->deepCopy();
- }
- substituteArguments->push_back(replacement);
- }
-
- TIntermTyped *substituteCall =
- TIntermAggregate::CreateFunctionCall(*substituteFunction, substituteArguments);
-
- queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
- }
-
const TVariable *mAtomicCounters;
const TIntermTyped *mAcbBufferOffsets;
- // A map from the atomic_uint variable to the binding/offset declaration.
- std::unordered_map<const TVariable *, TVariable *> mAtomicCounterBindingOffsets;
- // A map from functions with atomic_uint parameters to one where that's replaced with uint.
- std::unordered_map<const TFunction *, TFunction *> mReplacedFunctions;
- // A map from atomic_uint function parameters to their replacement uint parameter for the
- // current function definition.
- std::unordered_map<const TVariable *, TVariable *> mAtomicCounterFunctionParams;
- // A map from atomic_uint function call arguments to their replacement for the current
- // non-builtin function call.
- std::unordered_map<const TIntermNode *, TIntermTyped *> mAtomicCounterFunctionCallArgs;
-
- uint32_t mCurrentAtomicCounterOffset;
- uint32_t mCurrentAtomicCounterBinding;
- TIntermDeclaration *mCurrentAtomicCounterDecl;
- TIntermAggregateBase *mCurrentAtomicCounterDeclParent;
+ RetypeOpaqueVariablesHelper mRetyper;
TType *mAtomicCounterType;
TType *mAtomicCounterTypeConst;
diff --git a/src/compiler/translator/tree_util/IntermNode_util.cpp b/src/compiler/translator/tree_util/IntermNode_util.cpp
index 3fa7d57..a9976d4 100644
--- a/src/compiler/translator/tree_util/IntermNode_util.cpp
+++ b/src/compiler/translator/tree_util/IntermNode_util.cpp
@@ -113,14 +113,22 @@
return TIntermAggregate::CreateConstructor(constType, arguments);
}
+TIntermConstantUnion *CreateFloatNode(float value)
+{
+ TConstantUnion *u = new TConstantUnion[1];
+ u[0].setFConst(value);
+
+ TType type(EbtFloat, EbpUndefined, EvqConst, 1);
+ return new TIntermConstantUnion(u, type);
+}
+
TIntermConstantUnion *CreateIndexNode(int index)
{
TConstantUnion *u = new TConstantUnion[1];
u[0].setIConst(index);
TType type(EbtInt, EbpUndefined, EvqConst, 1);
- TIntermConstantUnion *node = new TIntermConstantUnion(u, type);
- return node;
+ return new TIntermConstantUnion(u, type);
}
TIntermConstantUnion *CreateBoolNode(bool value)
@@ -129,8 +137,7 @@
u[0].setBConst(value);
TType type(EbtBool, EbpUndefined, EvqConst, 1);
- TIntermConstantUnion *node = new TIntermConstantUnion(u, type);
- return node;
+ return new TIntermConstantUnion(u, type);
}
TVariable *CreateTempVariable(TSymbolTable *symbolTable, const TType *type)
diff --git a/src/compiler/translator/tree_util/IntermNode_util.h b/src/compiler/translator/tree_util/IntermNode_util.h
index 7d1d421..d0923f0 100644
--- a/src/compiler/translator/tree_util/IntermNode_util.h
+++ b/src/compiler/translator/tree_util/IntermNode_util.h
@@ -23,6 +23,7 @@
TIntermBlock *functionBody);
TIntermTyped *CreateZeroNode(const TType &type);
+TIntermConstantUnion *CreateFloatNode(float value);
TIntermConstantUnion *CreateIndexNode(int index);
TIntermConstantUnion *CreateBoolNode(bool value);
diff --git a/src/compiler/translator/tree_util/ReplaceVariable.cpp b/src/compiler/translator/tree_util/ReplaceVariable.cpp
index 7120cea..f9555be 100644
--- a/src/compiler/translator/tree_util/ReplaceVariable.cpp
+++ b/src/compiler/translator/tree_util/ReplaceVariable.cpp
@@ -9,6 +9,7 @@
#include "compiler/translator/tree_util/ReplaceVariable.h"
#include "compiler/translator/IntermNode.h"
+#include "compiler/translator/Symbol.h"
#include "compiler/translator/tree_util/IntermTraverse.h"
namespace sh
@@ -61,4 +62,75 @@
traverser.updateTree();
}
+TIntermFunctionPrototype *RetypeOpaqueVariablesHelper::convertFunctionPrototype(
+ TSymbolTable *symbolTable,
+ const TFunction *oldFunction)
+{
+ if (mReplacedFunctionParams.empty())
+ {
+ return nullptr;
+ }
+
+ // Create a new function prototype for replacement.
+ TFunction *replacementFunction = new TFunction(
+ symbolTable, oldFunction->name(), SymbolType::UserDefined,
+ new TType(oldFunction->getReturnType()), oldFunction->isKnownToNotHaveSideEffects());
+ for (size_t paramIndex = 0; paramIndex < oldFunction->getParamCount(); ++paramIndex)
+ {
+ const TVariable *param = oldFunction->getParam(paramIndex);
+ TVariable *replacement = nullptr;
+ auto replaced = mReplacedFunctionParams.find(param);
+ if (replaced != mReplacedFunctionParams.end())
+ {
+ replacement = replaced->second;
+ }
+ else
+ {
+ replacement = new TVariable(symbolTable, param->name(), new TType(param->getType()),
+ SymbolType::UserDefined);
+ }
+ replacementFunction->addParameter(replacement);
+ }
+ mReplacedFunctions[oldFunction] = replacementFunction;
+
+ TIntermFunctionPrototype *replacementPrototype =
+ new TIntermFunctionPrototype(replacementFunction);
+
+ return replacementPrototype;
+}
+
+TIntermAggregate *RetypeOpaqueVariablesHelper::convertASTFunction(TIntermAggregate *node)
+{
+ // See if the function needs replacement at all.
+ const TFunction *function = node->getFunction();
+ auto replacedFunction = mReplacedFunctions.find(function);
+ if (replacedFunction == mReplacedFunctions.end())
+ {
+ return nullptr;
+ }
+
+ // Arguments to this call are staged to be replaced at the same time.
+ TFunction *substituteFunction = replacedFunction->second;
+ TIntermSequence *substituteArguments = new TIntermSequence;
+
+ for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
+ {
+ TIntermNode *param = node->getChildNode(paramIndex);
+
+ TIntermNode *replacement = nullptr;
+ auto replacedArg = mReplacedFunctionCallArgs.top().find(param);
+ if (replacedArg != mReplacedFunctionCallArgs.top().end())
+ {
+ replacement = replacedArg->second;
+ }
+ else
+ {
+ replacement = param->getAsTyped()->deepCopy();
+ }
+ substituteArguments->push_back(replacement);
+ }
+
+ return TIntermAggregate::CreateFunctionCall(*substituteFunction, substituteArguments);
+}
+
} // namespace sh
diff --git a/src/compiler/translator/tree_util/ReplaceVariable.h b/src/compiler/translator/tree_util/ReplaceVariable.h
index 3a9e34e..d2b1c8d 100644
--- a/src/compiler/translator/tree_util/ReplaceVariable.h
+++ b/src/compiler/translator/tree_util/ReplaceVariable.h
@@ -9,12 +9,22 @@
#ifndef COMPILER_TRANSLATOR_TREEUTIL_REPLACEVARIABLE_H_
#define COMPILER_TRANSLATOR_TREEUTIL_REPLACEVARIABLE_H_
+#include "common/debug.h"
+
+#include <stack>
+#include <unordered_map>
+
namespace sh
{
+class TFunction;
+class TIntermAggregate;
class TIntermBlock;
-class TVariable;
+class TIntermFunctionPrototype;
+class TIntermNode;
class TIntermTyped;
+class TSymbolTable;
+class TVariable;
void ReplaceVariable(TIntermBlock *root,
const TVariable *toBeReplaced,
@@ -22,6 +32,108 @@
void ReplaceVariableWithTyped(TIntermBlock *root,
const TVariable *toBeReplaced,
const TIntermTyped *replacement);
+
+// A helper class to keep track of opaque variable re-typing during a pass. Unlike the above
+// functions, this can be used to replace all opaque variables of a certain type with another in a
+// pass that possibly does other related transformations. Only opaque variables are supported as
+// replacing local variables is not supported.
+//
+// The class uses "old" to refer to the original type of the variable and "new" to refer to the type
+// that will replace it.
+//
+// - replaceGlobalVariable(): Call to track a global variable that is replaced.
+// - in TIntermTraverser::visitFunctionPrototype():
+// * Call visitFunctionPrototype().
+// * For every replaced parameter, call replaceFunctionParam().
+// * call convertFunctionPrototype() to convert the prototype based on the above replacements
+// and track the function with its replacement.
+// * Call replaceFunction() to track the function that is replaced.
+// - In PreVisit of TIntermTraverser::visitAggregate():
+// * call preVisitAggregate()
+// - In TIntermTraverser::visitSymbol():
+// * Replace non-function-call-argument symbols that refer to a global or function param with the
+// replacement (getVariableReplacement()).
+// * For function call arguments, call replaceFunctionCallArg() to track the replacement.
+// - In PostVisit of TIntermTraverser::visitAggregate():
+// * Convert built-in functions per case. Call convertASTFunction() for non built-in functions
+// for the replacement to be created.
+// * Call postVisitAggregate() when done.
+//
+class RetypeOpaqueVariablesHelper
+{
+ public:
+ RetypeOpaqueVariablesHelper() {}
+ ~RetypeOpaqueVariablesHelper() {}
+
+ // Global variable handling:
+ void replaceGlobalVariable(const TVariable *oldVar, TVariable *newVar)
+ {
+ ASSERT(mReplacedGlobalVariables.count(oldVar) == 0);
+ mReplacedGlobalVariables[oldVar] = newVar;
+ }
+ TVariable *getVariableReplacement(const TVariable *oldVar) const
+ {
+ if (mReplacedGlobalVariables.count(oldVar) != 0)
+ {
+ return mReplacedGlobalVariables.at(oldVar);
+ }
+ else
+ {
+ // This function should only be called if the variable is expected to have been
+ // replaced either way (as a global variable or a function parameter).
+ ASSERT(mReplacedFunctionParams.count(oldVar) != 0);
+ return mReplacedFunctionParams.at(oldVar);
+ }
+ }
+
+ // Function parameters handling:
+ void visitFunctionPrototype() { mReplacedFunctionParams.clear(); }
+ void replaceFunctionParam(const TVariable *oldParam, TVariable *newParam)
+ {
+ ASSERT(mReplacedFunctionParams.count(oldParam) == 0);
+ mReplacedFunctionParams[oldParam] = newParam;
+ }
+ TVariable *getFunctionParamReplacement(const TVariable *oldParam) const
+ {
+ ASSERT(mReplacedFunctionParams.count(oldParam) != 0);
+ return mReplacedFunctionParams.at(oldParam);
+ }
+
+ // Function call arguments handling:
+ void preVisitAggregate() { mReplacedFunctionCallArgs.emplace(); }
+ void postVisitAggregate() { mReplacedFunctionCallArgs.pop(); }
+ void replaceFunctionCallArg(const TIntermNode *oldArg, TIntermTyped *newArg)
+ {
+ ASSERT(mReplacedFunctionCallArgs.top().count(oldArg) == 0);
+ mReplacedFunctionCallArgs.top()[oldArg] = newArg;
+ }
+ TIntermTyped *getFunctionCallArgReplacement(const TIntermNode *oldArg) const
+ {
+ ASSERT(mReplacedFunctionCallArgs.top().count(oldArg) != 0);
+ return mReplacedFunctionCallArgs.top().at(oldArg);
+ }
+
+ // Helper code conversion methods.
+ TIntermFunctionPrototype *convertFunctionPrototype(TSymbolTable *symbolTable,
+ const TFunction *oldFunction);
+ TIntermAggregate *convertASTFunction(TIntermAggregate *node);
+
+ private:
+ // A map from the old global variable to the new one.
+ std::unordered_map<const TVariable *, TVariable *> mReplacedGlobalVariables;
+
+ // A map from functions with old type parameters to one where that's replaced with the new type.
+ std::unordered_map<const TFunction *, TFunction *> mReplacedFunctions;
+
+ // A map from function old type parameters to their replacement new type parameter for the
+ // current function definition.
+ std::unordered_map<const TVariable *, TVariable *> mReplacedFunctionParams;
+
+ // A map from function call old type arguments to their replacement for the current function
+ // call.
+ std::stack<std::unordered_map<const TIntermNode *, TIntermTyped *>> mReplacedFunctionCallArgs;
+};
+
} // namespace sh
#endif // COMPILER_TRANSLATOR_TREEUTIL_REPLACEVARIABLE_H_