Vulkan: Refactor atomic counter retype code

A generic "retyper" class is extracted out of the atomic counter retype
code to be used with coverting samplerCube to sampler2DArray for seamful
cubemap sampling emulation.

Bug: angleproject:3732
Change-Id: I8b5f835125b9513afcfe7baeea48afaf1299a027
Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/1733807
Commit-Queue: Shahbaz Youssefi <syoussefi@chromium.org>
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Reviewed-by: Tim Van Patten <timvp@google.com>
diff --git a/src/compiler/translator/Common.h b/src/compiler/translator/Common.h
index 179cfe2..be3fb2a 100644
--- a/src/compiler/translator/Common.h
+++ b/src/compiler/translator/Common.h
@@ -71,6 +71,7 @@
     TVector() : std::vector<T, pool_allocator<T>>() {}
     TVector(const pool_allocator<T> &a) : std::vector<T, pool_allocator<T>>(a) {}
     TVector(size_type i) : std::vector<T, pool_allocator<T>>(i) {}
+    TVector(std::initializer_list<T> init) : std::vector<T, pool_allocator<T>>(init) {}
 };
 
 template <class K, class D, class H = std::hash<K>, class CMP = std::equal_to<K>>
diff --git a/src/compiler/translator/TranslatorVulkan.cpp b/src/compiler/translator/TranslatorVulkan.cpp
index 657dfcc..5f0c4c1 100644
--- a/src/compiler/translator/TranslatorVulkan.cpp
+++ b/src/compiler/translator/TranslatorVulkan.cpp
@@ -152,14 +152,6 @@
     bool mInDefaultUniform;
 };
 
-TIntermConstantUnion *CreateFloatConstant(float value)
-{
-    const TType *constantType     = StaticType::GetBasic<EbtFloat, 1>();
-    TConstantUnion *constantValue = new TConstantUnion;
-    constantValue->setFConst(value);
-    return new TIntermConstantUnion(constantValue, *constantType);
-}
-
 constexpr ImmutableString kFlippedPointCoordName    = ImmutableString("flippedPointCoord");
 constexpr ImmutableString kFlippedFragCoordName     = ImmutableString("flippedFragCoord");
 constexpr ImmutableString kEmulatedDepthRangeParams = ImmutableString("ANGLEDepthRangeParams");
@@ -223,8 +215,7 @@
     TIntermSymbol *builtinRef = new TIntermSymbol(builtin);
 
     // Create a swizzle to "builtin.y"
-    TVector<int> swizzleOffsetY;
-    swizzleOffsetY.push_back(1);
+    TVector<int> swizzleOffsetY = {1};
     TIntermSwizzle *builtinY = new TIntermSwizzle(builtinRef, swizzleOffsetY);
 
     // Create a symbol reference to our new variable that will hold the modified builtin.
@@ -296,16 +287,14 @@
     TIntermSymbol *positionRef = new TIntermSymbol(position);
 
     // Create a swizzle to "gl_Position.z"
-    TVector<int> swizzleOffsetZ;
-    swizzleOffsetZ.push_back(2);
+    TVector<int> swizzleOffsetZ = {2};
     TIntermSwizzle *positionZ = new TIntermSwizzle(positionRef, swizzleOffsetZ);
 
     // Create a constant "0.5"
-    TIntermConstantUnion *oneHalf = CreateFloatConstant(0.5f);
+    TIntermConstantUnion *oneHalf = CreateFloatNode(0.5f);
 
     // Create a swizzle to "gl_Position.w"
-    TVector<int> swizzleOffsetW;
-    swizzleOffsetW.push_back(3);
+    TVector<int> swizzleOffsetW = {3};
     TIntermSwizzle *positionW = new TIntermSwizzle(positionRef->deepCopy(), swizzleOffsetW);
 
     // Create the expression "(gl_Position.z + gl_Position.w) * 0.5".
@@ -517,27 +506,22 @@
 
     // Create a swizzle to "ANGLEUniforms.viewport.xy".
     TIntermBinary *viewportRef = CreateDriverUniformRef(driverUniforms, kViewport);
-    TVector<int> swizzleOffsetXY;
-    swizzleOffsetXY.push_back(0);
-    swizzleOffsetXY.push_back(1);
+    TVector<int> swizzleOffsetXY = {0, 1};
     TIntermSwizzle *viewportXY = new TIntermSwizzle(viewportRef->deepCopy(), swizzleOffsetXY);
 
     // Create a swizzle to "ANGLEUniforms.viewport.zw".
-    TVector<int> swizzleOffsetZW;
-    swizzleOffsetZW.push_back(2);
-    swizzleOffsetZW.push_back(3);
+    TVector<int> swizzleOffsetZW = {2, 3};
     TIntermSwizzle *viewportZW = new TIntermSwizzle(viewportRef, swizzleOffsetZW);
 
     // ANGLEPosition.xy / ANGLEPosition.w
     TIntermSymbol *position    = new TIntermSymbol(anglePosition);
     TIntermSwizzle *positionXY = new TIntermSwizzle(position, swizzleOffsetXY);
-    TVector<int> swizzleOffsetW;
-    swizzleOffsetW.push_back(3);
+    TVector<int> swizzleOffsetW = {3};
     TIntermSwizzle *positionW  = new TIntermSwizzle(position->deepCopy(), swizzleOffsetW);
     TIntermBinary *positionNDC = new TIntermBinary(EOpDiv, positionXY, positionW);
 
     // ANGLEPosition * 0.5
-    TIntermConstantUnion *oneHalf = CreateFloatConstant(0.5f);
+    TIntermConstantUnion *oneHalf = CreateFloatNode(0.5f);
     TIntermBinary *halfPosition   = new TIntermBinary(EOpVectorTimesScalar, positionNDC, oneHalf);
 
     // (ANGLEPosition * 0.5) + 0.5
@@ -575,7 +559,7 @@
     TIntermBinary *baSq = new TIntermBinary(EOpMul, ba, ba->deepCopy());
 
     // 2.0 * ba * ba
-    TIntermTyped *two      = CreateFloatConstant(2.0f);
+    TIntermTyped *two      = CreateFloatNode(2.0f);
     TIntermBinary *twoBaSq = new TIntermBinary(EOpVectorTimesScalar, baSq, two);
 
     // Assign to a temporary "ba2".
@@ -583,9 +567,7 @@
     TIntermDeclaration *ba2Decl = CreateTempInitDeclarationNode(ba2Temp, twoBaSq);
 
     // Create a swizzle to "ba2.yx".
-    TVector<int> swizzleOffsetYX;
-    swizzleOffsetYX.push_back(1);
-    swizzleOffsetYX.push_back(0);
+    TVector<int> swizzleOffsetYX = {1, 0};
     TIntermSymbol *ba2    = CreateTempSymbolNode(ba2Temp);
     TIntermSwizzle *ba2YX = new TIntermSwizzle(ba2, swizzleOffsetYX);
 
@@ -599,21 +581,19 @@
     TIntermSymbol *bp          = CreateTempSymbolNode(bpTemp);
 
     // Create a swizzle to "bp.x".
-    TVector<int> swizzleOffsetX;
-    swizzleOffsetX.push_back(0);
+    TVector<int> swizzleOffsetX = {0};
     TIntermSwizzle *bpX = new TIntermSwizzle(bp, swizzleOffsetX);
 
     // Using a small epsilon value ensures that we don't suffer from numerical instability when
     // lines are exactly vertical or horizontal.
     static constexpr float kEpisilon = 0.00001f;
-    TIntermConstantUnion *epsilon    = CreateFloatConstant(kEpisilon);
+    TIntermConstantUnion *epsilon    = CreateFloatNode(kEpisilon);
 
     // bp.x > epsilon
     TIntermBinary *checkX = new TIntermBinary(EOpGreaterThan, bpX, epsilon);
 
     // Create a swizzle to "bp.y".
-    TVector<int> swizzleOffsetY;
-    swizzleOffsetY.push_back(1);
+    TVector<int> swizzleOffsetY = {1};
     TIntermSwizzle *bpY = new TIntermSwizzle(bp->deepCopy(), swizzleOffsetY);
 
     // bp.y > epsilon
@@ -798,7 +778,7 @@
         {
             TIntermBinary *viewportYScale =
                 CreateDriverUniformRef(driverUniforms, kNegViewportYScale);
-            TIntermConstantUnion *pivot = CreateFloatConstant(0.5f);
+            TIntermConstantUnion *pivot = CreateFloatNode(0.5f);
             FlipBuiltinVariable(root, GetMainSequence(root), viewportYScale, &getSymbolTable(),
                                 BuiltInVariable::gl_PointCoord(), kFlippedPointCoordName, pivot);
         }
diff --git a/src/compiler/translator/tree_ops/RewriteAtomicCounters.cpp b/src/compiler/translator/tree_ops/RewriteAtomicCounters.cpp
index 7f7b88c..5067d27 100644
--- a/src/compiler/translator/tree_ops/RewriteAtomicCounters.cpp
+++ b/src/compiler/translator/tree_ops/RewriteAtomicCounters.cpp
@@ -13,6 +13,7 @@
 #include "compiler/translator/SymbolTable.h"
 #include "compiler/translator/tree_util/IntermNode_util.h"
 #include "compiler/translator/tree_util/IntermTraverse.h"
+#include "compiler/translator/tree_util/ReplaceVariable.h"
 
 namespace sh
 {
@@ -150,10 +151,6 @@
         : TIntermTraverser(true, true, true, symbolTable),
           mAtomicCounters(atomicCounters),
           mAcbBufferOffsets(acbBufferOffsets),
-          mCurrentAtomicCounterOffset(0),
-          mCurrentAtomicCounterBinding(0),
-          mCurrentAtomicCounterDecl(nullptr),
-          mCurrentAtomicCounterDeclParent(nullptr),
           mAtomicCounterType(nullptr),
           mAtomicCounterTypeConst(nullptr),
           mAtomicCounterTypeDeclaration(nullptr)
@@ -161,86 +158,59 @@
 
     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
     {
+        if (visit != PreVisit)
+        {
+            return true;
+        }
+
         const TIntermSequence &sequence = *(node->getSequence());
 
         TIntermTyped *variable = sequence.front()->getAsTyped();
         const TType &type      = variable->getType();
         bool isAtomicCounter   = type.getQualifier() == EvqUniform && type.isAtomicCounter();
 
-        if (visit == PreVisit || visit == InVisit)
+        if (isAtomicCounter)
         {
-            if (isAtomicCounter)
-            {
-                mCurrentAtomicCounterDecl       = node;
-                mCurrentAtomicCounterDeclParent = getParentNode()->getAsBlock();
-                mCurrentAtomicCounterOffset     = type.getLayoutQualifier().offset;
-                mCurrentAtomicCounterBinding    = type.getLayoutQualifier().binding;
-            }
+            // Atomic counters cannot have initializers, so the declaration must necessarily be a
+            // symbol.
+            TIntermSymbol *samplerVariable = variable->getAsSymbolNode();
+            ASSERT(samplerVariable != nullptr);
+
+            declareAtomicCounter(&samplerVariable->variable(), node);
+            return false;
         }
-        else if (visit == PostVisit)
-        {
-            mCurrentAtomicCounterDecl       = nullptr;
-            mCurrentAtomicCounterDeclParent = nullptr;
-            mCurrentAtomicCounterOffset     = 0;
-            mCurrentAtomicCounterBinding    = 0;
-        }
+
         return true;
     }
 
     void visitFunctionPrototype(TIntermFunctionPrototype *node) override
     {
         const TFunction *function = node->getFunction();
-        // Go over the parameters and replace the atomic arguments with a uint type.  If this is
-        // the function definition, keep the replaced variable for future encounters.
-        mAtomicCounterFunctionParams.clear();
+        // Go over the parameters and replace the atomic arguments with a uint type.
+        mRetyper.visitFunctionPrototype();
         for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
         {
             const TVariable *param = function->getParam(paramIndex);
             TVariable *replacement = convertFunctionParameter(node, param);
             if (replacement)
             {
-                mAtomicCounterFunctionParams[param] = replacement;
+                mRetyper.replaceFunctionParam(param, replacement);
             }
         }
 
-        if (mAtomicCounterFunctionParams.empty())
-        {
-            return;
-        }
-
-        // Create a new function prototype and replace this with it.
-        TFunction *replacementFunction = new TFunction(
-            mSymbolTable, function->name(), SymbolType::UserDefined,
-            new TType(function->getReturnType()), function->isKnownToNotHaveSideEffects());
-        for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
-        {
-            const TVariable *param = function->getParam(paramIndex);
-            TVariable *replacement = nullptr;
-            if (param->getType().isAtomicCounter())
-            {
-                ASSERT(mAtomicCounterFunctionParams.count(param) != 0);
-                replacement = mAtomicCounterFunctionParams[param];
-            }
-            else
-            {
-                replacement = new TVariable(mSymbolTable, param->name(),
-                                            new TType(param->getType()), SymbolType::UserDefined);
-            }
-            replacementFunction->addParameter(replacement);
-        }
-
         TIntermFunctionPrototype *replacementPrototype =
-            new TIntermFunctionPrototype(replacementFunction);
-        queueReplacement(replacementPrototype, OriginalNode::IS_DROPPED);
-
-        mReplacedFunctions[function] = replacementFunction;
+            mRetyper.convertFunctionPrototype(mSymbolTable, function);
+        if (replacementPrototype)
+        {
+            queueReplacement(replacementPrototype, OriginalNode::IS_DROPPED);
+        }
     }
 
     bool visitAggregate(Visit visit, TIntermAggregate *node) override
     {
         if (visit == PreVisit)
         {
-            mAtomicCounterFunctionCallArgs.clear();
+            mRetyper.preVisitAggregate();
         }
 
         if (visit != PostVisit)
@@ -254,8 +224,13 @@
         }
         else if (node->getOp() == EOpCallFunctionInAST)
         {
-            convertASTFunction(node);
+            TIntermAggregate *substituteCall = mRetyper.convertASTFunction(node);
+            if (substituteCall)
+            {
+                queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
+            }
         }
+        mRetyper.postVisitAggregate();
 
         return true;
     }
@@ -264,12 +239,6 @@
     {
         const TVariable *symbolVariable = &symbol->variable();
 
-        if (mCurrentAtomicCounterDecl)
-        {
-            declareAtomicCounter(symbolVariable);
-            return;
-        }
-
         if (!symbol->getType().isAtomicCounter())
         {
             return;
@@ -336,22 +305,14 @@
         //         atomicAdd(atomicCounters[ac.binding]counters[ac.offset+n]);
         //     }
         //
-        // In all cases, the argument transformation is stored in |mAtomicCounterFunctionCallArgs|.
-        // In the function call's PostVisit, if it's a builtin, the look up in
-        // |atomicCounters.counters| is done as well as the builtin function change.  Otherwise,
-        // the transformed argument is passed on as is.
+        // In all cases, the argument transformation is stored in mRetyper.  In the function call's
+        // PostVisit, if it's a builtin, the look up in |atomicCounters.counters| is done as well as
+        // the builtin function change.  Otherwise, the transformed argument is passed on as is.
         //
 
-        TIntermTyped *bindingOffset = nullptr;
-        if (mAtomicCounterBindingOffsets.count(symbolVariable) != 0)
-        {
-            bindingOffset = new TIntermSymbol(mAtomicCounterBindingOffsets[symbolVariable]);
-        }
-        else
-        {
-            ASSERT(mAtomicCounterFunctionParams.count(symbolVariable) != 0);
-            bindingOffset = new TIntermSymbol(mAtomicCounterFunctionParams[symbolVariable]);
-        }
+        TIntermTyped *bindingOffset =
+            new TIntermSymbol(mRetyper.getVariableReplacement(symbolVariable));
+        ASSERT(bindingOffset != nullptr);
 
         TIntermNode *argument = symbol;
 
@@ -389,22 +350,21 @@
                 TIntermBinary *modifiedOffset = new TIntermBinary(
                     EOpAddAssign, offsetField, arrayExpression->getRight()->deepCopy());
 
-                TIntermSequence *modifySequence = new TIntermSequence();
-                modifySequence->push_back(modifiedDecl);
-                modifySequence->push_back(modifiedOffset);
+                TIntermSequence *modifySequence =
+                    new TIntermSequence({modifiedDecl, modifiedOffset});
                 insertStatementsInParentBlock(*modifySequence);
 
                 bindingOffset = modifiedSymbol->deepCopy();
             }
         }
 
-        mAtomicCounterFunctionCallArgs[argument] = bindingOffset;
+        mRetyper.replaceFunctionCallArg(argument, bindingOffset);
     }
 
     TIntermDeclaration *getAtomicCounterTypeDeclaration() { return mAtomicCounterTypeDeclaration; }
 
   private:
-    void declareAtomicCounter(const TVariable *symbolVariable)
+    void declareAtomicCounter(const TVariable *atomicCounterVar, TIntermDeclaration *node)
     {
         // Create a global variable that contains the binding and offset of this atomic counter
         // declaration.
@@ -414,12 +374,16 @@
         }
         ASSERT(mAtomicCounterTypeConst);
 
-        TVariable *bindingOffset = new TVariable(mSymbolTable, symbolVariable->name(),
+        TVariable *bindingOffset = new TVariable(mSymbolTable, atomicCounterVar->name(),
                                                  mAtomicCounterTypeConst, SymbolType::UserDefined);
 
-        ASSERT(mCurrentAtomicCounterOffset % 4 == 0);
-        TIntermTyped *bindingOffsetInitValue = CreateAtomicCounterConstant(
-            mAtomicCounterTypeConst, mCurrentAtomicCounterBinding, mCurrentAtomicCounterOffset / 4);
+        const TType &atomicCounterType = atomicCounterVar->getType();
+        uint32_t offset                = atomicCounterType.getLayoutQualifier().offset;
+        uint32_t binding               = atomicCounterType.getLayoutQualifier().binding;
+
+        ASSERT(offset % 4 == 0);
+        TIntermTyped *bindingOffsetInitValue =
+            CreateAtomicCounterConstant(mAtomicCounterTypeConst, binding, offset / 4);
 
         TIntermSymbol *bindingOffsetSymbol = new TIntermSymbol(bindingOffset);
         TIntermBinary *bindingOffsetInit =
@@ -431,11 +395,10 @@
         // Replace the atomic_uint declaration with the binding/offset declaration.
         TIntermSequence replacement;
         replacement.push_back(bindingOffsetDeclaration);
-        mMultiReplacements.emplace_back(mCurrentAtomicCounterDeclParent, mCurrentAtomicCounterDecl,
-                                        replacement);
+        mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, replacement);
 
         // Remember the binding/offset variable.
-        mAtomicCounterBindingOffsets[symbolVariable] = bindingOffset;
+        mRetyper.replaceGlobalVariable(atomicCounterVar, bindingOffset);
     }
 
     void declareAtomicCounterType()
@@ -529,9 +492,8 @@
         }
 
         const TIntermNode *param = (*arguments)[0];
-        ASSERT(mAtomicCounterFunctionCallArgs.count(param) != 0);
 
-        TIntermTyped *bindingOffset = mAtomicCounterFunctionCallArgs[param];
+        TIntermTyped *bindingOffset = mRetyper.getFunctionCallArgReplacement(param);
 
         TIntermSequence *substituteArguments = new TIntermSequence;
         substituteArguments->push_back(
@@ -551,60 +513,10 @@
         queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
     }
 
-    void convertASTFunction(TIntermAggregate *node)
-    {
-        // See if the function needs replacement at all.
-        const TFunction *function = node->getFunction();
-        if (mReplacedFunctions.count(function) == 0)
-        {
-            return;
-        }
-
-        // atomic_uint arguments to this call are staged to be replaced at the same time.
-        TFunction *substituteFunction        = mReplacedFunctions[function];
-        TIntermSequence *substituteArguments = new TIntermSequence;
-
-        for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
-        {
-            TIntermNode *param = node->getChildNode(paramIndex);
-
-            TIntermNode *replacement = nullptr;
-            if (param->getAsTyped()->getType().isAtomicCounter())
-            {
-                ASSERT(mAtomicCounterFunctionCallArgs.count(param) != 0);
-                replacement = mAtomicCounterFunctionCallArgs[param];
-            }
-            else
-            {
-                replacement = param->getAsTyped()->deepCopy();
-            }
-            substituteArguments->push_back(replacement);
-        }
-
-        TIntermTyped *substituteCall =
-            TIntermAggregate::CreateFunctionCall(*substituteFunction, substituteArguments);
-
-        queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
-    }
-
     const TVariable *mAtomicCounters;
     const TIntermTyped *mAcbBufferOffsets;
 
-    // A map from the atomic_uint variable to the binding/offset declaration.
-    std::unordered_map<const TVariable *, TVariable *> mAtomicCounterBindingOffsets;
-    // A map from functions with atomic_uint parameters to one where that's replaced with uint.
-    std::unordered_map<const TFunction *, TFunction *> mReplacedFunctions;
-    // A map from atomic_uint function parameters to their replacement uint parameter for the
-    // current function definition.
-    std::unordered_map<const TVariable *, TVariable *> mAtomicCounterFunctionParams;
-    // A map from atomic_uint function call arguments to their replacement for the current
-    // non-builtin function call.
-    std::unordered_map<const TIntermNode *, TIntermTyped *> mAtomicCounterFunctionCallArgs;
-
-    uint32_t mCurrentAtomicCounterOffset;
-    uint32_t mCurrentAtomicCounterBinding;
-    TIntermDeclaration *mCurrentAtomicCounterDecl;
-    TIntermAggregateBase *mCurrentAtomicCounterDeclParent;
+    RetypeOpaqueVariablesHelper mRetyper;
 
     TType *mAtomicCounterType;
     TType *mAtomicCounterTypeConst;
diff --git a/src/compiler/translator/tree_util/IntermNode_util.cpp b/src/compiler/translator/tree_util/IntermNode_util.cpp
index 3fa7d57..a9976d4 100644
--- a/src/compiler/translator/tree_util/IntermNode_util.cpp
+++ b/src/compiler/translator/tree_util/IntermNode_util.cpp
@@ -113,14 +113,22 @@
     return TIntermAggregate::CreateConstructor(constType, arguments);
 }
 
+TIntermConstantUnion *CreateFloatNode(float value)
+{
+    TConstantUnion *u = new TConstantUnion[1];
+    u[0].setFConst(value);
+
+    TType type(EbtFloat, EbpUndefined, EvqConst, 1);
+    return new TIntermConstantUnion(u, type);
+}
+
 TIntermConstantUnion *CreateIndexNode(int index)
 {
     TConstantUnion *u = new TConstantUnion[1];
     u[0].setIConst(index);
 
     TType type(EbtInt, EbpUndefined, EvqConst, 1);
-    TIntermConstantUnion *node = new TIntermConstantUnion(u, type);
-    return node;
+    return new TIntermConstantUnion(u, type);
 }
 
 TIntermConstantUnion *CreateBoolNode(bool value)
@@ -129,8 +137,7 @@
     u[0].setBConst(value);
 
     TType type(EbtBool, EbpUndefined, EvqConst, 1);
-    TIntermConstantUnion *node = new TIntermConstantUnion(u, type);
-    return node;
+    return new TIntermConstantUnion(u, type);
 }
 
 TVariable *CreateTempVariable(TSymbolTable *symbolTable, const TType *type)
diff --git a/src/compiler/translator/tree_util/IntermNode_util.h b/src/compiler/translator/tree_util/IntermNode_util.h
index 7d1d421..d0923f0 100644
--- a/src/compiler/translator/tree_util/IntermNode_util.h
+++ b/src/compiler/translator/tree_util/IntermNode_util.h
@@ -23,6 +23,7 @@
                                                                 TIntermBlock *functionBody);
 
 TIntermTyped *CreateZeroNode(const TType &type);
+TIntermConstantUnion *CreateFloatNode(float value);
 TIntermConstantUnion *CreateIndexNode(int index);
 TIntermConstantUnion *CreateBoolNode(bool value);
 
diff --git a/src/compiler/translator/tree_util/ReplaceVariable.cpp b/src/compiler/translator/tree_util/ReplaceVariable.cpp
index 7120cea..f9555be 100644
--- a/src/compiler/translator/tree_util/ReplaceVariable.cpp
+++ b/src/compiler/translator/tree_util/ReplaceVariable.cpp
@@ -9,6 +9,7 @@
 #include "compiler/translator/tree_util/ReplaceVariable.h"
 
 #include "compiler/translator/IntermNode.h"
+#include "compiler/translator/Symbol.h"
 #include "compiler/translator/tree_util/IntermTraverse.h"
 
 namespace sh
@@ -61,4 +62,75 @@
     traverser.updateTree();
 }
 
+TIntermFunctionPrototype *RetypeOpaqueVariablesHelper::convertFunctionPrototype(
+    TSymbolTable *symbolTable,
+    const TFunction *oldFunction)
+{
+    if (mReplacedFunctionParams.empty())
+    {
+        return nullptr;
+    }
+
+    // Create a new function prototype for replacement.
+    TFunction *replacementFunction = new TFunction(
+        symbolTable, oldFunction->name(), SymbolType::UserDefined,
+        new TType(oldFunction->getReturnType()), oldFunction->isKnownToNotHaveSideEffects());
+    for (size_t paramIndex = 0; paramIndex < oldFunction->getParamCount(); ++paramIndex)
+    {
+        const TVariable *param = oldFunction->getParam(paramIndex);
+        TVariable *replacement = nullptr;
+        auto replaced          = mReplacedFunctionParams.find(param);
+        if (replaced != mReplacedFunctionParams.end())
+        {
+            replacement = replaced->second;
+        }
+        else
+        {
+            replacement = new TVariable(symbolTable, param->name(), new TType(param->getType()),
+                                        SymbolType::UserDefined);
+        }
+        replacementFunction->addParameter(replacement);
+    }
+    mReplacedFunctions[oldFunction] = replacementFunction;
+
+    TIntermFunctionPrototype *replacementPrototype =
+        new TIntermFunctionPrototype(replacementFunction);
+
+    return replacementPrototype;
+}
+
+TIntermAggregate *RetypeOpaqueVariablesHelper::convertASTFunction(TIntermAggregate *node)
+{
+    // See if the function needs replacement at all.
+    const TFunction *function = node->getFunction();
+    auto replacedFunction     = mReplacedFunctions.find(function);
+    if (replacedFunction == mReplacedFunctions.end())
+    {
+        return nullptr;
+    }
+
+    // Arguments to this call are staged to be replaced at the same time.
+    TFunction *substituteFunction        = replacedFunction->second;
+    TIntermSequence *substituteArguments = new TIntermSequence;
+
+    for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
+    {
+        TIntermNode *param = node->getChildNode(paramIndex);
+
+        TIntermNode *replacement = nullptr;
+        auto replacedArg         = mReplacedFunctionCallArgs.top().find(param);
+        if (replacedArg != mReplacedFunctionCallArgs.top().end())
+        {
+            replacement = replacedArg->second;
+        }
+        else
+        {
+            replacement = param->getAsTyped()->deepCopy();
+        }
+        substituteArguments->push_back(replacement);
+    }
+
+    return TIntermAggregate::CreateFunctionCall(*substituteFunction, substituteArguments);
+}
+
 }  // namespace sh
diff --git a/src/compiler/translator/tree_util/ReplaceVariable.h b/src/compiler/translator/tree_util/ReplaceVariable.h
index 3a9e34e..d2b1c8d 100644
--- a/src/compiler/translator/tree_util/ReplaceVariable.h
+++ b/src/compiler/translator/tree_util/ReplaceVariable.h
@@ -9,12 +9,22 @@
 #ifndef COMPILER_TRANSLATOR_TREEUTIL_REPLACEVARIABLE_H_
 #define COMPILER_TRANSLATOR_TREEUTIL_REPLACEVARIABLE_H_
 
+#include "common/debug.h"
+
+#include <stack>
+#include <unordered_map>
+
 namespace sh
 {
 
+class TFunction;
+class TIntermAggregate;
 class TIntermBlock;
-class TVariable;
+class TIntermFunctionPrototype;
+class TIntermNode;
 class TIntermTyped;
+class TSymbolTable;
+class TVariable;
 
 void ReplaceVariable(TIntermBlock *root,
                      const TVariable *toBeReplaced,
@@ -22,6 +32,108 @@
 void ReplaceVariableWithTyped(TIntermBlock *root,
                               const TVariable *toBeReplaced,
                               const TIntermTyped *replacement);
+
+// A helper class to keep track of opaque variable re-typing during a pass.  Unlike the above
+// functions, this can be used to replace all opaque variables of a certain type with another in a
+// pass that possibly does other related transformations.  Only opaque variables are supported as
+// replacing local variables is not supported.
+//
+// The class uses "old" to refer to the original type of the variable and "new" to refer to the type
+// that will replace it.
+//
+// - replaceGlobalVariable(): Call to track a global variable that is replaced.
+// - in TIntermTraverser::visitFunctionPrototype():
+//   * Call visitFunctionPrototype().
+//   * For every replaced parameter, call replaceFunctionParam().
+//   * call convertFunctionPrototype() to convert the prototype based on the above replacements
+//     and track the function with its replacement.
+//   * Call replaceFunction() to track the function that is replaced.
+// - In PreVisit of TIntermTraverser::visitAggregate():
+//   * call preVisitAggregate()
+// - In TIntermTraverser::visitSymbol():
+//   * Replace non-function-call-argument symbols that refer to a global or function param with the
+//     replacement (getVariableReplacement()).
+//   * For function call arguments, call replaceFunctionCallArg() to track the replacement.
+// - In PostVisit of TIntermTraverser::visitAggregate():
+//   * Convert built-in functions per case.  Call convertASTFunction() for non built-in functions
+//     for the replacement to be created.
+//   * Call postVisitAggregate() when done.
+//
+class RetypeOpaqueVariablesHelper
+{
+  public:
+    RetypeOpaqueVariablesHelper() {}
+    ~RetypeOpaqueVariablesHelper() {}
+
+    // Global variable handling:
+    void replaceGlobalVariable(const TVariable *oldVar, TVariable *newVar)
+    {
+        ASSERT(mReplacedGlobalVariables.count(oldVar) == 0);
+        mReplacedGlobalVariables[oldVar] = newVar;
+    }
+    TVariable *getVariableReplacement(const TVariable *oldVar) const
+    {
+        if (mReplacedGlobalVariables.count(oldVar) != 0)
+        {
+            return mReplacedGlobalVariables.at(oldVar);
+        }
+        else
+        {
+            // This function should only be called if the variable is expected to have been
+            // replaced either way (as a global variable or a function parameter).
+            ASSERT(mReplacedFunctionParams.count(oldVar) != 0);
+            return mReplacedFunctionParams.at(oldVar);
+        }
+    }
+
+    // Function parameters handling:
+    void visitFunctionPrototype() { mReplacedFunctionParams.clear(); }
+    void replaceFunctionParam(const TVariable *oldParam, TVariable *newParam)
+    {
+        ASSERT(mReplacedFunctionParams.count(oldParam) == 0);
+        mReplacedFunctionParams[oldParam] = newParam;
+    }
+    TVariable *getFunctionParamReplacement(const TVariable *oldParam) const
+    {
+        ASSERT(mReplacedFunctionParams.count(oldParam) != 0);
+        return mReplacedFunctionParams.at(oldParam);
+    }
+
+    // Function call arguments handling:
+    void preVisitAggregate() { mReplacedFunctionCallArgs.emplace(); }
+    void postVisitAggregate() { mReplacedFunctionCallArgs.pop(); }
+    void replaceFunctionCallArg(const TIntermNode *oldArg, TIntermTyped *newArg)
+    {
+        ASSERT(mReplacedFunctionCallArgs.top().count(oldArg) == 0);
+        mReplacedFunctionCallArgs.top()[oldArg] = newArg;
+    }
+    TIntermTyped *getFunctionCallArgReplacement(const TIntermNode *oldArg) const
+    {
+        ASSERT(mReplacedFunctionCallArgs.top().count(oldArg) != 0);
+        return mReplacedFunctionCallArgs.top().at(oldArg);
+    }
+
+    // Helper code conversion methods.
+    TIntermFunctionPrototype *convertFunctionPrototype(TSymbolTable *symbolTable,
+                                                       const TFunction *oldFunction);
+    TIntermAggregate *convertASTFunction(TIntermAggregate *node);
+
+  private:
+    // A map from the old global variable to the new one.
+    std::unordered_map<const TVariable *, TVariable *> mReplacedGlobalVariables;
+
+    // A map from functions with old type parameters to one where that's replaced with the new type.
+    std::unordered_map<const TFunction *, TFunction *> mReplacedFunctions;
+
+    // A map from function old type parameters to their replacement new type parameter for the
+    // current function definition.
+    std::unordered_map<const TVariable *, TVariable *> mReplacedFunctionParams;
+
+    // A map from function call old type arguments to their replacement for the current function
+    // call.
+    std::stack<std::unordered_map<const TIntermNode *, TIntermTyped *>> mReplacedFunctionCallArgs;
+};
+
 }  // namespace sh
 
 #endif  // COMPILER_TRANSLATOR_TREEUTIL_REPLACEVARIABLE_H_