blob: 5bf58f485897a053fa126828e8e2a70393b6dc0d [file] [log] [blame] [edit]
#include "StateFunctionTransform.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ValueMap.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"
#include "FunctionBuilder.h"
#include "LLVMUtils.h"
#include "LiveValues.h"
#include "Reducibility.h"
#define DBGS dbgs
//#define DBGS errs
using namespace llvm;
static const char *CALL_INDIRECT_NAME = "\x1?Fallback_CallIndirect@@YAXH@Z";
static const char *SET_PENDING_ATTR_PREFIX = "\x1?Fallback_SetPendingAttr@@";
// Create a string with printf-like arguments
inline std::string stringf(const char *fmt, ...) {
va_list args;
va_start(args, fmt);
#ifdef WIN32
int size = _vscprintf(fmt, args);
#else
int size = vsnprintf(0, 0, fmt, args);
#endif
va_end(args);
std::string ret;
if (size > 0) {
ret.resize(size);
va_start(args, fmt);
vsnprintf(const_cast<char *>(ret.data()), size + 1, fmt, args);
va_end(args);
}
return ret;
}
// Remove ELF mangling
static std::string cleanName(StringRef name) {
if (!name.startswith("\x1?"))
return name;
size_t pos = name.find("@@");
if (pos == name.npos)
return name;
std::string newName = name.substr(2, pos - 2);
return newName;
}
// Utility to append the suffix to the name of the value, but returns
// an empty string if name is empty. This is to avoid names like ".ptr".
static std::string addSuffix(StringRef valueName, StringRef suffix) {
if (!valueName.empty()) {
if (valueName.back() == '.' && suffix.front() == '.') // avoid double dots
return (valueName + suffix.substr(1)).str();
else
return (valueName + suffix).str();
} else
return valueName.str();
}
// Remove suffix from name.
static std::string stripSuffix(StringRef name, StringRef suffix) {
size_t pos = name.rfind(suffix);
if (pos != name.npos)
return name.substr(0, pos).str();
else
return name.str();
}
// Insert str before the final "." in filename.
static std::string insertBeforeExtension(const std::string &filename,
const std::string &str) {
std::string ret = filename;
size_t pos = filename.rfind('.');
if (pos != std::string::npos)
ret.insert(pos, str);
else
ret += str;
return ret;
}
// Inserts <functionName>-<id>-<suffix> before the extension in baseName
static std::string createDumpPath(const std::string &baseName, unsigned id,
const std::string &suffix,
const std::string &functionName) {
std::string s;
if (!functionName.empty())
s = "-" + functionName;
s += stringf("-%02d-", id) + suffix;
return insertBeforeExtension(baseName, s);
}
// Return byte offset aligned to the alignment required by inst.
static uint64_t align(uint64_t offset, Instruction *inst, DataLayout &DL) {
unsigned alignment = 0;
if (AllocaInst *ai = dyn_cast<AllocaInst>(inst))
alignment = ai->getAlignment();
if (alignment == 0)
alignment = DL.getPrefTypeAlignment(inst->getType());
return RoundUpToAlignment(offset, alignment);
}
template <class T> // T can be Value* or Instruction*
T createCastForStack(T ptr, llvm::Type *targetPtrElemType,
llvm::Instruction *insertBefore) {
llvm::PointerType *requiredType = llvm::PointerType::get(
targetPtrElemType, ptr->getType()->getPointerAddressSpace());
if (ptr->getType() == requiredType)
return ptr;
return new llvm::BitCastInst(ptr, requiredType, ptr->getName(), insertBefore);
}
static Value *createCastToInt(Value *val, Instruction *insertBefore) {
Type *i32Ty = Type::getInt32Ty(val->getContext());
if (val->getType() == i32Ty)
return val;
if (val->getType() == Type::getInt1Ty(val->getContext()))
return new ZExtInst(val, i32Ty, addSuffix(val->getName(), ".int"),
insertBefore);
Value *intVal = new BitCastInst(val, i32Ty, addSuffix(val->getName(), ".int"),
insertBefore);
return intVal;
}
static Value *createCastFromInt(Value *intVal, Type *ty,
Instruction *insertBefore) {
Type *i32Ty = Type::getInt32Ty(intVal->getContext());
if (ty == i32Ty)
return intVal;
std::string name = intVal->getName();
intVal->setName(addSuffix(name, ".int"));
// Create boolean with compare
if (ty == Type::getInt1Ty(intVal->getContext()))
return new ICmpInst(insertBefore, CmpInst::ICMP_SGT, intVal,
makeInt32(0, intVal->getContext()), name);
return new BitCastInst(intVal, ty, name, insertBefore);
}
// Gives every value in the given function a name. This can aid in debugging.
static void dbgNameUnnamedVals(Function *func) {
Type *voidTy = Type::getVoidTy(func->getContext());
for (auto &I : inst_range(func)) {
if (!I.hasName() && I.getType() != voidTy)
I.setName("v"); // LLVM will uniquify the name by adding a numeric suffix
}
}
// Returns an iterator for the instruction after the last alloca in the entry
// block (assuming that allocas are at the top of the entry block).
static BasicBlock::iterator afterEntryBlockAllocas(Function *function) {
BasicBlock::iterator insertBefore = function->getEntryBlock().begin();
while (isa<AllocaInst>(insertBefore))
++insertBefore;
return insertBefore;
}
// Return all the blocks reachable from entryBlock.
static BasicBlockVector getReachableBlocks(BasicBlock *entryBlock) {
BasicBlockVector blocks;
std::deque<BasicBlock *> stack = {entryBlock};
::BasicBlockSet visited = {entryBlock};
while (!stack.empty()) {
BasicBlock *block = stack.front();
stack.pop_front();
blocks.push_back(block);
TerminatorInst *termInst = block->getTerminator();
for (unsigned int succ = 0, succEnd = termInst->getNumSuccessors();
succ != succEnd; ++succ) {
BasicBlock *succBlock = termInst->getSuccessor(succ);
if (visited.insert(succBlock).second)
stack.push_front(succBlock);
}
}
return blocks;
}
// Creates a new function with the same arguments and attributes as oldFunction
static Function *cloneFunctionPrototype(const Function *oldFunction,
ValueToValueMapTy &VMap) {
std::vector<Type *> argTypes;
for (auto I = oldFunction->arg_begin(), E = oldFunction->arg_end(); I != E;
++I)
argTypes.push_back(I->getType());
FunctionType *FTy =
FunctionType::get(oldFunction->getFunctionType()->getReturnType(),
argTypes, oldFunction->getFunctionType()->isVarArg());
Function *newFunction =
Function::Create(FTy, oldFunction->getLinkage(), oldFunction->getName());
Function::arg_iterator destI = newFunction->arg_begin();
for (auto I = oldFunction->arg_begin(), E = oldFunction->arg_end(); I != E;
++I, ++destI) {
destI->setName(I->getName());
VMap[I] = destI;
}
AttributeSet oldAttrs = oldFunction->getAttributes();
for (auto I = oldFunction->arg_begin(), E = oldFunction->arg_end(); I != E;
++I) {
if (Argument *Anew = dyn_cast<Argument>(VMap[I])) {
AttributeSet attrs = oldAttrs.getParamAttributes(I->getArgNo() + 1);
if (attrs.getNumSlots() > 0)
Anew->addAttr(attrs);
}
}
newFunction->setAttributes(newFunction->getAttributes().addAttributes(
newFunction->getContext(), AttributeSet::ReturnIndex,
oldAttrs.getRetAttributes()));
newFunction->setAttributes(newFunction->getAttributes().addAttributes(
newFunction->getContext(), AttributeSet::FunctionIndex,
oldAttrs.getFnAttributes()));
return newFunction;
}
// Creates a new function by cloning blocks reachable from entryBlock
static Function *cloneBlocksReachableFrom(BasicBlock *entryBlock,
ValueToValueMapTy &VMap) {
Function *oldFunction = entryBlock->getParent();
Function *newFunction = cloneFunctionPrototype(oldFunction, VMap);
// Insert a clone of the entry block into the function.
BasicBlock *newEntry = CloneBasicBlock(entryBlock, VMap, "", newFunction);
VMap[entryBlock] = newEntry;
// Clone all other blocks.
BasicBlockVector blocks = getReachableBlocks(entryBlock);
for (auto block : blocks) {
if (block == entryBlock)
continue;
BasicBlock *clonedBlock = CloneBasicBlock(block, VMap, "", newFunction);
VMap[block] = clonedBlock;
}
// Remap new instructions to reference blocks and instructions of the new
// function.
for (auto block : blocks) {
auto clonedBlock = cast<BasicBlock>(VMap[block]);
for (BasicBlock::iterator I = clonedBlock->begin(); I != clonedBlock->end();
++I) {
RemapInstruction(I, VMap,
RF_NoModuleLevelChanges | RF_IgnoreMissingEntries);
}
}
// Remove phi operands incoming from blocks that are not present in the new
// function anymore.
for (auto &block : *newFunction) {
PHINode *firstPHI = dyn_cast<PHINode>(block.begin());
if (firstPHI == nullptr)
continue; // phi instructions only at beginning
// Create set of actual predecessors
BasicBlockSet preds(pred_begin(&block), pred_end(&block));
if (preds.size() == firstPHI->getNumIncomingValues())
continue;
// Remove phi incoming blocks not in preds
for (auto iter = block.begin(); isa<PHINode>(iter); ++iter) {
std::vector<unsigned int> toRemove;
PHINode *phi = cast<PHINode>(iter);
for (unsigned int op = 0, opEnd = phi->getNumIncomingValues();
op != opEnd; ++op) {
BasicBlock *pred = phi->getIncomingBlock(op);
if (preds.count(pred) == 0) {
toRemove.push_back(op);
}
}
for (auto I = toRemove.rbegin(), E = toRemove.rend(); I != E; ++I)
phi->removeIncomingValue(*I, false);
}
}
return newFunction;
}
// Replace and remove calls to func with val
static void replaceValAndRemoveUnusedDummyFunc(Value *oldVal, Value *newVal,
Function *caller) {
CallInst *call = dyn_cast<CallInst>(oldVal);
assert(call != nullptr && "Must be a call");
Function *func = call->getCalledFunction();
for (CallInst *CI : getCallsToFunction(func, caller)) {
CI->replaceAllUsesWith(newVal);
CI->eraseFromParent();
}
if (func->getNumUses() == 0)
func->eraseFromParent();
}
// Get the integer value of val. If val is not a ConstantInt return false.
static bool getConstantValue(int &constant, const Value *val) {
const ConstantInt *CI = dyn_cast<ConstantInt>(val);
if (!CI)
return false;
if (CI->getBitWidth() > 32)
return false;
constant = static_cast<int>(CI->getSExtValue());
return true;
}
static int getConstantValue(const Value *val) {
const ConstantInt *CI = dyn_cast<ConstantInt>(val);
assert(CI && CI->getBitWidth() <= 32);
return static_cast<int>(CI->getSExtValue());
}
struct StoreInfo {
Function *stackIntPtrFunc;
Value *runtimeDataArg;
Value *baseOffset;
Instruction *insertBefore;
Value *val;
std::vector<Value *> idxList;
};
// Takes the offset at which to store the next value.
// Returns the next available offset.
static int store(int offset, StoreInfo &SI, Type *ty) {
if (StructType *STy = dyn_cast<StructType>(ty)) {
SI.idxList.push_back(nullptr);
int elIdx = 0;
for (auto &elTy : STy->elements()) {
SI.idxList.back() = makeInt32(elIdx++, ty->getContext());
offset = store(offset, SI, elTy);
}
SI.idxList.pop_back();
} else if (ArrayType *ATy = dyn_cast<ArrayType>(ty)) {
Type *elTy = ATy->getArrayElementType();
SI.idxList.push_back(nullptr);
for (int elIdx = 0; elIdx < (int)ATy->getArrayNumElements(); ++elIdx) {
SI.idxList.back() = makeInt32(elIdx, ty->getContext());
offset = store(offset, SI, elTy);
}
SI.idxList.pop_back();
} else if (PointerType *PTy = dyn_cast<PointerType>(ty)) {
SI.idxList.push_back(makeInt32(0, ty->getContext()));
offset = store(offset, SI, PTy->getPointerElementType());
SI.idxList.pop_back();
} else {
Value *val = SI.val;
if (!SI.idxList.empty()) {
Value *gep = GetElementPtrInst::CreateInBounds(SI.val, SI.idxList, "",
SI.insertBefore);
val = new LoadInst(gep, "", SI.insertBefore);
}
if (VectorType *VTy = dyn_cast<VectorType>(ty)) {
std::vector<Value *> idxList = std::move(SI.idxList);
Type *elTy = VTy->getVectorElementType();
for (int elIdx = 0; elIdx < (int)VTy->getVectorNumElements(); ++elIdx) {
Value *idxVal = makeInt32(elIdx, ty->getContext());
Value *el =
ExtractElementInst::Create(val, idxVal, "", SI.insertBefore);
SI.val = el;
offset = store(offset, SI, elTy);
}
SI.idxList = std::move(idxList);
} else {
Value *idxVal = makeInt32(offset, val->getContext());
Value *intVal = createCastToInt(val, SI.insertBefore);
Value *intPtr = CallInst::Create(
SI.stackIntPtrFunc, {SI.runtimeDataArg, SI.baseOffset, idxVal},
addSuffix(val->getName(), ".ptr"), SI.insertBefore);
new StoreInst(intVal, intPtr, SI.insertBefore);
offset += 1;
}
}
return offset;
}
// Store value to the stack at given baseOffset + offset. Will flatten
// aggregates and vectors. Returns the offset where writing left off. For
// pointer vals stores what is pointed to.
static int store(Value *val, Function *stackIntPtrFunc, Value *runtimeDataArg,
Value *baseOffset, int offset, Instruction *insertBefore) {
StoreInfo SI;
SI.stackIntPtrFunc = stackIntPtrFunc;
SI.runtimeDataArg = runtimeDataArg;
SI.baseOffset = baseOffset;
SI.insertBefore = insertBefore;
SI.val = val;
return store(offset, SI, val->getType());
}
static Value *load(llvm::Function *m_stackIntPtrFunc, Value *runtimeDataArg,
Value *offset, Value *idx, const std::string &name, Type *ty,
Instruction *insertBefore) {
if (VectorType *VTy = dyn_cast<VectorType>(ty)) {
LLVMContext &C = ty->getContext();
int baseIdx = getConstantValue(idx);
Type *elTy = VTy->getVectorElementType();
Value *vec = UndefValue::get(VTy);
for (int i = 0; i < (int)VTy->getVectorNumElements(); ++i) {
std::string elName = stringf("el%d.", i);
Value *intPtr =
CallInst::Create(m_stackIntPtrFunc,
{runtimeDataArg, offset, makeInt32(baseIdx + i, C)},
elName + "ptr", insertBefore);
Value *intEl = new LoadInst(intPtr, elName, insertBefore);
Value *el = createCastFromInt(intEl, elTy, insertBefore);
vec = InsertElementInst::Create(vec, el, makeInt32(i, C), "tmpvec",
insertBefore);
}
vec->setName(name);
return vec;
} else {
Value *intPtr =
CallInst::Create(m_stackIntPtrFunc, {runtimeDataArg, offset, idx},
addSuffix(name, ".ptr"), insertBefore);
Value *intVal = new LoadInst(intPtr, name, insertBefore);
Value *val = createCastFromInt(intVal, ty, insertBefore);
return val;
}
}
static void reg2Mem(DenseMap<Instruction *, AllocaInst *> &valToAlloca,
DenseMap<AllocaInst *, Instruction *> &allocaToVal,
Instruction *inst) {
if (valToAlloca.count(inst))
return;
// Convert the value to an alloca
AllocaInst *allocaPtr = DemoteRegToStack(*inst, false);
if (allocaPtr) {
valToAlloca[inst] = allocaPtr;
allocaToVal[allocaPtr] = inst;
}
}
// Utility class for rematerializing values at a callsite
class Rematerializer {
public:
Rematerializer(DenseMap<AllocaInst *, Instruction *> &allocaToVal,
const InstructionSetVector &liveHere,
const std::set<Value *> &resources)
: m_allocaToVal(allocaToVal), m_liveHere(liveHere),
m_resources(resources) {}
// Returns true if inst can be rematerialized.
bool canRematerialize(Instruction *inst) {
if (CallInst *call = dyn_cast<CallInst>(inst)) {
StringRef funcName = call->getCalledFunction()->getName();
if (funcName.startswith("dummyStackFrameSize"))
return true;
if (funcName.startswith("stack.ptr"))
return true;
if (funcName.startswith("stack.load"))
return true;
if (funcName.startswith("dx.op.createHandle"))
return true;
} else if (LoadInst *load = dyn_cast<LoadInst>(inst)) {
Value *op = load->getOperand(0);
if (GetElementPtrInst *gep =
dyn_cast<GetElementPtrInst>(op)) // for descriptor tables
op = gep->getOperand(0);
if (m_resources.count(op))
return true;
} else if (GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(inst)) {
assert(gep->hasAllConstantIndices() &&
"Unhandled non-constant index"); // Should have been changed to
// stack.ptr
return true;
}
return false;
}
// Rematerialize the given instruction and its dependency graph, adding
// any nonrematerializable values that are live in the function, but not
// at this callsite to the work list to insure that their values are restored.
Instruction *rematerialize(Instruction *inst,
std::vector<Instruction *> workList,
Instruction *insertBefore, int depth = 0) {
// Signal if we hit a complex case. Deep rematerialization needs more
// analysis. To make this robust we would need to make it possible to run
// the current value through the live value handling pipeline: figure out
// where it is live, reg2mem, save/restore at appropriate callsites, etc.
assert(depth < 8);
// Reuse an already rematerialized value?
auto it = m_rematMap.find(inst);
if (it != m_rematMap.end())
return it->second;
// Handle allocas
if (AllocaInst *alloc = dyn_cast<AllocaInst>(inst)) {
assert(depth >
0); // Should only be an operand to another rematerialized value
auto it = m_allocaToVal.find(alloc);
if (it != m_allocaToVal.end()) // Is it a value that is live at some
// callsite (and reg2mem'd)?
{
Instruction *val = it->second;
if (canRematerialize(val)) {
// Rematerialize here and store to the alloca. We may have already
// rematerialized a load from the alloca. Any future uses will use the
// rematerialized value directly.
Instruction *remat =
rematerialize(val, workList, insertBefore, depth + 1);
new StoreInst(remat, alloc, insertBefore);
} else {
// Value has to be restored, but it rematerialization may have
// extended the liveness of this value to this callsite. Make sure it
// gets restored.
if (!m_liveHere.count(val))
workList.push_back(val);
}
}
// Allocas are not cloned.
return inst;
}
Instruction *clone = inst->clone();
clone->setName(addSuffix(inst->getName(), ".remat"));
for (unsigned i = 0; i < inst->getNumOperands(); ++i) {
Value *op = inst->getOperand(i);
if (Instruction *opInst = dyn_cast<Instruction>(op))
clone->setOperand(
i, rematerialize(opInst, workList, insertBefore, depth + 1));
else
clone->setOperand(i, op);
}
clone->insertBefore(
insertBefore); // insert after any instructions cloned for operands
m_rematMap[inst] = clone;
return clone;
}
Instruction *getRematerializedValueFor(Instruction *val) {
auto it = m_rematMap.find(val);
if (it != m_rematMap.end())
return it->second;
else
return nullptr;
}
private:
DenseMap<Instruction *, Instruction *>
m_rematMap; // Map instructions to their rematerialized counterparts
DenseMap<AllocaInst *, Instruction *>
&m_allocaToVal; // Map allocas for reg2mem'd live values back to the value
const InstructionSetVector &m_liveHere; // Values live at this callsite
const std::set<Value *>
&m_resources; // Values for resources like SRVs, UAVs, etc.
};
StateFunctionTransform::StateFunctionTransform(
Function *func, const std::vector<std::string> &candidateFuncNames,
Type *runtimeDataArgTy)
: m_function(func), m_candidateFuncNames(candidateFuncNames),
m_runtimeDataArgTy(runtimeDataArgTy) {
m_functionName = cleanName(m_function->getName());
auto it = std::find(m_candidateFuncNames.begin(), m_candidateFuncNames.end(),
m_functionName);
assert(it != m_candidateFuncNames.end());
m_functionIdx = it - m_candidateFuncNames.begin();
}
void StateFunctionTransform::setAttributeSize(int size) {
m_attributeSizeInBytes = size;
}
void StateFunctionTransform::setParameterInfo(
const std::vector<ParameterSemanticType> &paramTypes,
bool useCommittedAttr) {
m_paramTypes = paramTypes;
m_useCommittedAttr = useCommittedAttr;
}
void StateFunctionTransform::setResourceGlobals(
const std::set<llvm::Value *> &resources) {
m_resources = &resources;
}
Function *
StateFunctionTransform::createDummyRuntimeDataArgFunc(Module *mod,
Type *runtimeDataArgTy) {
return FunctionBuilder(mod, "dummyRuntimeDataArg")
.type(runtimeDataArgTy)
.build();
}
void StateFunctionTransform::setVerbose(bool val) { m_verbose = val; }
void StateFunctionTransform::setDumpFilename(const std::string &dumpFilename) {
m_dumpFilename = dumpFilename;
}
void StateFunctionTransform::run(std::vector<Function *> &stateFunctions,
unsigned int &shaderStackSize) {
printFunction("Initial");
init();
printFunction("AfterInit");
changeCallingConvention();
printFunction("AfterCallingConvention");
preserveLiveValuesAcrossCallsites(shaderStackSize);
printFunction("AfterPreserveLiveValues");
createSubstateFunctions(stateFunctions);
printFunctions(stateFunctions, "AfterSubstateFunctions");
lowerStackFuncs();
printFunctions(stateFunctions, "AfterLowerStackFuncs");
}
void StateFunctionTransform::finalizeStateIds(
llvm::Module *mod, const std::vector<int> &candidateFuncEntryStateIds) {
LLVMContext &context = mod->getContext();
Function *func = mod->getFunction("dummyStateId");
if (!func)
return;
std::vector<Instruction *> toRemove;
for (User *U : func->users()) {
CallInst *call = dyn_cast<CallInst>(U);
if (!call)
continue;
int functionIdx = 0;
int substate = 0;
getConstantValue(functionIdx, call->getArgOperand(0));
getConstantValue(substate, call->getArgOperand(1));
int stateId = candidateFuncEntryStateIds[functionIdx] + substate;
call->replaceAllUsesWith(makeInt32(stateId, context));
toRemove.push_back(call);
}
for (Instruction *v : toRemove)
v->eraseFromParent();
func->eraseFromParent();
}
void StateFunctionTransform::init() {
Module *mod = m_function->getParent();
m_function->setName(cleanName(m_function->getName()));
// Run preparatory passes
runPasses(m_function, {// createBreakCriticalEdgesPass(),
// createLoopSimplifyPass(),
// createLCSSAPass(),
createPromoteMemoryToRegisterPass()});
// Make debugging a little easier by giving things names
dbgNameUnnamedVals(m_function);
findCallSitesIntrinsicsAndReturns();
// Create a bunch of functions that we are going to need
m_stackIntPtrFunc = FunctionBuilder(mod, "stackIntPtr")
.i32Ptr()
.type(m_runtimeDataArgTy, "runtimeData")
.i32("baseOffset")
.i32("offset")
.build();
Instruction *insertBefore = afterEntryBlockAllocas(m_function);
Function *runtimeDataArgFunc =
createDummyRuntimeDataArgFunc(mod, m_runtimeDataArgTy);
m_runtimeDataArg =
CallInst::Create(runtimeDataArgFunc, "runtimeData", insertBefore);
Function *stackFrameSizeFunc =
FunctionBuilder(mod, "dummyStackFrameSize").i32().build();
m_stackFrameSizeVal =
CallInst::Create(stackFrameSizeFunc, "stackFrame.size", insertBefore);
// TODO only create the values that are actually needed
Function *payloadOffsetFunc = FunctionBuilder(mod, "payloadOffset")
.i32()
.type(m_runtimeDataArgTy, "runtimeData")
.build();
m_payloadOffset = CallInst::Create(payloadOffsetFunc, {m_runtimeDataArg},
"payload.offset", insertBefore);
Function *committedAttrOffsetFunc =
FunctionBuilder(mod, "committedAttrOffset")
.i32()
.type(m_runtimeDataArgTy, "runtimeData")
.build();
m_committedAttrOffset =
CallInst::Create(committedAttrOffsetFunc, {m_runtimeDataArg},
"committedAttr.offset", insertBefore);
Function *pendingAttrOffsetFunc = FunctionBuilder(mod, "pendingAttrOffset")
.i32()
.type(m_runtimeDataArgTy, "runtimeData")
.build();
m_pendingAttrOffset =
CallInst::Create(pendingAttrOffsetFunc, {m_runtimeDataArg},
"pendingAttr.offset", insertBefore);
Function *stackFrameOffsetFunc = FunctionBuilder(mod, "stackFrameOffset")
.i32()
.type(m_runtimeDataArgTy, "runtimeData")
.build();
m_stackFrameOffset =
CallInst::Create(stackFrameOffsetFunc, {m_runtimeDataArg},
"stackFrame.offset", insertBefore);
// lower SetPendingAttr() now
for (CallInst *call : m_setPendingAttrCalls) {
// Get the current pending attribute offset. It can change when a hit is
// committed
Instruction *insertBefore = call;
Value *currentPendingAttrOffset =
CallInst::Create(pendingAttrOffsetFunc, {m_runtimeDataArg},
"cur.pendingAttr.offset", insertBefore);
Value *attr = call->getArgOperand(0);
createStackStore(currentPendingAttrOffset, attr, 0, insertBefore);
call->eraseFromParent();
}
}
void StateFunctionTransform::findCallSitesIntrinsicsAndReturns() {
// Create a map for log N lookup
std::map<std::string, int> candidateFuncMap;
for (int i = 0; i < (int)m_candidateFuncNames.size(); ++i)
candidateFuncMap[m_candidateFuncNames[i]] = i;
for (auto &I : inst_range(m_function)) {
if (CallInst *call = dyn_cast<CallInst>(&I)) {
StringRef calledFuncName = call->getCalledFunction()->getName();
if (calledFuncName.startswith(SET_PENDING_ATTR_PREFIX))
m_setPendingAttrCalls.push_back(call);
else if (calledFuncName.startswith("movePayloadToStack"))
m_movePayloadToStackCalls.push_back(call);
else if (calledFuncName == CALL_INDIRECT_NAME)
m_callSites.push_back(call);
else {
auto it = candidateFuncMap.find(cleanName(calledFuncName));
if (it == candidateFuncMap.end())
continue;
assert(call->getCalledFunction()->getReturnType() ==
Type::getVoidTy(call->getContext()) &&
"Continuations with returns not supported");
m_callSites.push_back(call);
m_callSiteFunctionIdx.push_back(it->second);
}
} else if (ReturnInst *ret = dyn_cast<ReturnInst>(&I)) {
m_returns.push_back(ret);
}
}
}
void StateFunctionTransform::changeCallingConvention() {
if (!m_callSites.empty() || m_attributeSizeInBytes >= 0)
allocateStackFrame();
if (m_attributeSizeInBytes >= 0)
allocateTraceFrame();
createArgFrames();
changeFunctionSignature();
}
static bool isCallToStackPtr(Value *inst) {
CallInst *call = dyn_cast<CallInst>(inst);
if (call && call->getCalledFunction()->getName().startswith("stack.ptr"))
return true;
return false;
}
static void extendAllocaLifetimes(LiveValues &lv) {
for (Instruction *inst : lv.getAllLiveValues()) {
if (!inst->getType()->isPointerTy())
continue;
if (isa<AllocaInst>(inst) || isCallToStackPtr(inst))
continue;
GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(inst);
assert(gep && "Unhandled live pointer");
Value *ptr = gep->getPointerOperand();
if (isCallToStackPtr(ptr))
continue;
AllocaInst *alloc = dyn_cast<AllocaInst>(gep->getPointerOperand());
assert(alloc && "GEP of non-alloca pointer");
// TODO: We need to set indices of the uses of the gep, not the gep itself
const LiveValues::Indices *gepIndices = lv.getIndicesWhereLive(gep);
const LiveValues::Indices *allocIndices = lv.getIndicesWhereLive(alloc);
if (!allocIndices || *allocIndices != *gepIndices)
lv.setIndicesWhereLive(alloc, gepIndices);
}
}
void StateFunctionTransform::preserveLiveValuesAcrossCallsites(
unsigned int &shaderStackSize) {
if (m_callSites.empty()) {
// No stack frame. Nothing to do.
rewriteDummyStackSize(0);
return;
}
SetVector<Instruction *> stackOffsets;
stackOffsets.insert(m_stackFrameOffset);
if (m_payloadOffset && !m_payloadOffset->user_empty())
stackOffsets.insert(m_payloadOffset);
if (m_committedAttrOffset && !m_committedAttrOffset->user_empty())
stackOffsets.insert(m_committedAttrOffset);
if (m_pendingAttrOffset && !m_pendingAttrOffset->user_empty())
stackOffsets.insert(m_pendingAttrOffset);
// Do liveness analysis
ArrayRef<Instruction *> instructions((Instruction **)m_callSites.data(),
m_callSites.size());
LiveValues lv(instructions);
lv.run();
// Make sure alloca lifetimes match their uses
extendAllocaLifetimes(lv);
// Make sure stack offsets get included
for (auto o : stackOffsets)
lv.setLiveAtAllIndices(o, true);
// Add payload allocas, if any
for (CallInst *call : m_movePayloadToStackCalls) {
if (AllocaInst *payloadAlloca =
dyn_cast<AllocaInst>(call->getArgOperand(0)))
lv.setLiveAtAllIndices(payloadAlloca, true);
}
printSet(lv.getAllLiveValues(), "live values");
//
// Carve up the stack frame.
//
uint64_t offsetInBytes = 0;
// ... argument frame
offsetInBytes += m_maxCallerArgFrameSizeInBytes;
// ... live allocas.
Module *mod = m_function->getParent();
DataLayout DL(mod);
DenseMap<Instruction *, Instruction *> allocaToStack;
Instruction *insertBefore = getInstructionAfter(m_stackFrameOffset);
for (Instruction *inst : lv.getAllLiveValues()) {
AllocaInst *alloc = dyn_cast<AllocaInst>(inst);
if (!alloc)
continue;
// Allocate a slot in the stack frame for the alloca
offsetInBytes = align(offsetInBytes, inst, DL);
Instruction *stackAlloca =
createStackPtr(m_stackFrameOffset, alloc, offsetInBytes, insertBefore);
alloc->replaceAllUsesWith(stackAlloca);
allocaToStack[inst] = stackAlloca;
offsetInBytes += DL.getTypeAllocSize(alloc->getAllocatedType());
}
lv.remapLiveValues(allocaToStack); // replace old allocas with stackAllocas
for (auto &kv : allocaToStack)
kv.first->eraseFromParent(); // delete old allocas
// Set payload offsets now that they are all on the stack
for (CallInst *call : m_movePayloadToStackCalls) {
CallInst *payloadStackPtr = dyn_cast<CallInst>(call->getArgOperand(0));
assert(payloadStackPtr->getCalledFunction()->getName().startswith(
"stack.ptr"));
Value *baseOffset = payloadStackPtr->getArgOperand(0);
Value *idx = payloadStackPtr->getArgOperand(1);
Value *payloadOffset =
BinaryOperator::Create(Instruction::Add, baseOffset, idx, "", call);
call->replaceAllUsesWith(payloadOffset);
payloadOffset->takeName(call);
call->eraseFromParent();
}
// printFunction("AfterStackAllocas");
// ... saves/restores for each call site
// Create allocas for live values. This makes it easier to generate code
// because we don't have to maintain the use-def chains of SSA form. We can
// just load/store from/to the alloca for a particular value. A subsequent
// mem2reg pass will rebuild the SSA form.
DenseMap<Instruction *, AllocaInst *> valToAlloca;
DenseMap<AllocaInst *, Instruction *> allocaToVal;
for (Instruction *inst : lv.getAllLiveValues())
reg2Mem(valToAlloca, allocaToVal, inst);
// printFunction("AfterReg2Mem");
uint64_t baseOffsetInBytes = offsetInBytes;
uint64_t maxOffsetInBytes = offsetInBytes;
for (size_t i = 0; i < m_callSites.size(); ++i) {
offsetInBytes = baseOffsetInBytes;
const InstructionSetVector &liveHere = lv.getLiveValues(i);
std::vector<Instruction *> workList(liveHere.begin(), liveHere.end());
std::set<Instruction *> visited;
Rematerializer R(allocaToVal, liveHere, *m_resources);
Instruction *saveInsertBefore = m_callSites[i];
Instruction *restoreInsertBefore = getInstructionAfter(m_callSites[i]);
Instruction *rematInsertBefore = nullptr; // create only if needed
// Rematerialize stack offsets after the continuation before other restores
for (Instruction *inst : stackOffsets) {
visited.insert(inst);
Instruction *remat = R.rematerialize(inst, workList, restoreInsertBefore);
new StoreInst(remat, valToAlloca[inst], restoreInsertBefore);
}
Instruction *saveStackFrameOffset = new LoadInst(
valToAlloca[m_stackFrameOffset], "stackFrame.offset", saveInsertBefore);
Instruction *restoreStackFrameOffset =
R.getRematerializedValueFor(m_stackFrameOffset);
while (!workList.empty()) {
Instruction *inst = workList.back();
workList.pop_back();
if (!visited.insert(inst).second)
continue;
if (!R.canRematerialize(inst)) {
assert(!inst->getType()->isPointerTy() && "Can not save pointers");
offsetInBytes = align(offsetInBytes, inst, DL);
AllocaInst *alloca = valToAlloca[inst];
Value *saveVal = new LoadInst(
alloca, addSuffix(inst->getName(), ".save"), saveInsertBefore);
createStackStore(saveStackFrameOffset, saveVal, offsetInBytes,
saveInsertBefore);
Value *restoreVal = createStackLoad(restoreStackFrameOffset, inst,
offsetInBytes, restoreInsertBefore);
new StoreInst(restoreVal, alloca, restoreInsertBefore);
offsetInBytes += DL.getTypeAllocSize(inst->getType());
} else if (R.getRematerializedValueFor(inst) == nullptr) {
if (!rematInsertBefore) {
// Create a new block after restores for rematerialized values. This
// ensures that we can use restored values (through their allocas)
// even if we haven't generated the actual restore yet.
rematInsertBefore =
restoreInsertBefore->getParent()
->splitBasicBlock(restoreInsertBefore, "remat_begin")
->begin();
restoreInsertBefore = m_callSites[i]->getParent()->getTerminator();
}
Instruction *remat = R.rematerialize(inst, workList, rematInsertBefore);
new StoreInst(remat, valToAlloca[inst], rematInsertBefore);
}
}
// Take the max offset over all call sites
maxOffsetInBytes = std::max(maxOffsetInBytes, offsetInBytes);
}
// ... traceFrame (if any)
maxOffsetInBytes += m_traceFrameSizeInBytes;
// Set the stack size
rewriteDummyStackSize(maxOffsetInBytes);
shaderStackSize = maxOffsetInBytes;
}
void StateFunctionTransform::createSubstateFunctions(
std::vector<Function *> &stateFunctions) {
// The runtime perf of split() depends on the number of blocks in the
// function. Simplifying the CFG before the split helps reduce the cost of
// that operation.
runPasses(m_function, {createCFGSimplificationPass()});
stateFunctions.resize(m_callSites.size() + 1);
BasicBlockVector substateEntryBlocks = replaceCallSites();
for (size_t i = 0, e = stateFunctions.size(); i < e; ++i) {
stateFunctions[i] = split(m_function, substateEntryBlocks[i], i);
// Add an attribute so we can detect when an intrinsic is not being called
// from a state function, and thus doesn't have access to the runtimeData
// pointer.
stateFunctions[i]->addFnAttr("state_function", "true");
}
// Erase base function
m_function->eraseFromParent();
m_function = nullptr;
}
void StateFunctionTransform::allocateStackFrame() {
Module *mod = m_function->getParent();
// Push stack frame in entry block.
Instruction *insertBefore = m_stackFrameOffset;
Function *stackFramePushFunc = FunctionBuilder(mod, "stackFramePush")
.voidTy()
.type(m_runtimeDataArgTy, "runtimeData")
.i32("size")
.build();
m_stackFramePush = CallInst::Create(stackFramePushFunc,
{m_runtimeDataArg, m_stackFrameSizeVal},
"", insertBefore);
// Pop the stack frame just before returns.
Function *stackFramePop = FunctionBuilder(mod, "stackFramePop")
.voidTy()
.type(m_runtimeDataArgTy, "runtimeData")
.i32("size")
.build();
for (Instruction *insertBefore : m_returns)
CallInst::Create(stackFramePop, {m_runtimeDataArg, m_stackFrameSizeVal}, "",
insertBefore);
}
void StateFunctionTransform::allocateTraceFrame() {
assert(m_attributeSizeInBytes >= 0 &&
"Attribute size has not been specified");
m_traceFrameSizeInBytes =
2 * m_attributeSizeInBytes // committed and pending attributes
+ 2 * sizeof(int); // old committed/pending attribute offsets
int attrSizeInInts = m_attributeSizeInBytes / sizeof(int);
// Push the trace frame first thing so that the runtime
// can do setup relative to the entry stack offset.
Module *mod = m_function->getParent();
Instruction *insertBefore = afterEntryBlockAllocas(m_function);
Value *attrSize = makeInt32(attrSizeInInts, mod->getContext());
Function *traceFramePushFunc = FunctionBuilder(mod, "traceFramePush")
.voidTy()
.type(m_runtimeDataArgTy, "runtimeData")
.i32("attrSize")
.build();
CallInst::Create(traceFramePushFunc, {m_runtimeDataArg, attrSize}, "",
insertBefore);
// Pop the trace frame just before returns.
Function *traceFramePopFunc = FunctionBuilder(mod, "traceFramePop")
.voidTy()
.type(m_runtimeDataArgTy, "runtimeData")
.build();
for (Instruction *insertBefore : m_returns)
CallInst::Create(traceFramePopFunc, {m_runtimeDataArg}, "", insertBefore);
}
bool isTemporaryAlloca(Value *op) {
// TODO: Need to some analysis to figure this out. We can put the alloca on
// the caller stack if:
// there is only a single callsite OR
// if no callsite between stores/loads and this callsite
return true;
}
void StateFunctionTransform::createArgFrames() {
Module *mod = m_function->getParent();
DataLayout DL(mod);
Instruction *stackAllocaInsertBefore =
getInstructionAfter(m_stackFrameOffset);
// Retrieve this function's arguments from the stack
if (m_function->getFunctionType()->getNumParams() > 0) {
if (m_paramTypes.empty())
m_paramTypes.assign(m_function->getFunctionType()->getNumParams(),
PST_NONE); // assume standard argument types
static_assert(PST_COUNT == 3, "Expected 3 parameter semantic types");
int offsetInBytes[PST_COUNT] = {0, 0, 0};
Value *baseOffset[PST_COUNT] = {nullptr, nullptr, nullptr};
Instruction *insertBefore = stackAllocaInsertBefore;
for (auto pst : m_paramTypes) {
if (baseOffset[pst])
continue;
if (pst == PST_NONE) {
baseOffset[pst] = BinaryOperator::Create(
Instruction::Add, m_stackFrameOffset, m_stackFrameSizeVal,
"callerArgFrame.offset", insertBefore);
offsetInBytes[pst] = sizeof(
int); // skip the first element in caller arg frame (returnStateID)
} else if (pst == PST_PAYLOAD) {
baseOffset[pst] = m_payloadOffset;
} else if (pst == PST_ATTRIBUTE) {
baseOffset[pst] =
(m_useCommittedAttr) ? m_committedAttrOffset : m_pendingAttrOffset;
} else {
assert(0 && "Bad parameter type");
}
}
int argIdx = 0;
for (auto &arg : m_function->args()) {
ParameterSemanticType pst = m_paramTypes[argIdx];
Value *val = nullptr;
if (arg.getType()->isPointerTy()) {
// Assume that pointed to memory is on the stack.
val = createStackPtr(baseOffset[pst], &arg, offsetInBytes[pst],
insertBefore);
offsetInBytes[pst] +=
DL.getTypeAllocSize(arg.getType()->getPointerElementType());
} else {
val = createStackLoad(baseOffset[pst], &arg, offsetInBytes[pst],
insertBefore);
offsetInBytes[pst] += DL.getTypeAllocSize(arg.getType());
}
// Replace use of the argument with the loaded value
if (arg.hasName())
val->takeName(&arg);
else
val->setName("arg" + std::to_string(argIdx));
arg.replaceAllUsesWith(val);
argIdx++;
}
}
// Process function arguments for each call site
m_maxCallerArgFrameSizeInBytes = 0;
for (size_t i = 0; i < m_callSites.size(); ++i) {
int offsetInBytes = 0;
CallInst *call = m_callSites[i];
FunctionType *FT = call->getCalledFunction()->getFunctionType();
StringRef calledFuncName = call->getCalledFunction()->getName();
Instruction *insertBefore = call;
// Set the return stateId (next substate of this function)
int nextSubstate = i + 1;
Value *nextStateId =
getDummyStateId(m_functionIdx, nextSubstate, insertBefore);
createStackStore(m_stackFrameOffset, nextStateId, offsetInBytes,
insertBefore);
offsetInBytes += DL.getTypeAllocSize(nextStateId->getType());
if (FT->getNumParams() && calledFuncName != CALL_INDIRECT_NAME) {
for (unsigned index = 0; index < FT->getNumParams(); ++index) {
// Save the argument from the argFrame
Value *op = call->getArgOperand(index);
Type *opTy = op->getType();
if (opTy->isPointerTy()) {
// TODO: Until we have callable shaders we should not get here except
// in tests.
if (isTemporaryAlloca(op)) {
// We can just replace the alloca with space in the arg frame
assert(isa<AllocaInst>(op));
Value *stackAlloca = createStackPtr(
m_stackFrameOffset, op, offsetInBytes, stackAllocaInsertBefore);
op->replaceAllUsesWith(stackAlloca);
cast<AllocaInst>(op)->eraseFromParent();
} else {
// copy in/out
assert(0);
}
offsetInBytes += DL.getTypeAllocSize(opTy->getPointerElementType());
} else {
createStackStore(m_stackFrameOffset, op, offsetInBytes, insertBefore);
offsetInBytes += DL.getTypeAllocSize(opTy);
}
// Replace use of the argument with undef
call->setArgOperand(index, UndefValue::get(opTy));
}
}
if (offsetInBytes > m_maxCallerArgFrameSizeInBytes)
m_maxCallerArgFrameSizeInBytes = offsetInBytes;
}
}
void StateFunctionTransform::changeFunctionSignature() {
// Create a new function that takes a state object pointer and returns next
// state ID and splice in the body of the old function into the new one.
Function *newFunc =
FunctionBuilder(m_function->getParent(), m_functionName + "_tmp")
.i32()
.type(m_runtimeDataArgTy, "runtimeData")
.build();
newFunc->getBasicBlockList().splice(newFunc->begin(),
m_function->getBasicBlockList());
m_function = newFunc;
// Set the runtime data pointer and remove the dummy function .
Value *runtimeDataArg = m_function->arg_begin();
replaceValAndRemoveUnusedDummyFunc(m_runtimeDataArg, runtimeDataArg,
m_function);
m_runtimeDataArg = runtimeDataArg;
// Get return stateID from stack on each return.
LLVMContext &context = m_function->getContext();
Value *zero = makeInt32(0, context);
CallInst *retStackFrameOffset = m_stackFrameOffset;
for (ReturnInst *&ret : m_returns) {
Instruction *insertBefore = ret;
if (m_stackFramePush)
retStackFrameOffset = CallInst::Create(
m_stackFrameOffset->getCalledFunction(), {m_runtimeDataArg},
"ret.stackFrame.offset", insertBefore);
Instruction *returnStateIdPtr = CallInst::Create(
m_stackIntPtrFunc, {m_runtimeDataArg, retStackFrameOffset, zero},
"ret.stateId.ptr", insertBefore);
Value *returnStateId =
new LoadInst(returnStateIdPtr, "ret.stateId", insertBefore);
ReturnInst *newRet = ReturnInst::Create(context, returnStateId);
ReplaceInstWithInst(ret, newRet);
ret = newRet; // update reference
}
}
void StateFunctionTransform::rewriteDummyStackSize(uint64_t frameSizeInBytes) {
assert(frameSizeInBytes % sizeof(int) == 0);
Value *frameSizeVal =
makeInt32(frameSizeInBytes / sizeof(int), m_function->getContext());
replaceValAndRemoveUnusedDummyFunc(m_stackFrameSizeVal, frameSizeVal,
m_function);
m_stackFrameSizeVal = frameSizeVal;
}
void StateFunctionTransform::createStackStore(Value *baseOffset, Value *val,
int offsetInBytes,
Instruction *insertBefore) {
assert(offsetInBytes % sizeof(int) == 0);
Value *intIndex =
makeInt32(offsetInBytes / sizeof(int), insertBefore->getContext());
Value *args[] = {val, baseOffset, intIndex};
Type *argTypes[] = {args[0]->getType(), args[1]->getType(),
args[2]->getType()};
FunctionType *FT =
FunctionType::get(Type::getVoidTy(val->getContext()), argTypes, false);
Function *F = getOrCreateFunction("stack.store", insertBefore->getModule(),
FT, m_stackStoreFuncs);
CallInst::Create(F, args, "", insertBefore);
}
Instruction *
StateFunctionTransform::createStackLoad(Value *baseOffset, Value *val,
int offsetInBytes,
Instruction *insertBefore) {
assert(offsetInBytes % sizeof(int) == 0);
Value *intIndex =
makeInt32(offsetInBytes / sizeof(int), insertBefore->getContext());
Value *args[] = {baseOffset, intIndex};
Type *argTypes[] = {args[0]->getType(), args[1]->getType()};
FunctionType *FT = FunctionType::get(val->getType(), argTypes, false);
Function *F = getOrCreateFunction("stack.load", insertBefore->getModule(), FT,
m_stackLoadFuncs);
return CallInst::Create(F, args, addSuffix(val->getName(), ".restore"),
insertBefore);
}
Instruction *StateFunctionTransform::createStackPtr(Value *baseOffset,
Type *valTy,
Value *intIndex,
Instruction *insertBefore) {
Value *args[] = {baseOffset, intIndex};
Type *argTypes[] = {args[0]->getType(), args[1]->getType()};
FunctionType *FT = FunctionType::get(valTy, argTypes, false);
Function *F = getOrCreateFunction("stack.ptr", insertBefore->getModule(), FT,
m_stackPtrFuncs);
CallInst *call = CallInst::Create(F, args, "", insertBefore);
return call;
}
Instruction *StateFunctionTransform::createStackPtr(Value *baseOffset,
Value *val,
int offsetInBytes,
Instruction *insertBefore) {
assert(offsetInBytes % sizeof(int) == 0);
Value *intIndex =
makeInt32(offsetInBytes / sizeof(int), insertBefore->getContext());
Instruction *ptr =
createStackPtr(baseOffset, val->getType(), intIndex, insertBefore);
ptr->takeName(val);
return ptr;
}
static bool isStackIntPtr(Value *val) {
CallInst *call = dyn_cast<CallInst>(val);
return call && call->getCalledFunction()->getName().startswith("stack.ptr");
}
// This code adapted from GetElementPtrInst::accumulateConstantOffset().
// TODO: Use a single function for both constant and dynamic offsets? Could do
// some constant folding along the way for dynamic offsets.
Value *accumulateDynamicOffset(GetElementPtrInst *gep, const DataLayout &DL) {
LLVMContext &C = gep->getContext();
Instruction *insertBefore = gep;
Value *offset = makeInt32(0, C);
for (gep_type_iterator GTI = gep_type_begin(gep), GTE = gep_type_end(gep);
GTI != GTE; ++GTI) {
ConstantInt *OpC = dyn_cast<ConstantInt>(GTI.getOperand());
if (OpC && OpC->isZero())
continue;
// Handle a struct index, which adds its field offset to the pointer.
Value *elementOffset = nullptr;
if (StructType *STy = dyn_cast<StructType>(*GTI)) {
assert(OpC && "Structure indices must be constant");
unsigned ElementIdx = OpC->getZExtValue();
const StructLayout *SL = DL.getStructLayout(STy);
elementOffset =
makeInt32(SL->getElementOffset(ElementIdx) / sizeof(int), C);
} else {
// For array or vector indices, scale the index by the size of the type.
Value *stride =
makeInt32(DL.getTypeAllocSize(GTI.getIndexedType()) / sizeof(int), C);
elementOffset = BinaryOperator::Create(Instruction::Mul, GTI.getOperand(),
stride, "elOffs", insertBefore);
}
offset = BinaryOperator::Create(Instruction::Add, offset, elementOffset,
"offs", insertBefore);
}
return offset;
}
// Adds gep offset to offsetVal and returns the result
static Value *accumulateGepOffset(GetElementPtrInst *gep, Value *offsetVal) {
Module *M = gep->getModule();
const DataLayout &DL = M->getDataLayout();
Value *elementOffsetVal = nullptr;
APInt constOffset(DL.getPointerSizeInBits(), 0);
if (gep->accumulateConstantOffset(DL, constOffset))
elementOffsetVal = makeInt32((int)constOffset.getZExtValue() / sizeof(int),
M->getContext());
else
elementOffsetVal = accumulateDynamicOffset(gep, DL);
elementOffsetVal = BinaryOperator::Create(Instruction::Add, offsetVal,
elementOffsetVal, "offs", gep);
return elementOffsetVal;
}
// Turn GEPs on a stack.ptr of aggregate type into stack.ptrs of scalar type
void StateFunctionTransform::flattenGepsOnValue(Value *val, Value *baseOffset,
Value *offsetVal) {
for (auto U = val->user_begin(), UE = val->user_end(); U != UE;) {
User *user = *U++;
if (CallInst *call = dyn_cast<CallInst>(user)) {
// inline the call to expose GEPs and restart the loop.
InlineFunctionInfo IFI;
bool success = InlineFunction(call, IFI, false);
assert(success);
(void)success;
U = val->user_begin();
UE = val->user_end();
continue;
}
GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(user);
if (!gep)
continue;
Value *elementOffsetVal = accumulateGepOffset(gep, offsetVal);
Type *gepElTy = gep->getType()->getPointerElementType();
if (gepElTy->isAggregateType()) {
// flatten geps on this gep
flattenGepsOnValue(gep, baseOffset, elementOffsetVal);
} else if (isa<VectorType>(gepElTy))
scalarizeVectorStackAccess(gep, baseOffset, elementOffsetVal);
else {
Value *ptr =
createStackPtr(baseOffset, gep->getType(), elementOffsetVal, gep);
ptr->takeName(
gep); // could use a name that encodes the gep type and indices
gep->replaceAllUsesWith(ptr);
}
gep->eraseFromParent();
}
}
void StateFunctionTransform::scalarizeVectorStackAccess(Instruction *vecPtr,
Value *baseOffset,
Value *offsetVal) {
std::vector<Value *> elPtrs;
Type *VTy = vecPtr->getType()->getPointerElementType();
Type *elTy = VTy->getVectorElementType();
LLVMContext &C = vecPtr->getContext();
Value *curOffsetVal = offsetVal;
Value *one = makeInt32(1, C);
offsetVal->setName("offs0.");
for (unsigned i = 0; i < VTy->getVectorNumElements(); ++i) {
// TODO: If offsetVal is a constant we could just create constants instead
// of add instructions
if (i > 0)
curOffsetVal = BinaryOperator::Create(Instruction::Add, curOffsetVal, one,
stringf("offs%d.", i), vecPtr);
elPtrs.push_back(
createStackPtr(baseOffset, elTy->getPointerTo(), curOffsetVal, vecPtr));
elPtrs.back()->setName(addSuffix(vecPtr->getName(), stringf(".el%d.", i)));
}
// Scalarize load/stores
for (auto U = vecPtr->user_begin(), UE = vecPtr->user_end(); U != UE;) {
User *user = *U++;
if (LoadInst *load = dyn_cast<LoadInst>(user)) {
Value *vec = UndefValue::get(VTy);
for (size_t i = 0; i < elPtrs.size(); ++i) {
Value *el = new LoadInst(elPtrs[i], stringf("el%d.", i), load);
vec = InsertElementInst::Create(vec, el, makeInt32(i, C), "vec", load);
}
load->replaceAllUsesWith(vec);
load->eraseFromParent();
} else if (StoreInst *store = dyn_cast<StoreInst>(user)) {
Value *vec = store->getOperand(0);
for (size_t i = 0; i < elPtrs.size(); ++i) {
Value *el = ExtractElementInst::Create(vec, makeInt32(i, C),
stringf("el%d.", i), store);
new StoreInst(el, elPtrs[i], store);
}
store->eraseFromParent();
} else {
assert(0 && "Unhandled user");
}
}
}
void StateFunctionTransform::lowerStackFuncs() {
LLVMContext &C = m_stackIntPtrFunc->getContext();
const DataLayout &DL = m_stackIntPtrFunc->getParent()->getDataLayout();
// stack.store functions
for (auto &kv : m_stackStoreFuncs) {
Function *F = kv.second;
for (auto U = F->user_begin(); U != F->user_end();) {
CallInst *call = dyn_cast<CallInst>(*(U++));
assert(call);
Value *runtimeDataArg = call->getParent()->getParent()->arg_begin();
Value *val = call->getArgOperand(0);
Value *offset = call->getArgOperand(1);
int idx = getConstantValue(call->getArgOperand(2));
Instruction *insertBefore = call;
if (isStackIntPtr(val)) {
// Copy from one part of the stack to another
CallInst *valCall = dyn_cast<CallInst>(val);
Value *srcOffset = valCall->getArgOperand(0);
int srcIdx = getConstantValue(valCall->getArgOperand(1));
Value *dstOffset = offset;
int dstIdx = idx;
int intCount =
(int)DL.getTypeAllocSize(val->getType()->getPointerElementType()) /
sizeof(int);
for (int i = 0; i < intCount; ++i) {
std::string idxStr = stringf("%d.", i);
Value *srcPtr = CallInst::Create(
m_stackIntPtrFunc,
{runtimeDataArg, srcOffset, makeInt32(srcIdx + i, C)},
addSuffix(val->getName(), ".ptr" + idxStr), insertBefore);
Value *dstPtr = CallInst::Create(
m_stackIntPtrFunc,
{runtimeDataArg, dstOffset, makeInt32(dstIdx + i, C)},
"dst.ptr" + idxStr, insertBefore);
Value *intVal =
new LoadInst(srcPtr, "copy.val" + idxStr, insertBefore);
new StoreInst(intVal, dstPtr, insertBefore);
}
} else {
store(val, m_stackIntPtrFunc, runtimeDataArg, offset, idx,
insertBefore);
}
call->eraseFromParent();
}
F->eraseFromParent();
}
// stack.load functions
for (auto &kv : m_stackLoadFuncs) {
Function *F = kv.second;
for (auto U = F->user_begin(); U != F->user_end();) {
CallInst *call = dyn_cast<CallInst>(*(U++));
assert(call);
std::string name = stripSuffix(call->getName(), ".restore");
call->setName("");
Value *runtimeDataArg = call->getParent()->getParent()->arg_begin();
Value *offset = call->getArgOperand(0);
Value *idx = call->getArgOperand(1);
Instruction *insertBefore = call;
Value *val = load(m_stackIntPtrFunc, runtimeDataArg, offset, idx, name,
call->getType(), insertBefore);
call->replaceAllUsesWith(val);
call->eraseFromParent();
}
F->eraseFromParent();
}
// Scalarize accesses based on a stack.ptr func
for (auto &kv : m_stackPtrFuncs) {
Function *F = kv.second;
if (!F->getReturnType()->getPointerElementType()->isAggregateType())
continue;
for (auto U = F->user_begin(), UE = F->user_end(); U != UE;) {
CallInst *call = dyn_cast<CallInst>(*(U++));
assert(call);
Value *offset = call->getArgOperand(0);
Value *idx = call->getArgOperand(1);
flattenGepsOnValue(call, offset, idx);
call->eraseFromParent();
}
}
// stack.ptr functions
for (auto &kv : m_stackPtrFuncs) {
Function *F = kv.second;
for (auto U = F->user_begin(); U != F->user_end();) {
CallInst *call = dyn_cast<CallInst>(*(U++));
assert(call);
std::string name = call->getName();
Value *runtimeDataArg = call->getParent()->getParent()->arg_begin();
Value *offset = call->getArgOperand(0);
Value *idx = call->getArgOperand(1);
Instruction *insertBefore = call;
Value *ptr =
CallInst::Create(m_stackIntPtrFunc, {runtimeDataArg, offset, idx},
addSuffix(name, ".ptr"), insertBefore);
if (ptr->getType() != call->getType())
ptr = new BitCastInst(ptr, call->getType(), "", insertBefore);
ptr->takeName(call);
call->replaceAllUsesWith(ptr);
call->eraseFromParent();
}
F->eraseFromParent();
}
}
Function *StateFunctionTransform::split(Function *baseFunc,
BasicBlock *substateEntryBlock,
int substateIndex) {
ValueToValueMapTy VMap;
Function *substateFunc = cloneBlocksReachableFrom(substateEntryBlock, VMap);
Module *mod = baseFunc->getParent();
mod->getFunctionList().push_back(substateFunc);
substateFunc->setName(m_functionName + ".ss_" +
std::to_string(substateIndex));
if (substateIndex != 0) {
// Collect allocas from entry block
SmallVector<Instruction *, 16> allocasToClone;
for (auto &I : baseFunc->getEntryBlock().getInstList()) {
if (isa<AllocaInst>(&I))
allocasToClone.push_back(&I);
}
// Clone collected allocas
BasicBlock *newEntryBlock = &substateFunc->getEntryBlock();
for (auto I : allocasToClone) {
// Collect users of original instruction in substateFunc
std::vector<Instruction *> users;
for (auto U : I->users()) {
Instruction *inst = dyn_cast<Instruction>(U);
if (inst->getParent()->getParent() == substateFunc)
users.push_back(inst);
}
if (users.empty())
continue;
// Clone instruction
Instruction *clone = I->clone();
if (I->hasName())
clone->setName(I->getName());
clone->insertBefore(
newEntryBlock->getFirstInsertionPt()); // allocas first in entry block
RemapInstruction(clone, VMap,
RF_NoModuleLevelChanges | RF_IgnoreMissingEntries);
// Replaces uses
for (auto user : users)
user->replaceUsesOfWith(I, clone);
}
}
// printFunction( substateFunc, substateFunc->getName().str() +
// "-BeforeSplittingOpt", m_dumpId++ );
makeReducible(substateFunc);
// Undo the reg2mem done in preserveLiveValuesAcrossCallSites()
runPasses(substateFunc,
{createVerifierPass(), createPromoteMemoryToRegisterPass()});
// printFunction( substateFunc, substateFunc->getName().str() +
// "-AfterSplitting", m_dumpId++ );
return substateFunc;
}
BasicBlockVector StateFunctionTransform::replaceCallSites() {
LLVMContext &context = m_function->getContext();
BasicBlockVector substateEntryPoints{&m_function->getEntryBlock()};
substateEntryPoints[0]->setName(m_functionName + ".BB0");
// Add other substates by splitting blocks at call sites.
for (size_t i = 0; i < m_callSites.size(); ++i) {
CallInst *call = m_callSites[i];
BasicBlock *block = call->getParent();
StringRef calledFuncName = call->getCalledFunction()->getName();
BasicBlock *nextBlock = block->splitBasicBlock(
call->getNextNode(), m_functionName + ".BB" + std::to_string(i + 1) +
".from." + cleanName(calledFuncName));
substateEntryPoints.push_back(nextBlock);
// Return state id for entry state of the function being called
Instruction *insertBefore = call;
Value *returnStateId = nullptr;
if (calledFuncName == CALL_INDIRECT_NAME)
returnStateId = call->getArgOperand(0);
else
returnStateId =
getDummyStateId(m_callSiteFunctionIdx[i], 0, insertBefore);
ReplaceInstWithInst(call->getParent()->getTerminator(),
ReturnInst::Create(context, returnStateId));
call->eraseFromParent();
}
return substateEntryPoints;
}
llvm::Value *
StateFunctionTransform::getDummyStateId(int functionIdx, int substate,
llvm::Instruction *insertBefore) {
if (!m_dummyStateIdFunc) {
Module *M = m_function->getParent();
m_dummyStateIdFunc = FunctionBuilder(M, "dummyStateId")
.i32()
.i32("functionIdx")
.i32("substate")
.build();
}
LLVMContext &context = insertBefore->getContext();
Value *functionIdxVal = makeInt32(functionIdx, context);
Value *substateVal = makeInt32(substate, context);
return CallInst::Create(m_dummyStateIdFunc, {functionIdxVal, substateVal},
"stateId", insertBefore);
}
raw_ostream &
StateFunctionTransform::getOutputStream(const std::string functionName,
const std::string &suffix,
unsigned int dumpId) {
if (m_dumpFilename.empty())
return DBGS();
const std::string filename =
createDumpPath(m_dumpFilename, dumpId, suffix, functionName);
std::error_code errorCode;
raw_ostream *out =
new raw_fd_ostream(filename, errorCode, sys::fs::OpenFlags::F_None);
if (errorCode) {
DBGS() << "Failed to open " << filename << " for writing sft output. "
<< errorCode.message() << "\n";
delete out;
return DBGS();
}
return *out;
}
void StateFunctionTransform::printFunction(const Function *function,
const std::string &suffix,
unsigned int dumpId) {
if (!m_verbose)
return;
raw_ostream &out = getOutputStream(m_functionName, suffix, dumpId);
out << "; ########################### " << suffix << "\n";
out << *function << "\n";
if (&out != &DBGS())
delete &out;
}
void StateFunctionTransform::printFunction(const std::string &suffix) {
printFunction(m_function, suffix, m_dumpId++);
}
void StateFunctionTransform::printFunctions(
const std::vector<Function *> &funcs, const char *suffix) {
if (!m_verbose)
return;
raw_ostream &out = getOutputStream(m_functionName, suffix, m_dumpId++);
out << "; ########################### " << suffix << "\n";
for (Function *F : funcs)
out << *F << "\n";
if (&out != &DBGS())
delete &out;
}
void StateFunctionTransform::printModule(const Module *mod,
const std::string &suffix) {
if (!m_verbose)
return;
raw_ostream &out = getOutputStream("module", suffix, m_dumpId++);
out << "; ########################### " << suffix << "\n";
out << *mod << "\n";
}
void StateFunctionTransform::printSet(const InstructionSetVector &vals,
const char *msg) {
if (!m_verbose)
return;
raw_ostream &out = DBGS();
if (msg)
out << msg << " --------------------\n";
uint64_t totalBytes = 0;
if (vals.size() > 0) {
Module *mod = m_function->getParent();
DataLayout DL(mod);
for (InstructionSetVector::const_iterator I = vals.begin(), IE = vals.end();
I != IE; ++I) {
const Instruction *inst = *I;
uint64_t size = DL.getTypeAllocSize(inst->getType());
out << stringf("%3dB: ", size) << *inst << '\n';
totalBytes += size;
}
}
out << "Count:" << vals.size() << " Bytes:" << totalBytes << "\n\n";
}