Merge pull request #1968 from MrTrillian/19h1-fixes

Additional 19H1 fixes

Add a pass to remove unwanted addrspacecast from final dxil.
Avoid elementwise matrix copies when orientations match.
diff --git a/include/dxc/HLSL/DxilGenerationPass.h b/include/dxc/HLSL/DxilGenerationPass.h
index 22e480c..e0a8d47 100644
--- a/include/dxc/HLSL/DxilGenerationPass.h
+++ b/include/dxc/HLSL/DxilGenerationPass.h
@@ -74,6 +74,7 @@
 ModulePass *createPausePassesPass();
 ModulePass *createResumePassesPass();
 FunctionPass *createMatrixBitcastLowerPass();
+ModulePass *createDxilCleanupAddrSpaceCastPass();
 
 void initializeDxilCondenseResourcesPass(llvm::PassRegistry&);
 void initializeDxilLowerCreateHandleForLibPass(llvm::PassRegistry&);
@@ -105,6 +106,7 @@
 void initializePausePassesPass(llvm::PassRegistry&);
 void initializeResumePassesPass(llvm::PassRegistry&);
 void initializeMatrixBitcastLowerPassPass(llvm::PassRegistry&);
+void initializeDxilCleanupAddrSpaceCastPass(llvm::PassRegistry&);
 
 bool AreDxilResourcesDense(llvm::Module *M, hlsl::DxilResourceBase **ppNonDense);
 
diff --git a/lib/HLSL/DxcOptimizer.cpp b/lib/HLSL/DxcOptimizer.cpp
index d2c1543..ca581ba 100644
--- a/lib/HLSL/DxcOptimizer.cpp
+++ b/lib/HLSL/DxcOptimizer.cpp
@@ -85,6 +85,7 @@
     initializeDSEPass(Registry);
     initializeDeadInstEliminationPass(Registry);
     initializeDxilAllocateResourcesForLibPass(Registry);
+    initializeDxilCleanupAddrSpaceCastPass(Registry);
     initializeDxilCondenseResourcesPass(Registry);
     initializeDxilConvergentClearPass(Registry);
     initializeDxilConvergentMarkPass(Registry);
diff --git a/lib/HLSL/DxilPreparePasses.cpp b/lib/HLSL/DxilPreparePasses.cpp
index 50769bd..92aef99 100644
--- a/lib/HLSL/DxilPreparePasses.cpp
+++ b/lib/HLSL/DxilPreparePasses.cpp
@@ -146,6 +146,8 @@
 
 ///////////////////////////////////////////////////////////////////////////////
 
+bool CleanupSharedMemoryAddrSpaceCast(Module &M);
+
 namespace {
 
 static void TransferEntryFunctionAttributes(Function *F, Function *NewFunc) {
@@ -295,8 +297,11 @@
 
       RemoveUnusedStaticGlobal(M);
 
+      // Remove unnecessary address space casts.
+      CleanupSharedMemoryAddrSpaceCast(M);
+
       // Clear inbound for GEP which has none-const index.
-      LegalizeShareMemoryGEPInbound(M);
+      LegalizeSharedMemoryGEPInbound(M);
 
       // Strip parameters of entry function.
       StripEntryParameters(M, DM, IsLib);
@@ -375,7 +380,7 @@
     }
   }
 
-  void LegalizeShareMemoryGEPInbound(Module &M) {
+  void LegalizeSharedMemoryGEPInbound(Module &M) {
     const DataLayout &DL = M.getDataLayout();
     // Clear inbound for GEP which has none-const index.
     for (GlobalVariable &GV : M.globals()) {
@@ -452,6 +457,226 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 namespace {
+typedef MapVector< PHINode*, SmallVector<Value*,8> > PHIReplacementMap;
+bool RemoveAddrSpaceCasts(Value *Val, Value *NewVal,
+                          PHIReplacementMap &phiReplacements,
+                          DenseMap<Value*, Value*> &valueMap) {
+  bool bChanged = false;
+  for (auto itU = Val->use_begin(), itEnd = Val->use_end(); itU != itEnd; ) {
+    Use &use = *(itU++);
+    User *user = use.getUser();
+    Value *userReplacement = user;
+    bool bConstructReplacement = false;
+    bool bCleanupInst = false;
+    auto valueMapIter = valueMap.find(user);
+    if (valueMapIter != valueMap.end())
+      userReplacement = valueMapIter->second;
+    else if (Val != NewVal)
+      bConstructReplacement = true;
+    if (ConstantExpr* CE = dyn_cast<ConstantExpr>(user)) {
+      if (CE->getOpcode() == Instruction::BitCast) {
+        if (bConstructReplacement) {
+          // Replicate bitcast in target address space
+          Type* NewTy = PointerType::get(
+            CE->getType()->getPointerElementType(),
+            NewVal->getType()->getPointerAddressSpace());
+          userReplacement = ConstantExpr::getBitCast(cast<Constant>(NewVal), NewTy);
+        }
+      } else if (CE->getOpcode() == Instruction::GetElementPtr) {
+        if (bConstructReplacement) {
+          // Replicate GEP in target address space
+          GEPOperator *GEP = cast<GEPOperator>(CE);
+          SmallVector<Value*, 8> idxList(GEP->idx_begin(), GEP->idx_end());
+          userReplacement = ConstantExpr::getGetElementPtr(
+            nullptr, cast<Constant>(NewVal), idxList, GEP->isInBounds());
+        }
+      } else if (CE->getOpcode() == Instruction::AddrSpaceCast) {
+        userReplacement = NewVal;
+        bConstructReplacement = false;
+      } else {
+        DXASSERT(false, "RemoveAddrSpaceCasts: unhandled pointer ConstantExpr");
+      }
+    } else if (Instruction *I = dyn_cast<Instruction>(user)) {
+      if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(user)) {
+        if (bConstructReplacement) {
+          IRBuilder<> Builder(GEP);
+          SmallVector<Value*, 8> idxList(GEP->idx_begin(), GEP->idx_end());
+          if (GEP->isInBounds())
+            userReplacement = Builder.CreateInBoundsGEP(NewVal, idxList, GEP->getName());
+          else
+            userReplacement = Builder.CreateGEP(NewVal, idxList, GEP->getName());
+        }
+      } else if (BitCastInst *BC = dyn_cast<BitCastInst>(user)) {
+        if (bConstructReplacement) {
+          IRBuilder<> Builder(BC);
+          Type* NewTy = PointerType::get(
+            BC->getType()->getPointerElementType(),
+            NewVal->getType()->getPointerAddressSpace());
+          userReplacement = Builder.CreateBitCast(NewVal, NewTy);
+        }
+      } else if (PHINode *PHI = dyn_cast<PHINode>(user)) {
+        // set replacement phi values for PHI pass
+        unsigned numValues = PHI->getNumIncomingValues();
+        auto &phiValues = phiReplacements[PHI];
+        if (phiValues.empty())
+          phiValues.resize(numValues, nullptr);
+        for (unsigned idx = 0; idx < numValues; ++idx) {
+          if (phiValues[idx] == nullptr &&
+              PHI->getIncomingValue(idx) == Val) {
+            phiValues[idx] = NewVal;
+            bChanged = true;
+          }
+        }
+        continue;
+      } else if (isa<AddrSpaceCastInst>(user)) {
+        userReplacement = NewVal;
+        bConstructReplacement = false;
+        bCleanupInst = true;
+      } else if (isa<CallInst>(user)) {
+        continue;
+      } else {
+        if (Val != NewVal) {
+          use.set(NewVal);
+          bChanged = true;
+        }
+        continue;
+      }
+    }
+    if (bConstructReplacement && user != userReplacement)
+      valueMap[user] = userReplacement;
+    bChanged |= RemoveAddrSpaceCasts(user, userReplacement, phiReplacements,
+                                      valueMap);
+    if (bCleanupInst && user->use_empty()) {
+      // Clean up old instruction if it's now unused.
+      // Safe during this use iteration when only one use of V in instruction.
+      if (Instruction *I = dyn_cast<Instruction>(user))
+        I->eraseFromParent();
+      bChanged = true;
+    }
+  }
+  return bChanged;
+}
+}
+
+bool CleanupSharedMemoryAddrSpaceCast(Module &M) {
+  bool bChanged = false;
+  // Eliminate address space casts if possible
+  // Collect phi nodes so we can replace iteratively after pass over GVs
+  PHIReplacementMap phiReplacements;
+  DenseMap<Value*, Value*> valueMap;
+  for (GlobalVariable &GV : M.globals()) {
+    if (dxilutil::IsSharedMemoryGlobal(&GV)) {
+      bChanged |= RemoveAddrSpaceCasts(&GV, &GV, phiReplacements,
+                                       valueMap);
+    }
+  }
+  bool bConverged = false;
+  while (!phiReplacements.empty() && !bConverged) {
+    bConverged = true;
+    for (auto &phiReplacement : phiReplacements) {
+      PHINode *PHI = phiReplacement.first;
+      unsigned origAddrSpace = PHI->getType()->getPointerAddressSpace();
+      unsigned incomingAddrSpace = UINT_MAX;
+      bool bReplacePHI = true;
+      bool bRemovePHI = false;
+      for (auto V : phiReplacement.second) {
+        if (nullptr == V) {
+          // cannot replace phi (yet)
+          bReplacePHI = false;
+          break;
+        }
+        unsigned addrSpace = V->getType()->getPointerAddressSpace();
+        if (incomingAddrSpace == UINT_MAX) {
+          incomingAddrSpace = addrSpace;
+        } else if (addrSpace != incomingAddrSpace) {
+          bRemovePHI = true;
+          break;
+        }
+      }
+      if (origAddrSpace == incomingAddrSpace)
+        bRemovePHI = true;
+      if (bRemovePHI) {
+        // Cannot replace phi.  Remove it and restart.
+        phiReplacements.erase(PHI);
+        bConverged = false;
+        break;
+      }
+      if (!bReplacePHI)
+        continue;
+      auto &NewVal = valueMap[PHI];
+      PHINode *NewPHI = nullptr;
+      if (NewVal) {
+        NewPHI = cast<PHINode>(NewVal);
+      } else {
+        IRBuilder<> Builder(PHI);
+        NewPHI = Builder.CreatePHI(
+          PointerType::get(PHI->getType()->getPointerElementType(),
+                           incomingAddrSpace),
+          PHI->getNumIncomingValues(),
+          PHI->getName());
+        NewVal = NewPHI;
+        for (unsigned idx = 0; idx < PHI->getNumIncomingValues(); idx++) {
+          NewPHI->addIncoming(phiReplacement.second[idx],
+                              PHI->getIncomingBlock(idx));
+        }
+      }
+      if (RemoveAddrSpaceCasts(PHI, NewPHI, phiReplacements,
+                               valueMap)) {
+        bConverged = false;
+        bChanged = true;
+      }
+      if (PHI->use_empty()) {
+        phiReplacements.erase(PHI);
+        bConverged = false;
+        bChanged = true;
+        break;
+      }
+    }
+  }
+
+  // Cleanup unused replacement instructions
+  SmallVector<WeakVH, 8> cleanupInsts;
+  for (auto it : valueMap) {
+    if (isa<Instruction>(it.first))
+      cleanupInsts.push_back(it.first);
+    if (isa<Instruction>(it.second))
+      cleanupInsts.push_back(it.second);
+  }
+  for (auto V : cleanupInsts) {
+    if (!V)
+      continue;
+    if (PHINode *PHI = dyn_cast<PHINode>(V))
+      RecursivelyDeleteDeadPHINode(PHI);
+    else if (Instruction *I = dyn_cast<Instruction>(V))
+      RecursivelyDeleteTriviallyDeadInstructions(I);
+  }
+
+  return bChanged;
+}
+
+class DxilCleanupAddrSpaceCast : public ModulePass {
+public:
+  static char ID; // Pass identification, replacement for typeid
+  explicit DxilCleanupAddrSpaceCast() : ModulePass(ID) {}
+
+  const char *getPassName() const override { return "HLSL DXIL Cleanup Address Space Cast"; }
+
+  bool runOnModule(Module &M) override {
+    return CleanupSharedMemoryAddrSpaceCast(M);
+  }
+};
+
+char DxilCleanupAddrSpaceCast::ID = 0;
+
+ModulePass *llvm::createDxilCleanupAddrSpaceCastPass() {
+  return new DxilCleanupAddrSpaceCast();
+}
+
+INITIALIZE_PASS(DxilCleanupAddrSpaceCast, "hlsl-dxil-cleanup-addrspacecast", "HLSL DXIL Cleanup Address Space Cast", false, false)
+
+///////////////////////////////////////////////////////////////////////////////
+
+namespace {
 
 class DxilEmitMetadata : public ModulePass {
 public:
diff --git a/lib/HLSL/HLOperationLower.cpp b/lib/HLSL/HLOperationLower.cpp
index ea94bc1..688e5fb 100644
--- a/lib/HLSL/HLOperationLower.cpp
+++ b/lib/HLSL/HLOperationLower.cpp
@@ -6434,6 +6434,14 @@
       }
     }
     user->eraseFromParent();
+  } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(user)) {
+    // Recurse users
+    for (auto U = BCI->user_begin(); U != BCI->user_end();) {
+      Value *BCIUser = *(U++);
+      TranslateStructBufSubscriptUser(cast<Instruction>(BCIUser), handle,
+        bufIdx, baseOffset, status, OP, DL);
+    }
+    BCI->eraseFromParent();
   } else {
     // should only used by GEP
     GetElementPtrInst *GEP = cast<GetElementPtrInst>(user);
diff --git a/lib/Transforms/IPO/PassManagerBuilder.cpp b/lib/Transforms/IPO/PassManagerBuilder.cpp
index 7674d7b..32adc89 100644
--- a/lib/Transforms/IPO/PassManagerBuilder.cpp
+++ b/lib/Transforms/IPO/PassManagerBuilder.cpp
@@ -213,6 +213,8 @@
     return;
   }
 
+  MPM.add(createDxilCleanupAddrSpaceCastPass());
+
   MPM.add(createHLPreprocessPass());
   bool NoOpt = OptLevel == 0;
   if (!NoOpt) {
diff --git a/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp b/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp
index 10611fc..312667a 100644
--- a/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp
+++ b/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp
@@ -2707,8 +2707,6 @@
     assert(NewGEP->getType() == GEP->getType() && "type mismatch");
     
     GEP->replaceAllUsesWith(NewGEP);
-    if (isa<Instruction>(GEP))
-      DeadInsts.push_back(GEP);
   } else {
     // End at array of basic type.
     Type *Ty = GEP->getType()->getPointerElementType();
@@ -2725,22 +2723,16 @@
         NewGEPs.emplace_back(NewGEP);
       }
       const bool bAllowReplace = isa<AllocaInst>(OldVal);
-      if (SROA_Helper::LowerMemcpy(GEP, /*annoation*/ nullptr, typeSys, DL,
-                                   bAllowReplace)) {
-        if (GEP->user_empty() && isa<Instruction>(GEP))
-          DeadInsts.push_back(GEP);
-        return;
-      }
-      SROA_Helper helper(GEP, NewGEPs, DeadInsts, typeSys, DL);
-      helper.RewriteForScalarRepl(GEP, Builder);
-      for (Value *NewGEP : NewGEPs) {
-        if (NewGEP->user_empty() && isa<Instruction>(NewGEP)) {
-          // Delete unused newGEP.
-          cast<Instruction>(NewGEP)->eraseFromParent();
+      if (!SROA_Helper::LowerMemcpy(GEP, /*annoation*/ nullptr, typeSys, DL, bAllowReplace)) {
+        SROA_Helper helper(GEP, NewGEPs, DeadInsts, typeSys, DL);
+        helper.RewriteForScalarRepl(GEP, Builder);
+        for (Value *NewGEP : NewGEPs) {
+          if (NewGEP->user_empty() && isa<Instruction>(NewGEP)) {
+            // Delete unused newGEP.
+            cast<Instruction>(NewGEP)->eraseFromParent();
+          }
         }
       }
-      if (GEP->user_empty() && isa<Instruction>(GEP))
-        DeadInsts.push_back(GEP);
     } else {
       Value *vecIdx = NewArgs.back();
       if (ConstantInt *immVecIdx = dyn_cast<ConstantInt>(vecIdx)) {
@@ -2758,14 +2750,22 @@
         assert(NewGEP->getType() == GEP->getType() && "type mismatch");
 
         GEP->replaceAllUsesWith(NewGEP);
-        if (isa<Instruction>(GEP))
-          DeadInsts.push_back(GEP);
       } else {
         // dynamic vector indexing.
         assert(0 && "should not reach here");
       }
     }
   }
+
+  // Remove the use so that the caller can keep iterating over its other users
+  DXASSERT(GEP->user_empty(), "All uses of the GEP should have been eliminated");
+  if (isa<Instruction>(GEP)) {
+    GEP->setOperand(GEP->getPointerOperandIndex(), UndefValue::get(GEP->getPointerOperand()->getType()));
+    DeadInsts.push_back(GEP);
+  }
+  else {
+    cast<Constant>(GEP)->destroyConstant();
+  }
 }
 
 /// isVectorOrStructArray - Check if T is array of vector or struct.
@@ -2828,7 +2828,6 @@
       Insert = Builder.CreateInsertElement(Insert, Load, i, "insert");
     }
     LI->replaceAllUsesWith(Insert);
-    DeadInsts.push_back(LI);
   } else if (isCompatibleAggregate(LIType, ValTy)) {
     if (isVectorOrStructArray(LIType)) {
       // Replace:
@@ -2846,7 +2845,6 @@
       Value *newLd =
           LoadVectorOrStructArray(cast<ArrayType>(LIType), NewElts, idxList, Builder);
       LI->replaceAllUsesWith(newLd);
-      DeadInsts.push_back(LI);
     } else {
       // Replace:
       //   %res = load { i32, i32 }* %alloc
@@ -2880,11 +2878,14 @@
       if (LIType->isStructTy()) {
         SimplifyStructValUsage(Insert, LdElts, DeadInsts);
       }
-      DeadInsts.push_back(LI);
     }
   } else {
     llvm_unreachable("other type don't need rewrite");
   }
+
+  // Remove the use so that the caller can keep iterating over its other users
+  LI->setOperand(LI->getPointerOperandIndex(), UndefValue::get(LI->getPointerOperand()->getType()));
+  DeadInsts.push_back(LI);
 }
 
 /// RewriteForStore - Replace OldVal with flattened NewElts in StoreInst.
@@ -2906,7 +2907,6 @@
       Value *Extract = Builder.CreateExtractElement(Val, i, Val->getName());
       Builder.CreateStore(Extract, NewElts[i]);
     }
-    DeadInsts.push_back(SI);
   } else if (isCompatibleAggregate(SIType, ValTy)) {
     if (isVectorOrStructArray(SIType)) {
       // Replace:
@@ -2936,7 +2936,6 @@
       SmallVector<Value *, 8> idxList;
       idxList.emplace_back(zero);
       StoreVectorOrStructArray(AT, Val, NewElts, idxList, Builder);
-      DeadInsts.push_back(SI);
     } else {
       // Replace:
       //   store { i32, i32 } %val, { i32, i32 }* %alloc
@@ -2959,11 +2958,14 @@
               Extract->getType(), {NewElts[i], Extract}, *M);
         }
       }
-      DeadInsts.push_back(SI);
     }
   } else {
     llvm_unreachable("other type don't need rewrite");
   }
+
+  // Remove the use so that the caller can keep iterating over its other users
+  SI->setOperand(SI->getPointerOperandIndex(), UndefValue::get(SI->getPointerOperand()->getType()));
+  DeadInsts.push_back(SI);
 }
 /// RewriteMemIntrin - MI is a memcpy/memset/memmove from or to AI.
 /// Rewrite it to copy or set the elements of the scalarized memory.
@@ -3006,6 +3008,10 @@
            I != E; ++I)
         if (*I == MI)
           return;
+
+      // Remove the uses so that the caller can keep iterating over its other users
+      MI->setOperand(0, UndefValue::get(MI->getOperand(0)->getType()));
+      MI->setOperand(1, UndefValue::get(MI->getOperand(1)->getType()));
       DeadInsts.push_back(MI);
       return;
     }
@@ -3136,6 +3142,11 @@
                               MI->isVolatile());
     }
   }
+
+  // Remove the use so that the caller can keep iterating over its other users
+  MI->setOperand(0, UndefValue::get(MI->getOperand(0)->getType()));
+  if (isa<MemTransferInst>(MI))
+    MI->setOperand(1, UndefValue::get(MI->getOperand(1)->getType()));
   DeadInsts.push_back(MI);
 }
 
@@ -3317,6 +3328,13 @@
   }
   SROA_Helper helper(CE, NewCasts, DeadInsts, typeSys, DL);
   helper.RewriteForScalarRepl(CE, Builder);
+
+  // Remove the use so that the caller can keep iterating over its other users
+  DXASSERT(CE->user_empty(), "All uses of the addrspacecast should have been eliminated");
+  if (Instruction *I = dyn_cast<Instruction>(CE))
+    I->eraseFromParent();
+  else
+    cast<Constant>(CE)->destroyConstant();
 }
 
 /// RewriteForConstExpr - Rewrite the GEP which is ConstantExpr.
@@ -3335,10 +3353,6 @@
       return;
     }
   }
-  // Skip unused CE. 
-  if (CE->use_empty())
-    return;
-
   for (Value::use_iterator UI = CE->use_begin(), E = CE->use_end(); UI != E;) {
     Use &TheUse = *UI++;
     if (Instruction *I = dyn_cast<Instruction>(TheUse.getUser())) {
@@ -3352,37 +3366,49 @@
       RewriteForConstExpr(cast<ConstantExpr>(TheUse.getUser()), Builder);
     }
   }
+
+  // Remove the use so that the caller can keep iterating over its other users
+  DXASSERT(CE->user_empty(), "All uses of the constantexpr should have been eliminated");
+  CE->destroyConstant();
 }
 /// RewriteForScalarRepl - OldVal is being split into NewElts, so rewrite
 /// users of V, which references it, to use the separate elements.
 void SROA_Helper::RewriteForScalarRepl(Value *V, IRBuilder<> &Builder) {
+  // Don't iterate upon the uses explicitly because we'll be removing them,
+  // and potentially adding new ones (if expanding memcpys) during the iteration.
+  Use* PrevUse = nullptr;
+  while (!V->use_empty()) {
+    Use &TheUse = *V->use_begin();
 
-  for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E;) {
-    Use &TheUse = *UI++;
+    DXASSERT_LOCALVAR(PrevUse, &TheUse != PrevUse,
+      "Infinite loop while SROA'ing value, use isn't getting eliminated.");
+    PrevUse = &TheUse;
 
+    // Each of these must either call ->eraseFromParent()
+    // or null out the use of V so that we make progress.
     if (ConstantExpr *CE = dyn_cast<ConstantExpr>(TheUse.getUser())) {
       RewriteForConstExpr(CE, Builder);
-      continue;
     }
-    Instruction *User = cast<Instruction>(TheUse.getUser());
-
-    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
-      IRBuilder<> Builder(GEP);
-      RewriteForGEP(cast<GEPOperator>(GEP), Builder);
-    } else if (LoadInst *ldInst = dyn_cast<LoadInst>(User))
-      RewriteForLoad(ldInst);
-    else if (StoreInst *stInst = dyn_cast<StoreInst>(User))
-      RewriteForStore(stInst);
-    else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User))
-      RewriteMemIntrin(MI, cast<Instruction>(V));
-    else if (CallInst *CI = dyn_cast<CallInst>(User)) 
-      RewriteCall(CI);
-    else if (BitCastInst *BCI = dyn_cast<BitCastInst>(User))
-      RewriteBitCast(BCI);
-    else if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(User)) {
-      RewriteForAddrSpaceCast(CI, Builder);
-    } else {
-      assert(0 && "not support.");
+    else {
+      Instruction *User = cast<Instruction>(TheUse.getUser());
+      if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
+        IRBuilder<> Builder(GEP);
+        RewriteForGEP(cast<GEPOperator>(GEP), Builder);
+      } else if (LoadInst *ldInst = dyn_cast<LoadInst>(User))
+        RewriteForLoad(ldInst);
+      else if (StoreInst *stInst = dyn_cast<StoreInst>(User))
+        RewriteForStore(stInst);
+      else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User))
+        RewriteMemIntrin(MI, V);
+      else if (CallInst *CI = dyn_cast<CallInst>(User)) 
+        RewriteCall(CI);
+      else if (BitCastInst *BCI = dyn_cast<BitCastInst>(User))
+        RewriteBitCast(BCI);
+      else if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(User)) {
+        RewriteForAddrSpaceCast(CI, Builder);
+      } else {
+        assert(0 && "not support.");
+      }
     }
   }
 }
diff --git a/tools/clang/lib/CodeGen/CGHLSLMS.cpp b/tools/clang/lib/CodeGen/CGHLSLMS.cpp
index e96a1fe..fdc5225 100644
--- a/tools/clang/lib/CodeGen/CGHLSLMS.cpp
+++ b/tools/clang/lib/CodeGen/CGHLSLMS.cpp
@@ -4021,6 +4021,8 @@
       // Skip function call.
     } else if (dyn_cast<BitCastInst>(U)) {
       // Skip bitcast.
+    } else if (dyn_cast<AddrSpaceCastInst>(U)) {
+      // Skip addrspacecast.
     } else {
       DXASSERT(0, "not support yet");
     }
@@ -6653,6 +6655,20 @@
   }
 }
 
+static bool AreMatrixArrayOrientationMatching(ASTContext& Context,
+  HLModule &Module, QualType LhsTy, QualType RhsTy) {
+  while (const clang::ArrayType *LhsArrayTy = Context.getAsArrayType(LhsTy)) {
+    LhsTy = LhsArrayTy->getElementType();
+    RhsTy = Context.getAsArrayType(RhsTy)->getElementType();
+  }
+
+  bool LhsRowMajor, RhsRowMajor;
+  LhsRowMajor = RhsRowMajor = Module.GetHLOptions().bDefaultRowMajor;
+  HasHLSLMatOrientation(LhsTy, &LhsRowMajor);
+  HasHLSLMatOrientation(RhsTy, &RhsRowMajor);
+  return LhsRowMajor == RhsRowMajor;
+}
+
 // Copy data from SrcPtr to DestPtr.
 // For matrix, use MatLoad/MatStore.
 // For matrix array, EmitHLSLAggregateCopy on each element.
@@ -6689,13 +6705,15 @@
     // Memcpy struct.
     CGF.Builder.CreateMemCpy(dstGEP, srcGEP, size, 1);
   } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
-    if (!HLMatrixType::isMatrixArray(Ty)) {
+    if (!HLMatrixType::isMatrixArray(Ty)
+      || AreMatrixArrayOrientationMatching(CGF.getContext(), *m_pHLModule, SrcType, DestType)) {
       Value *srcGEP = CGF.Builder.CreateInBoundsGEP(SrcPtr, idxList);
       Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
       unsigned size = this->TheModule.getDataLayout().getTypeAllocSize(AT);
       // Memcpy non-matrix array.
       CGF.Builder.CreateMemCpy(dstGEP, srcGEP, size, 1);
     } else {
+      // Copy matrix arrays elementwise if orientation changes are needed.
       llvm::Type *ET = AT->getElementType();
       QualType EltDestType = CGF.getContext().getBaseElementType(DestType);
       QualType EltSrcType = CGF.getContext().getBaseElementType(SrcType);
diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp
index 9a3420a..66d60ad 100644
--- a/tools/clang/lib/Sema/SemaHLSL.cpp
+++ b/tools/clang/lib/Sema/SemaHLSL.cpp
@@ -3681,7 +3681,8 @@
       type = RefType ? RefType->getPointeeType() : AttrType->getEquivalentType();
     }
 
-    return type->getCanonicalTypeUnqualified();
+    // Despite its name, getCanonicalTypeUnqualified will preserve const for array elements or something
+    return QualType(type->getCanonicalTypeUnqualified()->getTypePtr(), 0);
   }
 
   /// <summary>Given a Clang type, return the ArBasicKind classification for its contents.</summary>
diff --git a/tools/clang/test/CodeGenHLSL/passes/sroa_hlsl/groupshared_array_struct_matrix_regression.hlsl b/tools/clang/test/CodeGenHLSL/passes/sroa_hlsl/groupshared_array_struct_matrix_regression.hlsl
new file mode 100644
index 0000000..502f100
--- /dev/null
+++ b/tools/clang/test/CodeGenHLSL/passes/sroa_hlsl/groupshared_array_struct_matrix_regression.hlsl
@@ -0,0 +1,14 @@
+// RUN: %dxc -E main -T vs_6_2 %s | FileCheck %s
+
+// Regression test for GitHub #1631, where SROA would generate more uses
+// of a value while processing it (due to expanding a memcpy) and fail
+// to process the new uses. This caused global structs of matrices to reach HLMatrixLower,
+// which couldn't handle them and would unexpectedly leave matrix intrinsics untouched.
+// Compilation would then fail with "error: Fail to lower matrix load/store."
+
+// CHECK: ret void
+
+struct S { int1x1 x, y; };
+groupshared S gs[1];
+void f(S s[1]) {}
+void main() { f(gs); }
\ No newline at end of file
diff --git a/tools/clang/test/CodeGenHLSL/quick-ll-test/cleanup-addrspacecast.ll b/tools/clang/test/CodeGenHLSL/quick-ll-test/cleanup-addrspacecast.ll
new file mode 100644
index 0000000..a070015
--- /dev/null
+++ b/tools/clang/test/CodeGenHLSL/quick-ll-test/cleanup-addrspacecast.ll
@@ -0,0 +1,411 @@
+; RUN: %opt %s -hlsl-dxil-cleanup-addrspacecast -S | FileCheck %s
+
+; Make sure addrspacecast is removed
+; CHECK-NOT: addrspacecast
+
+
+; ModuleID = 'MyModule'
+target datalayout = "e-m:e-p:32:32-i1:32-i8:32-i16:32-i32:32-i64:64-f16:32-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-ms-dx"
+
+%"class.RWStructuredBuffer<Derived>" = type { %struct.Derived }
+%struct.Derived = type { %struct.Base, float }
+%struct.Base = type { i32 }
+%"$Globals" = type { [2 x %struct.Derived], i32 }
+%dx.types.Handle = type { i8* }
+%dx.types.CBufRet.i32 = type { i32, i32, i32, i32 }
+%dx.types.CBufRet.f32 = type { float, float, float, float }
+%"class.RWStructuredBuffer<Base>" = type { %struct.Base }
+
+@"\01?sb_Derived@@3PAV?$RWStructuredBuffer@UDerived@@@@A" = external constant [2 x %"class.RWStructuredBuffer<Derived>"], align 4
+@"$Globals" = external constant %"$Globals"
+@"\01?gs_Derived0@@3UDerived@@A.1" = addrspace(3) global float undef
+@"\01?gs_Derived0@@3UDerived@@A.0.0" = addrspace(3) global i32 undef
+@"\01?gs_Derived1@@3UDerived@@A.1" = addrspace(3) global float undef
+@"\01?gs_Derived1@@3UDerived@@A.0.0" = addrspace(3) global i32 undef
+@"\01?gs_Derived@@3PAUDerived@@A.1" = addrspace(3) global [2 x float] undef
+@"\01?gs_Derived@@3PAUDerived@@A.0.0" = addrspace(3) global [2 x i32] undef
+@"\01?gs_vecArray@@3PAY01$$CAV?$vector@M$01@@A.v.1dim" = addrspace(3) global [8 x float] undef
+@"\01?gs_matArray@@3PAY01$$CAV?$matrix@M$01$01@@A.v.v.1dim" = addrspace(3) global [16 x float] undef
+
+; Function Attrs: nounwind
+define void @main() #0 {
+entry:
+  %"$Globals_cbuffer" = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 2, i32 0, i32 0, i1 false)
+  %0 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 0)
+  %1 = extractvalue %dx.types.CBufRet.i32 %0, 0
+  store i32 %1, i32 addrspace(3)* @"\01?gs_Derived0@@3UDerived@@A.0.0", align 4
+  %2 = call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 0)
+  %3 = extractvalue %dx.types.CBufRet.f32 %2, 1
+  store float %3, float addrspace(3)* @"\01?gs_Derived0@@3UDerived@@A.1", align 4
+  %4 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %5 = extractvalue %dx.types.CBufRet.i32 %4, 2
+  %sub = sub nsw i32 1, %5
+  %6 = getelementptr [2 x i32], [2 x i32] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.0.0", i32 0, i32 %sub
+  %7 = getelementptr [2 x float], [2 x float] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.1", i32 0, i32 %sub
+  %8 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 %5)
+  %9 = extractvalue %dx.types.CBufRet.i32 %8, 0
+  store i32 %9, i32 addrspace(3)* %6, align 4
+  %10 = call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 %5)
+  %11 = extractvalue %dx.types.CBufRet.f32 %10, 1
+  store float %11, float addrspace(3)* %7, align 4
+  %12 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %13 = extractvalue %dx.types.CBufRet.i32 %12, 2
+  %14 = getelementptr [2 x i32], [2 x i32] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.0.0", i32 0, i32 %13
+  store i32 1, i32 addrspace(3)* %14, align 4, !tbaa !30
+  %15 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %16 = extractvalue %dx.types.CBufRet.i32 %15, 2
+  %y10 = getelementptr inbounds [2 x float], [2 x float] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.1", i32 0, i32 %16
+  store float 2.000000e+00, float addrspace(3)* %y10, align 4, !tbaa !34
+  %tobool = icmp eq i32 %1, 0
+  br i1 %tobool, label %if.then, label %if.else, !dx.controlflow.hints !36
+
+if.then:                                          ; preds = %entry
+  store i32 0, i32 addrspace(3)* @"\01?gs_Derived1@@3UDerived@@A.0.0", align 4
+  store float %3, float addrspace(3)* @"\01?gs_Derived1@@3UDerived@@A.1", align 4
+
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ; add addrspacecast inst
+  %new.asc.0 = addrspacecast i32 addrspace(3)* bitcast (float addrspace(3)* @"\01?gs_Derived1@@3UDerived@@A.1" to i32 addrspace(3)*) to i32*
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+  br label %if.end.11
+
+if.else:                                          ; preds = %entry
+  %17 = load i32, i32* addrspacecast (i32 addrspace(3)* getelementptr inbounds ([2 x i32], [2 x i32] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.0.0", i32 0, i32 0) to i32*), align 4, !tbaa !30
+  %tobool7 = icmp eq i32 %17, 0
+  br i1 %tobool7, label %if.else.10, label %if.then.8
+
+if.then.8:                                        ; preds = %if.else
+  %18 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %19 = extractvalue %dx.types.CBufRet.i32 %18, 2
+  %20 = getelementptr [2 x i32], [2 x i32] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.0.0", i32 0, i32 %19
+
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ; add constant addrspacecast expr inside gep inst
+  %new.gep.asc.1 = getelementptr [2 x i32], [2 x i32]* addrspacecast ([2 x i32] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.0.0" to [2 x i32]*), i32 0, i32 %19
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+  %21 = getelementptr [2 x float], [2 x float] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.1", i32 0, i32 %19
+  %22 = load i32, i32 addrspace(3)* %20, align 4
+  %23 = load float, float addrspace(3)* %21, align 4
+  store i32 %22, i32 addrspace(3)* @"\01?gs_Derived1@@3UDerived@@A.0.0", align 4
+  store float %23, float addrspace(3)* @"\01?gs_Derived1@@3UDerived@@A.1", align 4
+  %phitmp26 = sitofp i32 %22 to float
+  br label %if.end.11
+
+if.else.10:                                       ; preds = %if.else
+  %24 = load float, float addrspace(3)* getelementptr inbounds ([2 x float], [2 x float] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.1", i32 0, i32 0), align 4
+  store i32 0, i32 addrspace(3)* @"\01?gs_Derived1@@3UDerived@@A.0.0", align 4
+  store float %24, float addrspace(3)* @"\01?gs_Derived1@@3UDerived@@A.1", align 4
+  br label %if.end.11
+
+if.end.11:                                        ; preds = %if.then.8, %if.else.10, %if.then
+  %25 = phi float [ %phitmp26, %if.then.8 ], [ 0.000000e+00, %if.else.10 ], [ 0.000000e+00, %if.then ]
+  %26 = phi float [ %23, %if.then.8 ], [ %24, %if.else.10 ], [ %3, %if.then ]
+
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ; add phi mixing incoming addrspace cast scenarios and incoming global pointers
+  %new.phi.0 = phi i32* [ %new.gep.asc.1, %if.then.8 ], [ addrspacecast (i32 addrspace(3)* getelementptr inbounds ([2 x i32], [2 x i32] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.0.0", i32 0, i32 0) to i32*), %if.else.10 ], [ %new.asc.0, %if.then ]
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+  %27 = fmul fast float %25, %25
+  %mul2.i = fmul fast float %27, %26
+  %28 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %29 = extractvalue %dx.types.CBufRet.i32 %28, 2
+  %30 = getelementptr [2 x i32], [2 x i32] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.0.0", i32 0, i32 %29
+  %31 = addrspacecast i32 addrspace(3)* %30 to i32*
+  %32 = getelementptr [2 x float], [2 x float] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.1", i32 0, i32 %29
+  %33 = addrspacecast float addrspace(3)* %32 to float*
+  store i32 5, i32* %31, align 4, !tbaa !30
+
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ; Store to new phi ptr
+  store i32 13, i32* %new.phi.0, align 4, !tbaa !30
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+  store float 6.000000e+00, float* %33, align 4, !tbaa !34
+  %34 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %35 = extractvalue %dx.types.CBufRet.i32 %34, 2
+  %36 = getelementptr [2 x i32], [2 x i32] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.0.0", i32 0, i32 %35
+  %37 = addrspacecast i32 addrspace(3)* %36 to i32*
+  %conv = fptosi float %mul2.i to i32
+  store i32 %conv, i32* %37, align 4, !tbaa !30
+  %38 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %39 = extractvalue %dx.types.CBufRet.i32 %38, 2
+  %40 = add i32 %39, 0
+  %sb_Derived_UAV_structbuf29 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 %40, i1 false)
+  %41 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %42 = extractvalue %dx.types.CBufRet.i32 %41, 2
+  %43 = getelementptr [2 x i32], [2 x i32] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.0.0", i32 0, i32 %42
+  %44 = getelementptr [2 x float], [2 x float] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.1", i32 0, i32 %42
+  %45 = load i32, i32 addrspace(3)* %43, align 4
+  call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %sb_Derived_UAV_structbuf29, i32 0, i32 0, i32 %45, i32 undef, i32 undef, i32 undef, i8 1)
+  %46 = load float, float addrspace(3)* %44, align 4
+  call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %sb_Derived_UAV_structbuf29, i32 0, i32 4, float %46, float undef, float undef, float undef, i8 1)
+  %47 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %48 = extractvalue %dx.types.CBufRet.i32 %47, 2
+  %49 = add i32 %48, 0
+  %sb_Derived_UAV_structbuf28 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 %49, i1 false)
+  call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %sb_Derived_UAV_structbuf28, i32 1, i32 0, i32 7, i32 undef, i32 undef, i32 undef, i8 1)
+  call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %sb_Derived_UAV_structbuf28, i32 1, i32 0, i32 7, i32 undef, i32 undef, i32 undef, i8 1)
+  call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %sb_Derived_UAV_structbuf28, i32 1, i32 4, float 8.000000e+00, float undef, float undef, float undef, i8 1)
+  %50 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %51 = extractvalue %dx.types.CBufRet.i32 %50, 2
+  %cmp.23 = icmp slt i32 %51, 4
+  br i1 %cmp.23, label %for.body.preheader, label %for.end
+
+for.body.preheader:                               ; preds = %if.end.11
+  br label %for.body
+
+for.body:                                         ; preds = %for.body.preheader, %for.body
+  %j.025 = phi i32 [ %inc, %for.body ], [ %51, %for.body.preheader ]
+  %k.024 = phi i32 [ %sub46, %for.body ], [ 4, %for.body.preheader ]
+
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ; add phi mixing incoming scenarios
+  %new.phi.1 = phi i32* [ %new.ptr.1, %for.body ], [ %new.phi.0, %for.body.preheader ]
+  %new.load.1 = load i32, i32* %new.phi.1, align 4, !tbaa !30
+  store i32 %new.load.1, i32* bitcast (float* getelementptr inbounds ([8 x float], [8 x float]* addrspacecast ([8 x float] addrspace(3)* @"\01?gs_vecArray@@3PAY01$$CAV?$vector@M$01@@A.v.1dim" to [8 x float]*), i32 0, i32 0) to i32*), align 4, !tbaa !30
+
+  ; If desired, for additional testing of function case:
+  ; use same constant in a function where we will be unable to replace it:
+  ;%new.unused.0 = call void @FunctionConsumingPtr(i32* bitcast (float* getelementptr inbounds ([8 x float], [8 x float]* addrspacecast ([8 x float] addrspace(3)* @"\01?gs_vecArray@@3PAY01$$CAV?$vector@M$01@@A.v.1dim" to [8 x float]*), i32 0, i32 0) to i32*))
+
+  ; If desired, for additional testing of function case:
+  ; use %new.phi.1 in a function where we will be unable to replace it:
+  ;%new.unused.0 = call void @FunctionConsumingPtr(i32* %new.phi.1)
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+
+  %rem = srem i32 %j.025, 2
+  %52 = getelementptr [2 x i32], [2 x i32] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.0.0", i32 0, i32 %rem
+  %53 = load i32, i32 addrspace(3)* %52, align 4, !tbaa !30
+
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ; use the value
+  %new.add.1 = add i32 %new.load.1, %53
+  %conv26 = sitofp i32 %new.add.1 to float
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+  %sub27 = sub nsw i32 1, %j.025
+  %rem28 = srem i32 %sub27, 2
+  %y309 = getelementptr inbounds [2 x float], [2 x float] addrspace(3)* @"\01?gs_Derived@@3PAUDerived@@A.1", i32 0, i32 %rem28
+  %54 = load float, float addrspace(3)* %y309, align 4, !tbaa !34
+  %div = sdiv i32 %j.025, 2
+  %55 = mul i32 %div, 2
+  %56 = add i32 %rem, %55
+  %57 = mul i32 %56, 2
+  %58 = add i32 0, %57
+  %59 = getelementptr [8 x float], [8 x float] addrspace(3)* @"\01?gs_vecArray@@3PAY01$$CAV?$vector@M$01@@A.v.1dim", i32 0, i32 %58
+
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ; set pointer for loop to bitcast of float pointer
+  %new.ptr.0 = addrspacecast float addrspace(3)* %59 to float*
+  %new.ptr.1 = bitcast float* %new.ptr.0 to i32*
+  ; new.ptr.1 is used in phi at beginning of this block
+  ; If desired, for additional testing of function case:
+  ; also use it in a function where we will be unable to replace it:
+  ;%new.unused.0 = call void @FunctionConsumingPtr(i32* %new.ptr.1)
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+  store float %conv26, float addrspace(3)* %59, align 8
+  %60 = mul i32 %div, 2
+  %61 = add i32 %rem, %60
+  %62 = mul i32 %61, 2
+  %63 = add i32 1, %62
+  %64 = getelementptr [8 x float], [8 x float] addrspace(3)* @"\01?gs_vecArray@@3PAY01$$CAV?$vector@M$01@@A.v.1dim", i32 0, i32 %63
+  store float %54, float addrspace(3)* %64, align 4
+  %rem38 = srem i32 %k.024, 2
+  %div39 = sdiv i32 %k.024, 2
+  %65 = mul i32 %div39, 2
+  %66 = add i32 %rem38, %65
+  %67 = mul i32 %66, 2
+  %68 = add i32 0, %67
+  %69 = getelementptr [8 x float], [8 x float] addrspace(3)* @"\01?gs_vecArray@@3PAY01$$CAV?$vector@M$01@@A.v.1dim", i32 0, i32 %68
+  %70 = load float, float addrspace(3)* %69, align 8
+  %71 = mul i32 %div39, 2
+  %72 = add i32 %rem38, %71
+  %73 = mul i32 %72, 2
+  %74 = add i32 1, %73
+  %75 = getelementptr [8 x float], [8 x float] addrspace(3)* @"\01?gs_vecArray@@3PAY01$$CAV?$vector@M$01@@A.v.1dim", i32 0, i32 %74
+  %76 = load float, float addrspace(3)* %75, align 4
+  %77 = mul i32 %rem38, 2
+  %78 = add i32 %div39, %77
+  %79 = mul i32 %78, 2
+  %80 = add i32 0, %79
+  %81 = getelementptr [8 x float], [8 x float] addrspace(3)* @"\01?gs_vecArray@@3PAY01$$CAV?$vector@M$01@@A.v.1dim", i32 0, i32 %80
+  %82 = load float, float addrspace(3)* %81, align 8
+  %83 = mul i32 %rem38, 2
+  %84 = add i32 %div39, %83
+  %85 = mul i32 %84, 2
+  %86 = add i32 1, %85
+  %87 = getelementptr [8 x float], [8 x float] addrspace(3)* @"\01?gs_vecArray@@3PAY01$$CAV?$vector@M$01@@A.v.1dim", i32 0, i32 %86
+  %88 = load float, float addrspace(3)* %87, align 4
+  %89 = mul i32 %div, 2
+  %90 = add i32 %rem, %89
+  %91 = mul i32 %90, 4
+  %92 = add i32 0, %91
+  %93 = getelementptr [16 x float], [16 x float] addrspace(3)* @"\01?gs_matArray@@3PAY01$$CAV?$matrix@M$01$01@@A.v.v.1dim", i32 0, i32 %92
+  store float %70, float addrspace(3)* %93, align 16
+  %94 = mul i32 %div, 2
+  %95 = add i32 %rem, %94
+  %96 = mul i32 %95, 4
+  %97 = add i32 1, %96
+  %98 = getelementptr [16 x float], [16 x float] addrspace(3)* @"\01?gs_matArray@@3PAY01$$CAV?$matrix@M$01$01@@A.v.v.1dim", i32 0, i32 %97
+  store float %82, float addrspace(3)* %98, align 4
+  %99 = mul i32 %div, 2
+  %100 = add i32 %rem, %99
+  %101 = mul i32 %100, 4
+  %102 = add i32 2, %101
+  %103 = getelementptr [16 x float], [16 x float] addrspace(3)* @"\01?gs_matArray@@3PAY01$$CAV?$matrix@M$01$01@@A.v.v.1dim", i32 0, i32 %102
+  store float %76, float addrspace(3)* %103, align 8
+  %104 = mul i32 %div, 2
+  %105 = add i32 %rem, %104
+  %106 = mul i32 %105, 4
+  %107 = add i32 3, %106
+  %108 = getelementptr [16 x float], [16 x float] addrspace(3)* @"\01?gs_matArray@@3PAY01$$CAV?$matrix@M$01$01@@A.v.v.1dim", i32 0, i32 %107
+  store float %88, float addrspace(3)* %108, align 4
+  %sub46 = add nsw i32 %k.024, -1
+  %inc = add nsw i32 %j.025, 1
+  %exitcond = icmp eq i32 %inc, 4
+  br i1 %exitcond, label %for.end.loopexit, label %for.body
+
+for.end.loopexit:                                 ; preds = %for.body
+  %phitmp = srem i32 %51, 2
+  br label %for.end
+
+for.end:                                          ; preds = %for.end.loopexit, %if.end.11
+  %k.0.lcssa = phi i32 [ 0, %if.end.11 ], [ %phitmp, %for.end.loopexit ]
+  %109 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %110 = extractvalue %dx.types.CBufRet.i32 %109, 2
+  %sub47 = sub nsw i32 1, %110
+  %111 = add i32 %110, 2
+  %112 = mul i32 %sub47, 2
+  %113 = add i32 %110, %112
+  %114 = mul i32 %113, 4
+  %115 = add i32 %110, %114
+  %116 = getelementptr [16 x float], [16 x float] addrspace(3)* @"\01?gs_matArray@@3PAY01$$CAV?$matrix@M$01$01@@A.v.v.1dim", i32 0, i32 %115
+  %117 = load float, float addrspace(3)* %116, align 4
+  %118 = mul i32 %sub47, 2
+  %119 = add i32 %110, %118
+  %120 = mul i32 %119, 4
+  %121 = add i32 %111, %120
+  %122 = getelementptr [16 x float], [16 x float] addrspace(3)* @"\01?gs_matArray@@3PAY01$$CAV?$matrix@M$01$01@@A.v.v.1dim", i32 0, i32 %121
+  %123 = load float, float addrspace(3)* %122, align 4
+  %124 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %125 = extractvalue %dx.types.CBufRet.i32 %124, 2
+  %sub50 = sub nsw i32 1, %125
+  %126 = mul i32 %125, 2
+  %127 = add i32 %sub50, %126
+  %128 = mul i32 %127, 2
+  %129 = add i32 0, %128
+  %130 = getelementptr [8 x float], [8 x float] addrspace(3)* @"\01?gs_vecArray@@3PAY01$$CAV?$vector@M$01@@A.v.1dim", i32 0, i32 %129
+  store float %117, float addrspace(3)* %130, align 8
+  %131 = mul i32 %125, 2
+  %132 = add i32 %sub50, %131
+  %133 = mul i32 %132, 2
+  %134 = add i32 1, %133
+  %135 = getelementptr [8 x float], [8 x float] addrspace(3)* @"\01?gs_vecArray@@3PAY01$$CAV?$vector@M$01@@A.v.1dim", i32 0, i32 %134
+  store float %123, float addrspace(3)* %135, align 4
+  %136 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %137 = extractvalue %dx.types.CBufRet.i32 %136, 2
+  %sub54 = sub nsw i32 1, %137
+  %138 = mul i32 %sub54, 2
+  %139 = add i32 %k.0.lcssa, %138
+  %140 = mul i32 %139, 2
+  %141 = add i32 0, %140
+  %142 = getelementptr [8 x float], [8 x float] addrspace(3)* @"\01?gs_vecArray@@3PAY01$$CAV?$vector@M$01@@A.v.1dim", i32 0, i32 %141
+  %143 = load float, float addrspace(3)* %142, align 8
+  %conv57 = fptosi float %143 to i32
+  %144 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %145 = extractvalue %dx.types.CBufRet.i32 %144, 2
+  %146 = add i32 %145, 0
+  %sb_Derived_UAV_structbuf27 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 %146, i1 false)
+  call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %sb_Derived_UAV_structbuf27, i32 1, i32 0, i32 %conv57, i32 undef, i32 undef, i32 undef, i8 1)
+  %147 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %148 = extractvalue %dx.types.CBufRet.i32 %147, 2
+  %sub61 = sub nsw i32 1, %148
+  %149 = mul i32 %sub61, 2
+  %150 = add i32 %148, %149
+  %151 = mul i32 %150, 2
+  %152 = add i32 1, %151
+  %153 = getelementptr [8 x float], [8 x float] addrspace(3)* @"\01?gs_vecArray@@3PAY01$$CAV?$vector@M$01@@A.v.1dim", i32 0, i32 %152
+  %154 = load float, float addrspace(3)* %153, align 4
+  %155 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)
+  %156 = extractvalue %dx.types.CBufRet.i32 %155, 2
+  %157 = add i32 %156, 0
+  %sb_Derived_UAV_structbuf = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 %157, i1 false)
+  call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %sb_Derived_UAV_structbuf, i32 1, i32 4, float %154, float undef, float undef, float undef, i8 1)
+  ret void
+}
+
+; Function Attrs: nounwind
+declare void @FunctionConsumingPtr(i32*) #2
+
+
+; Function Attrs: nounwind readonly
+declare %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32, %dx.types.Handle, i32) #1
+
+; Function Attrs: nounwind readonly
+declare %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32(i32, %dx.types.Handle, i32) #1
+
+; Function Attrs: nounwind readonly
+declare %dx.types.Handle @dx.op.createHandle(i32, i8, i32, i32, i1) #1
+
+; Function Attrs: nounwind
+declare void @dx.op.bufferStore.f32(i32, %dx.types.Handle, i32, i32, float, float, float, float, i8) #2
+
+; Function Attrs: nounwind
+declare void @dx.op.bufferStore.i32(i32, %dx.types.Handle, i32, i32, i32, i32, i32, i32, i8) #2
+
+attributes #0 = { nounwind "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-realign-stack" "stack-protector-buffer-size"="0" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #1 = { nounwind readonly }
+attributes #2 = { nounwind }
+
+!pauseresume = !{!0}
+!llvm.ident = !{!1}
+!dx.version = !{!2}
+!dx.valver = !{!3}
+!dx.shaderModel = !{!4}
+!dx.resources = !{!5}
+!dx.typeAnnotations = !{!11, !23}
+!dx.entryPoints = !{!27}
+
+!0 = !{!"hlsl-dxilemit", !"hlsl-dxilload"}
+!1 = !{!"clang version 3.7 (tags/RELEASE_370/final)"}
+!2 = !{i32 1, i32 0}
+!3 = !{i32 1, i32 4}
+!4 = !{!"cs", i32 6, i32 0}
+!5 = !{null, !6, !9, null}
+!6 = !{!7}
+!7 = !{i32 0, [2 x %"class.RWStructuredBuffer<Derived>"]* undef, !"sb_Derived", i32 0, i32 0, i32 2, i32 12, i1 false, i1 false, i1 false, !8}
+!8 = !{i32 1, i32 8}
+!9 = !{!10}
+!10 = !{i32 0, %"$Globals"* undef, !"$Globals", i32 0, i32 0, i32 1, i32 28, null}
+!11 = !{i32 0, %struct.Derived undef, !12, %struct.Base undef, !15, %"class.RWStructuredBuffer<Derived>" undef, !17, %"class.RWStructuredBuffer<Base>" undef, !19, %"$Globals" undef, !20}
+!12 = !{i32 8, !13, !14}
+!13 = !{i32 6, !"Base", i32 3, i32 0}
+!14 = !{i32 6, !"y", i32 3, i32 4, i32 7, i32 9}
+!15 = !{i32 4, !16}
+!16 = !{i32 6, !"x", i32 3, i32 0, i32 7, i32 4}
+!17 = !{i32 8, !18}
+!18 = !{i32 6, !"h", i32 3, i32 0}
+!19 = !{i32 4, !18}
+!20 = !{i32 28, !21, !22}
+!21 = !{i32 6, !"c_Derived", i32 3, i32 0}
+!22 = !{i32 6, !"i", i32 3, i32 24, i32 7, i32 4}
+!23 = !{i32 1, void ()* @main, !24}
+!24 = !{!25}
+!25 = !{i32 1, !26, !26}
+!26 = !{}
+!27 = !{void ()* @main, !"main", null, !5, !28}
+!28 = !{i32 4, !29}
+!29 = !{i32 1, i32 1, i32 1}
+!30 = !{!31, !31, i64 0}
+!31 = !{!"int", !32, i64 0}
+!32 = !{!"omnipotent char", !33, i64 0}
+!33 = !{!"Simple C/C++ TBAA"}
+!34 = !{!35, !35, i64 0}
+!35 = !{!"float", !32, i64 0}
+!36 = distinct !{!36, !"dx.controlflow.hints", i32 1}
diff --git a/tools/clang/test/CodeGenHLSL/quick-test/addrspace_stress.hlsl b/tools/clang/test/CodeGenHLSL/quick-test/addrspace_stress.hlsl
new file mode 100644
index 0000000..94f245e
--- /dev/null
+++ b/tools/clang/test/CodeGenHLSL/quick-test/addrspace_stress.hlsl
@@ -0,0 +1,71 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// Try a bunch of stuff and make sure we don't crash,
+// and addrspacecast is removed.
+// CHECK: @main()
+// CHECK-NOT: addrspacecast
+// CHECK: ret void
+
+struct Base {
+  int x;
+  int getX() { return x; }
+  void setX(int X) { x = X; }
+};
+struct Derived : Base {
+  float y;
+  float getYtimesX2() { return y * x * getX(); }
+  void setXY(int X, float Y) { x = X; setX(X); y = Y; }
+};
+
+Derived c_Derived[2];
+groupshared Derived gs_Derived0;
+groupshared Derived gs_Derived1;
+groupshared Derived gs_Derived[2];
+groupshared float2 gs_vecArray[2][2];
+groupshared float2x2 gs_matArray[2][2];
+
+int i;
+
+RWStructuredBuffer<Derived> sb_Derived[2];
+RWStructuredBuffer<Base> sb_Base[2];
+
+void assign(out Derived o, in Derived i) { o = i; }
+
+[numthreads(1,1,1)]
+void main() {
+  gs_Derived0 = c_Derived[0];
+  gs_Derived[1 - i] = c_Derived[i];
+  gs_Derived[i].x = 1;
+  gs_Derived[i].y = 2;
+  int x = -1;
+  float y = -1;
+  [branch]
+  if (!gs_Derived0.getX())
+    assign(gs_Derived1, gs_Derived0);
+  else if (gs_Derived[0].getX())
+    assign(gs_Derived1, gs_Derived[i]);
+  else
+    assign(gs_Derived1, gs_Derived[0]);
+  float f = gs_Derived1.getYtimesX2();
+  gs_Derived[i].setXY(5, 6);
+  gs_Derived[i].setX(f);
+  sb_Derived[i][0] = gs_Derived[i];
+
+  // Used to crash:
+  // HLOperationLower(6439): in TranslateStructBufSubscriptUser
+  // because it doesn't expect the bitcast to Base inside setXY when setX is called.
+  sb_Derived[i][1].setXY(7, 8);
+
+  sb_Base[i][2] = (Base)gs_Derived[i];
+
+  [loop]
+  int k = 4;
+  for (int j = i; j < 4; ++j) {
+    gs_vecArray[j/2][j%2] = float2(gs_Derived[j % 2].x, gs_Derived[(1 - j) % 2].y);
+    gs_matArray[j/2][j%2] = float2x2(gs_vecArray[k/2][k%2], gs_vecArray[k%2][k/2]);
+    k -= 1;
+  }
+  gs_vecArray[i][1-i] = gs_matArray[1-i][i][i];
+  sb_Derived[i][1].x = gs_vecArray[1-i][k%2].x;
+  sb_Derived[i][1].y = gs_vecArray[1-i][i].y;
+}
diff --git a/tools/clang/test/CodeGenHLSL/quick-test/addrspacecast.hlsl b/tools/clang/test/CodeGenHLSL/quick-test/addrspacecast.hlsl
index d2cdd00..47e981a 100644
--- a/tools/clang/test/CodeGenHLSL/quick-test/addrspacecast.hlsl
+++ b/tools/clang/test/CodeGenHLSL/quick-test/addrspacecast.hlsl
@@ -1,7 +1,9 @@
 // RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
 
-// Make sure generate addrspacecast.
-// CHECK: addrspacecast ([6 x float] addrspace(3)*
+// Make sure addrspacecast is cleaned up.
+// CHECK: @main()
+// CHECK-NOT: addrspacecast
+// CHECK: ret void
 
 struct ST
 {
diff --git a/tools/clang/test/CodeGenHLSL/quick-test/flat_addrspacecast.hlsl b/tools/clang/test/CodeGenHLSL/quick-test/flat_addrspacecast.hlsl
index f0d9878..f6b6b1b 100644
--- a/tools/clang/test/CodeGenHLSL/quick-test/flat_addrspacecast.hlsl
+++ b/tools/clang/test/CodeGenHLSL/quick-test/flat_addrspacecast.hlsl
@@ -1,7 +1,9 @@
 // RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
 
-// Make sure generate addrspacecast.
-// CHECK: addrspacecast (float addrspace(3)*
+// Make sure addrspacecast is cleaned up.
+// CHECK: @main()
+// CHECK-NOT: addrspacecast
+// CHECK: ret void
 
 struct ST
 {
diff --git a/tools/clang/test/CodeGenHLSL/quick-test/matrix_array_arg_copy_optimized_away.hlsl b/tools/clang/test/CodeGenHLSL/quick-test/matrix_array_arg_copy_optimized_away.hlsl
new file mode 100644
index 0000000..af5cdff
--- /dev/null
+++ b/tools/clang/test/CodeGenHLSL/quick-test/matrix_array_arg_copy_optimized_away.hlsl
@@ -0,0 +1,24 @@
+// RUN: %dxc -E main -T vs_6_2 %s | FileCheck %s
+
+// Regression test for a bug where matrix array arguments would get expanded
+// into elementwise copies to respect pass-by-value semantics in case orientation
+// changes were needed, but the pattern would not be cleaned up by later optimization passes.
+
+const int1x1 cb_matrices[64];
+const row_major int1x1 cb_matrices_rm[64];
+const int cb_index;
+
+int get_cm(column_major int1x1 matrices[64]) { return matrices[cb_index]; }
+int get_rm(row_major int1x1 matrices[64]) { return matrices[cb_index]; }
+
+int2 main() : OUT
+{
+    // There should be no dynamic GEP of an array,
+    // we should be dynamically indexing the constant buffer directly
+    // CHECK-NOT: getelementptr
+    return int2(
+        // CHECK: call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %{{.*}}, i32 %{{.*}})
+        get_cm(cb_matrices), // Test implicit column major to explicit column major (no conversion needed)
+        // CHECK: call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %{{.*}}, i32 %{{.*}})
+        get_rm(cb_matrices_rm)); // Test explicit row major to explicit row major (no conversion needed)
+}
\ No newline at end of file
diff --git a/tools/clang/test/CodeGenHLSL/quick-test/remove-addrspacecastinst.hlsl b/tools/clang/test/CodeGenHLSL/quick-test/remove-addrspacecastinst.hlsl
new file mode 100644
index 0000000..00ac16d
--- /dev/null
+++ b/tools/clang/test/CodeGenHLSL/quick-test/remove-addrspacecastinst.hlsl
@@ -0,0 +1,10 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// CHECK: @main()
+// CHECK-NOT: addrspacecast
+// CHECK: ret void
+
+struct Foo { int x; int getX() { return x; } };
+groupshared Foo foo[2];
+int i;
+int main() : OUT { return foo[i].getX(); }
diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py
index f371a74..6983b90 100644
--- a/utils/hct/hctdb.py
+++ b/utils/hct/hctdb.py
@@ -1589,6 +1589,7 @@
         add_pass('red', 'ReducibilityAnalysis', 'Reducibility Analysis', [])
         add_pass('viewid-state', 'ComputeViewIdState', 'Compute information related to ViewID', [])
         add_pass('hlsl-translate-dxil-opcode-version', 'DxilTranslateRawBuffer', 'Translates one version of dxil to another', [])
+        add_pass('hlsl-dxil-cleanup-addrspacecast', 'DxilCleanupAddrSpaceCast', 'HLSL DXIL Cleanup Address Space Cast (part of hlsl-dxilfinalize)', [])
 
         category_lib="llvm"
         add_pass('ipsccp', 'IPSCCP', 'Interprocedural Sparse Conditional Constant Propagation', [])