blob: a18d23972bbad9bfdea7b0a956b40faf1fda97c1 [file] [log] [blame]
// Copyright 2018 The Clspv Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
#include "clspv/Option.h"
#include "Builtins.h"
#include "Passes.h"
using namespace clspv;
using namespace llvm;
#define DEBUG_TYPE "signedcomparefixup"
namespace {
cl::opt<bool> ShowSCF("show-scf", cl::init(false), cl::Hidden,
cl::desc("Show signed compare fixup"));
class SignedCompareFixupPass final : public ModulePass {
public:
static char ID;
SignedCompareFixupPass() : ModulePass(ID) {}
// Rewrite the module to avoid signed integer comparisons.
bool runOnModule(Module &M) override;
private:
// Returns true if the given predicate is a signed integer comparison.
bool IsSignedRelational(CmpInst::Predicate pred) {
switch (pred) {
case CmpInst::ICMP_SGT:
case CmpInst::ICMP_SGE:
case CmpInst::ICMP_SLT:
case CmpInst::ICMP_SLE:
return true;
default:
break;
}
return false;
}
// Replaces |call| which is a smin, smax or sclamp call with an equivalent
// instruction stream. Also adds the comparisons introduced to |work_list|.
void ReplaceBuiltin(CallInst *call, Builtins::BuiltinType type,
SmallVectorImpl<ICmpInst *> *work_list);
};
} // namespace
char SignedCompareFixupPass::ID = 0;
INITIALIZE_PASS(SignedCompareFixupPass, "SignedCompareFixupPass",
"Signed Integer Compare Fixup", false, false)
namespace clspv {
ModulePass *createSignedCompareFixupPass() {
return new SignedCompareFixupPass();
}
} // namespace clspv
namespace {
bool SignedCompareFixupPass::runOnModule(Module &M) {
bool Changed = false;
if (!clspv::Option::HackSignedCompareFixup()) {
return Changed;
}
SmallVector<Instruction *, 16> to_remove;
SmallVector<ICmpInst *, 16> work_list;
for (auto &F : M) {
for (auto &BB : F) {
for (auto &inst : BB) {
if (auto *icmp = dyn_cast<ICmpInst>(&inst)) {
if (IsSignedRelational(icmp->getPredicate())) {
work_list.push_back(icmp);
}
} else if (auto call = dyn_cast<CallInst>(&inst)) {
// Same bug exists for smin, smax and sclamp. Break those calls down
// and then replace the resulting comparisons too.
auto &info = Builtins::Lookup(call->getCalledFunction());
switch (info.getType()) {
case Builtins::kMax:
case Builtins::kMin:
case Builtins::kClamp: {
auto param_info = info.getParameter(0);
// Note that this also covers vector operands.
if (param_info.is_signed &&
param_info.type_id == llvm::Type::IntegerTyID) {
if (ShowSCF) {
outs() << "SCF: Replace " << *call << "\n";
}
ReplaceBuiltin(call, info.getType(), &work_list);
Changed = true;
to_remove.push_back(call);
}
break;
}
default:
break;
}
}
}
}
}
if (ShowSCF) {
for (auto *icmp : work_list) {
outs() << "SCF: Replace " << *icmp << "\n";
}
}
IRBuilder<> Builder(M.getContext());
// First break up any vector cases in to scalar cases.
SmallVector<ICmpInst *, 16> scalar_work_list;
for (auto *icmp : work_list) {
Changed = true;
auto *x = icmp->getOperand(0);
auto *y = icmp->getOperand(1);
auto *x_type = x->getType();
// Make some useful constants
auto *zero = ConstantInt::get(x_type, 0);
auto *one = ConstantInt::get(x_type, 1);
unsigned bit_width;
if (auto *scalar = dyn_cast<ConstantInt>(one)) {
bit_width = scalar->getBitWidth();
} else if (auto *cdv = dyn_cast<ConstantDataVector>(one)) {
bit_width = cdv->getElementAsAPInt(0).getBitWidth();
} else {
errs() << "Signed Comparison Fixup: Unhandled constant vector type "
<< *(one->getType()) << "\n";
bit_width = 1;
llvm_unreachable("Unhandled constant vector type");
}
auto *sign_bit = ConstantInt::get(x_type, uint64_t(1) << (bit_width - 1));
Builder.SetInsertPoint(icmp);
Value *replacement;
switch (icmp->getPredicate()) {
case CmpInst::ICMP_SGT: {
// Derivation of the replacement:
//
// x > y
//
// x - y > 0
//
// x - y - 1 >= 0
//
// sign(x - y - 1) == 0
auto *diff = Builder.CreateSub(x, y);
auto *diff_minus_one = Builder.CreateSub(diff, one);
auto *diff_sign = Builder.CreateAnd(diff_minus_one, sign_bit);
replacement = Builder.CreateICmpEQ(diff_sign, zero);
} break;
case CmpInst::ICMP_SGE: {
// Derivation of the replacement:
//
// x >= y
//
// x - y >= 0
//
// sign(x - y) == 0
auto *diff = Builder.CreateSub(x, y);
auto *diff_sign = Builder.CreateAnd(diff, sign_bit);
replacement = Builder.CreateICmpEQ(diff_sign, zero);
} break;
case CmpInst::ICMP_SLT: {
// Derivation of the replacement:
//
// x < y
//
// 0 < y - x
//
// 0 <= y - x - 1
//
// sign(y - x - 1) == 0
auto *diff = Builder.CreateSub(y, x);
auto *diff_minus_one = Builder.CreateSub(diff, one);
auto *diff_sign = Builder.CreateAnd(diff_minus_one, sign_bit);
replacement = Builder.CreateICmpEQ(diff_sign, zero);
} break;
case CmpInst::ICMP_SLE: {
// Derivation of the replacement:
//
// x <= y
//
// 0 <= y - x
//
// sign(y - x) == 0
auto *diff = Builder.CreateSub(y, x);
auto *diff_sign = Builder.CreateAnd(diff, sign_bit);
replacement = Builder.CreateICmpEQ(diff_sign, zero);
} break;
default:
llvm_unreachable("Unhandled integer comparison");
break;
}
icmp->replaceAllUsesWith(replacement);
}
for (auto *inst : to_remove) {
inst->eraseFromParent();
}
for (auto *inst : work_list) {
inst->eraseFromParent();
}
if (ShowSCF) {
outs() << "\n\nSCF: DONE\n";
}
return Changed;
}
void SignedCompareFixupPass::ReplaceBuiltin(
CallInst *call, Builtins::BuiltinType type,
SmallVectorImpl<ICmpInst *> *work_list) {
// The returns from CreateICmpSLT are tested because they may have been
// folded to a constant.
IRBuilder<> builder(call->getContext());
builder.SetInsertPoint(call);
Value *out = nullptr;
if (type == Builtins::kClamp) {
// clamp(x,b,c) is min(max(x, b), c). Results are undefined if b > c, so
// both comparisons can be made against x.
auto x = call->getOperand(0);
auto min = call->getOperand(1);
auto max = call->getOperand(2);
auto cmp1 = builder.CreateICmpSLT(x, min);
auto sel = builder.CreateSelect(cmp1, min, x);
auto cmp2 = builder.CreateICmpSLT(max, x);
out = builder.CreateSelect(cmp2, max, sel);
if (auto icmp = dyn_cast<ICmpInst>(cmp1))
work_list->push_back(icmp);
if (auto icmp = dyn_cast<ICmpInst>(cmp2))
work_list->push_back(icmp);
} else {
bool is_min = type == Builtins::kMin;
auto lhs = call->getOperand(0);
auto rhs = call->getOperand(1);
if (is_min)
std::swap(lhs, rhs);
auto cmp = builder.CreateICmpSLT(lhs, rhs);
out = builder.CreateSelect(cmp, call->getOperand(1), call->getOperand(0));
if (auto icmp = dyn_cast<ICmpInst>(cmp))
work_list->push_back(icmp);
}
call->replaceAllUsesWith(out);
}
} // namespace