[InstCombine] simplify code for inserts -> splat; NFC

llvm-svn: 364441
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index d812c5b..693fe5d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -662,18 +662,17 @@
   return true;
 }
 
-// Turn a chain of inserts that splats a value into a canonical insert + shuffle
-// splat. That is:
-// insertelt(insertelt(insertelt(insertelt X, %k, 0), %k, 1), %k, 2) ... ->
-// shufflevector(insertelt(X, %k, 0), undef, zero)
-static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) {
-  // We are interested in the last insert in a chain. So, if this insert
-  // has a single user, and that user is an insert, bail.
+/// Turn a chain of inserts that splats a value into an insert + shuffle:
+/// insertelt(insertelt(insertelt(insertelt X, %k, 0), %k, 1), %k, 2) ... ->
+/// shufflevector(insertelt(X, %k, 0), undef, zero)
+static Instruction *foldInsSequenceIntoSplat(InsertElementInst &InsElt) {
+  // We are interested in the last insert in a chain. So if this insert has a
+  // single user and that user is an insert, bail.
   if (InsElt.hasOneUse() && isa<InsertElementInst>(InsElt.user_back()))
     return nullptr;
 
-  VectorType *VT = cast<VectorType>(InsElt.getType());
-  int NumElements = VT->getNumElements();
+  auto *VecTy = cast<VectorType>(InsElt.getType());
+  unsigned NumElements = VecTy->getNumElements();
 
   // Do not try to do this for a one-element vector, since that's a nop,
   // and will cause an inf-loop.
@@ -709,20 +708,15 @@
   if (llvm::any_of(ElementPresent, [](bool Present) { return !Present; }))
     return nullptr;
 
-  // All right, create the insert + shuffle.
-  Instruction *InsertFirst;
-  if (cast<ConstantInt>(FirstIE->getOperand(2))->isZero())
-    InsertFirst = FirstIE;
-  else
-    InsertFirst = InsertElementInst::Create(
-        UndefValue::get(VT), SplatVal,
-        ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), 0),
-        "", &InsElt);
+  // Create the insert + shuffle.
+  Type *Int32Ty = Type::getInt32Ty(InsElt.getContext());
+  UndefValue *UndefVec = UndefValue::get(VecTy);
+  Constant *Zero = ConstantInt::get(Int32Ty, 0);
+  if (!cast<ConstantInt>(FirstIE->getOperand(2))->isZero())
+    FirstIE = InsertElementInst::Create(UndefVec, SplatVal, Zero, "", &InsElt);
 
-  Constant *ZeroMask = ConstantAggregateZero::get(
-      VectorType::get(Type::getInt32Ty(InsElt.getContext()), NumElements));
-
-  return new ShuffleVectorInst(InsertFirst, UndefValue::get(VT), ZeroMask);
+  Constant *ZeroMask = ConstantVector::getSplat(NumElements, Zero);
+  return new ShuffleVectorInst(FirstIE, UndefVec, ZeroMask);
 }
 
 /// If we have an insertelement instruction feeding into another insertelement
@@ -940,9 +934,7 @@
   if (Instruction *NewInsElt = hoistInsEltConst(IE, Builder))
     return NewInsElt;
 
-  // Turn a sequence of inserts that broadcasts a scalar into a single
-  // insert + shufflevector.
-  if (Instruction *Broadcast = foldInsSequenceIntoBroadcast(IE))
+  if (Instruction *Broadcast = foldInsSequenceIntoSplat(IE))
     return Broadcast;
 
   return nullptr;