blob: 457e1a93985f7c388d22e7d2c89c47c23956221a [file] [log] [blame] [edit]
///////////////////////////////////////////////////////////////////////////////
// //
// DxilPrecisePropagatePass.cpp //
// Copyright (C) Microsoft Corporation. All rights reserved. //
// This file is distributed under the University of Illinois Open Source //
// License. See LICENSE.TXT for details. //
// //
///////////////////////////////////////////////////////////////////////////////
#include "dxc/DXIL/DxilModule.h"
#include "dxc/HLSL/DxilGenerationPass.h"
#include "dxc/HLSL/HLModule.h"
#include "dxc/HLSL/HLOperations.h"
#include "llvm/Pass.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Casting.h"
#include <unordered_set>
#include <vector>
using namespace llvm;
using namespace hlsl;
namespace {
class DxilPrecisePropagatePass : public ModulePass {
public:
static char ID; // Pass identification, replacement for typeid
explicit DxilPrecisePropagatePass() : ModulePass(ID) {}
const char *getPassName() const override { return "DXIL Precise Propagate"; }
bool runOnModule(Module &M) override {
DxilModule &dxilModule = M.GetOrCreateDxilModule();
DxilTypeSystem &typeSys = dxilModule.GetTypeSystem();
std::unordered_set<Instruction*> processedSet;
std::vector<Function*> deadList;
for (Function &F : M.functions()) {
if (HLModule::HasPreciseAttribute(&F)) {
PropagatePreciseOnFunctionUser(F, typeSys, processedSet);
deadList.emplace_back(&F);
}
}
for (Function *F : deadList)
F->eraseFromParent();
return true;
}
private:
void PropagatePreciseOnFunctionUser(
Function &F, DxilTypeSystem &typeSys,
std::unordered_set<Instruction *> &processedSet);
};
char DxilPrecisePropagatePass::ID = 0;
}
static void PropagatePreciseAttribute(Instruction *I, DxilTypeSystem &typeSys,
std::unordered_set<Instruction *> &processedSet);
static void PropagatePreciseAttributeOnOperand(
Value *V, DxilTypeSystem &typeSys, LLVMContext &Context,
std::unordered_set<Instruction *> &processedSet) {
Instruction *I = dyn_cast<Instruction>(V);
// Skip none inst.
if (!I)
return;
FPMathOperator *FPMath = dyn_cast<FPMathOperator>(I);
// Skip none FPMath
if (!FPMath)
return;
// Skip inst already marked.
if (processedSet.count(I) > 0)
return;
// TODO: skip precise on integer type, sample instruction...
processedSet.insert(I);
// Set precise fast math on those instructions that support it.
if (DxilModule::PreservesFastMathFlags(I))
DxilModule::SetPreciseFastMathFlags(I);
// Fast math not work on call, use metadata.
if (CallInst *CI = dyn_cast<CallInst>(I))
HLModule::MarkPreciseAttributeWithMetadata(CI);
PropagatePreciseAttribute(I, typeSys, processedSet);
}
static void PropagatePreciseAttributeOnPointer(
Value *Ptr, DxilTypeSystem &typeSys, LLVMContext &Context,
std::unordered_set<Instruction *> &processedSet) {
// Find all store and propagate on the val operand of store.
// For CallInst, if Ptr is used as out parameter, mark it.
for (User *U : Ptr->users()) {
Instruction *user = cast<Instruction>(U);
if (StoreInst *stInst = dyn_cast<StoreInst>(user)) {
Value *val = stInst->getValueOperand();
PropagatePreciseAttributeOnOperand(val, typeSys, Context, processedSet);
} else if (CallInst *CI = dyn_cast<CallInst>(user)) {
bool bReadOnly = true;
Function *F = CI->getCalledFunction();
const DxilFunctionAnnotation *funcAnnotation =
typeSys.GetFunctionAnnotation(F);
for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
if (Ptr != CI->getArgOperand(i))
continue;
const DxilParameterAnnotation &paramAnnotation =
funcAnnotation->GetParameterAnnotation(i);
// OutputPatch and OutputStream will be checked after scalar repl.
// Here only check out/inout
if (paramAnnotation.GetParamInputQual() == DxilParamInputQual::Out ||
paramAnnotation.GetParamInputQual() == DxilParamInputQual::Inout) {
bReadOnly = false;
break;
}
}
if (!bReadOnly)
PropagatePreciseAttributeOnOperand(CI, typeSys, Context, processedSet);
}
}
}
static void
PropagatePreciseAttribute(Instruction *I, DxilTypeSystem &typeSys,
std::unordered_set<Instruction *> &processedSet) {
LLVMContext &Context = I->getContext();
if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) {
PropagatePreciseAttributeOnPointer(AI, typeSys, Context, processedSet);
} else if (dyn_cast<CallInst>(I)) {
// Propagate every argument.
// TODO: only propagate precise argument.
for (Value *src : I->operands())
PropagatePreciseAttributeOnOperand(src, typeSys, Context, processedSet);
} else if (dyn_cast<FPMathOperator>(I)) {
// TODO: only propagate precise argument.
for (Value *src : I->operands())
PropagatePreciseAttributeOnOperand(src, typeSys, Context, processedSet);
} else if (LoadInst *ldInst = dyn_cast<LoadInst>(I)) {
Value *Ptr = ldInst->getPointerOperand();
PropagatePreciseAttributeOnPointer(Ptr, typeSys, Context, processedSet);
} else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I))
PropagatePreciseAttributeOnPointer(GEP, typeSys, Context, processedSet);
// TODO: support more case which need
}
void DxilPrecisePropagatePass::PropagatePreciseOnFunctionUser(
Function &F, DxilTypeSystem &typeSys,
std::unordered_set<Instruction *> &processedSet) {
LLVMContext &Context = F.getContext();
for (auto U = F.user_begin(), E = F.user_end(); U != E;) {
CallInst *CI = cast<CallInst>(*(U++));
Value *V = CI->getArgOperand(0);
PropagatePreciseAttributeOnOperand(V, typeSys, Context, processedSet);
CI->eraseFromParent();
}
}
ModulePass *llvm::createDxilPrecisePropagatePass() {
return new DxilPrecisePropagatePass();
}
INITIALIZE_PASS(DxilPrecisePropagatePass, "hlsl-dxil-precise", "DXIL precise attribute propagate", false, false)