Merging r229731:
------------------------------------------------------------------------
r229731 | sanjoy | 2015-02-18 11:32:25 -0800 (Wed, 18 Feb 2015) | 10 lines

Partial fix for bug 22589

Don't spend the entire iteration space in the scalar loop prologue if
computing the trip count overflows.  This change also gets rid of the
backedge check in the prologue loop and the extra check for
overflowing trip-count.

Differential Revision: http://reviews.llvm.org/D7715


------------------------------------------------------------------------

llvm-svn: 229757
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index f12cd61..8a32215 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -55,7 +55,7 @@
 /// - Branch around the original loop if the trip count is less
 ///   than the unroll factor.
 ///
-static void ConnectProlog(Loop *L, Value *TripCount, unsigned Count,
+static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,
                           BasicBlock *LastPrologBB, BasicBlock *PrologEnd,
                           BasicBlock *OrigPH, BasicBlock *NewPH,
                           ValueToValueMapTy &VMap, Pass *P) {
@@ -105,12 +105,19 @@
     }
   }
 
-  // Create a branch around the orignal loop, which is taken if the
-  // trip count is less than the unroll factor.
+  // Create a branch around the orignal loop, which is taken if there are no
+  // iterations remaining to be executed after running the prologue.
   Instruction *InsertPt = PrologEnd->getTerminator();
+
+  assert(Count != 0 && "nonsensical Count!");
+
+  // If BECount <u (Count - 1) then (BECount + 1) & (Count - 1) == (BECount + 1)
+  // (since Count is a power of 2).  This means %xtraiter is (BECount + 1) and
+  // and all of the iterations of this loop were executed by the prologue.  Note
+  // that if BECount <u (Count - 1) then (BECount + 1) cannot unsigned-overflow.
   Instruction *BrLoopExit =
-    new ICmpInst(InsertPt, ICmpInst::ICMP_ULT, TripCount,
-                 ConstantInt::get(TripCount->getType(), Count));
+    new ICmpInst(InsertPt, ICmpInst::ICMP_ULT, BECount,
+                 ConstantInt::get(BECount->getType(), Count - 1));
   BasicBlock *Exit = L->getUniqueExitBlock();
   assert(Exit && "Loop must have a single exit block only");
   // Split the exit to maintain loop canonicalization guarantees
@@ -292,23 +299,28 @@
 
   // Only unroll loops with a computable trip count and the trip count needs
   // to be an int value (allowing a pointer type is a TODO item)
-  const SCEV *BECount = SE->getBackedgeTakenCount(L);
-  if (isa<SCEVCouldNotCompute>(BECount) || !BECount->getType()->isIntegerTy())
+  const SCEV *BECountSC = SE->getBackedgeTakenCount(L);
+  if (isa<SCEVCouldNotCompute>(BECountSC) ||
+      !BECountSC->getType()->isIntegerTy())
     return false;
 
-  // If BECount is INT_MAX, we can't compute trip-count without overflow.
-  if (BECount->isAllOnesValue())
-    return false;
+  unsigned BEWidth = cast<IntegerType>(BECountSC->getType())->getBitWidth();
 
   // Add 1 since the backedge count doesn't include the first loop iteration
   const SCEV *TripCountSC =
-    SE->getAddExpr(BECount, SE->getConstant(BECount->getType(), 1));
+    SE->getAddExpr(BECountSC, SE->getConstant(BECountSC->getType(), 1));
   if (isa<SCEVCouldNotCompute>(TripCountSC))
     return false;
 
   // We only handle cases when the unroll factor is a power of 2.
   // Count is the loop unroll factor, the number of extra copies added + 1.
-  if ((Count & (Count-1)) != 0)
+  if (!isPowerOf2_32(Count))
+    return false;
+
+  // This constraint lets us deal with an overflowing trip count easily; see the
+  // comment on ModVal below.  This check is equivalent to `Log2(Count) <
+  // BEWidth`.
+  if (static_cast<uint64_t>(Count) > (1ULL << BEWidth))
     return false;
 
   // If this loop is nested, then the loop unroller changes the code in
@@ -330,16 +342,23 @@
   SCEVExpander Expander(*SE, "loop-unroll");
   Value *TripCount = Expander.expandCodeFor(TripCountSC, TripCountSC->getType(),
                                             PreHeaderBR);
+  Value *BECount = Expander.expandCodeFor(BECountSC, BECountSC->getType(),
+                                          PreHeaderBR);
 
   IRBuilder<> B(PreHeaderBR);
   Value *ModVal = B.CreateAnd(TripCount, Count - 1, "xtraiter");
 
-  // Check if for no extra iterations, then jump to cloned/unrolled loop.
-  // We have to check that the trip count computation didn't overflow when
-  // adding one to the backedge taken count.
-  Value *LCmp = B.CreateIsNotNull(ModVal, "lcmp.mod");
-  Value *OverflowCheck = B.CreateIsNull(TripCount, "lcmp.overflow");
-  Value *BranchVal = B.CreateOr(OverflowCheck, LCmp, "lcmp.or");
+  // If ModVal is zero, we know that either
+  //  1. there are no iteration to be run in the prologue loop
+  // OR
+  //  2. the addition computing TripCount overflowed
+  //
+  // If (2) is true, we know that TripCount really is (1 << BEWidth) and so the
+  // number of iterations that remain to be run in the original loop is a
+  // multiple Count == (1 << Log2(Count)) because Log2(Count) <= BEWidth (we
+  // explicitly check this above).
+
+  Value *BranchVal = B.CreateIsNotNull(ModVal, "lcmp.mod");
 
   // Branch to either the extra iterations or the cloned/unrolled loop
   // We will fix up the true branch label when adding loop body copies
@@ -362,10 +381,7 @@
   std::vector<BasicBlock *> NewBlocks;
   ValueToValueMapTy VMap;
 
-  // If unroll count is 2 and we can't overflow in tripcount computation (which
-  // is BECount + 1), then we don't need a loop for prologue, and we can unroll
-  // it. We can be sure that we don't overflow only if tripcount is a constant.
-  bool UnrollPrologue = (Count == 2 && isa<ConstantInt>(TripCount));
+  bool UnrollPrologue = Count == 2;
 
   // Clone all the basic blocks in the loop. If Count is 2, we don't clone
   // the loop, otherwise we create a cloned loop to execute the extra
@@ -391,7 +407,7 @@
   // Connect the prolog code to the original loop and update the
   // PHI functions.
   BasicBlock *LastLoopBB = cast<BasicBlock>(VMap[Latch]);
-  ConnectProlog(L, TripCount, Count, LastLoopBB, PEnd, PH, NewPH, VMap,
+  ConnectProlog(L, BECount, Count, LastLoopBB, PEnd, PH, NewPH, VMap,
                 LPM->getAsPass());
   NumRuntimeUnrolled++;
   return true;
diff --git a/llvm/test/Transforms/LoopUnroll/runtime-loop.ll b/llvm/test/Transforms/LoopUnroll/runtime-loop.ll
index 3a8777b..80571ec 100644
--- a/llvm/test/Transforms/LoopUnroll/runtime-loop.ll
+++ b/llvm/test/Transforms/LoopUnroll/runtime-loop.ll
@@ -4,9 +4,7 @@
 
 ; CHECK: %xtraiter = and i32 %n
 ; CHECK:  %lcmp.mod = icmp ne i32 %xtraiter, 0
-; CHECK:  %lcmp.overflow = icmp eq i32 %n, 0
-; CHECK:  %lcmp.or = or i1 %lcmp.overflow, %lcmp.mod
-; CHECK:  br i1 %lcmp.or, label %for.body.prol, label %for.body.preheader.split
+; CHECK:  br i1 %lcmp.mod, label %for.body.prol, label %for.body.preheader.split
 
 ; CHECK: for.body.prol:
 ; CHECK: %indvars.iv.prol = phi i64 [ %indvars.iv.next.prol, %for.body.prol ], [ 0, %for.body.preheader ]
diff --git a/llvm/test/Transforms/LoopUnroll/runtime-loop1.ll b/llvm/test/Transforms/LoopUnroll/runtime-loop1.ll
index 38b4f32..5ff75e3 100644
--- a/llvm/test/Transforms/LoopUnroll/runtime-loop1.ll
+++ b/llvm/test/Transforms/LoopUnroll/runtime-loop1.ll
@@ -3,7 +3,7 @@
 ; This tests that setting the unroll count works
 
 ; CHECK: for.body.prol:
-; CHECK: br i1 %prol.iter.cmp, label %for.body.prol, label %for.body.preheader.split
+; CHECK: br label %for.body.preheader.split
 ; CHECK: for.body:
 ; CHECK: br i1 %exitcond.1, label %for.end.loopexit.unr-lcssa, label %for.body
 ; CHECK-NOT: br i1 %exitcond.4, label %for.end.loopexit{{.*}}, label %for.body
diff --git a/llvm/test/Transforms/LoopUnroll/tripcount-overflow.ll b/llvm/test/Transforms/LoopUnroll/tripcount-overflow.ll
index d593685..052077c 100644
--- a/llvm/test/Transforms/LoopUnroll/tripcount-overflow.ll
+++ b/llvm/test/Transforms/LoopUnroll/tripcount-overflow.ll
@@ -1,19 +1,28 @@
 ; RUN: opt < %s -S -unroll-runtime -unroll-count=2 -loop-unroll | FileCheck %s
 target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
 
-; When prologue is fully unrolled, the branch on its end is unconditional.
-; Unrolling it is illegal if we can't prove that trip-count+1 doesn't overflow,
-; like in this example, where it comes from an argument.
-;
-; This test is based on an example from here:
-; http://stackoverflow.com/questions/23838661/why-is-clang-optimizing-this-code-out
-;
+; This test case documents how runtime loop unrolling handles the case
+; when the backedge-count is -1.
+
+; If %N, the backedge-taken count, is -1 then %0 unsigned-overflows
+; and is 0.  %xtraiter too is 0, signifying that the total trip-count
+; is divisible by 2.  The prologue then branches to the unrolled loop
+; and executes the 2^32 iterations there, in groups of 2.
+
+
+; CHECK: entry:
+; CHECK-NEXT: %0 = add i32 %N, 1
+; CHECK-NEXT: %xtraiter = and i32 %0, 1
+; CHECK-NEXT: %lcmp.mod = icmp ne i32 %xtraiter, 0
+; CHECK-NEXT: br i1 %lcmp.mod, label %while.body.prol, label %entry.split
+
 ; CHECK: while.body.prol:
-; CHECK: br i1
+; CHECK: br label %entry.split
+
 ; CHECK: entry.split:
 
 ; Function Attrs: nounwind readnone ssp uwtable
-define i32 @foo(i32 %N) #0 {
+define i32 @foo(i32 %N) {
 entry:
   br label %while.body
 
@@ -26,5 +35,3 @@
 while.end:                                        ; preds = %while.body
   ret i32 %i
 }
-
-attributes #0 = { nounwind readnone ssp uwtable "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }