| #include "dxc/DxrFallback/DxrFallbackCompiler.h" |
| |
| #include "dxc/DXIL/DxilFunctionProps.h" |
| #include "dxc/DXIL/DxilInstructions.h" |
| #include "dxc/DXIL/DxilModule.h" |
| #include "dxc/DXIL/DxilOperations.h" |
| #include "dxc/HLSL/DxilLinker.h" |
| #include "dxc/Support/FileIOHelper.h" |
| #include "dxc/Support/Global.h" |
| #include "dxc/Support/Unicode.h" |
| #include "dxc/Support/WinIncludes.h" |
| #include "dxc/Support/dxcapi.impl.h" |
| #include "dxc/Support/dxcapi.use.h" |
| #include "dxc/dxcapi.h" |
| #include "dxc/dxcdxrfallbackcompiler.h" |
| |
| #include "llvm/Analysis/CallGraph.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/InstIterator.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/IR/LegacyPassManager.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/Linker/Linker.h" |
| #include "llvm/Transforms/IPO.h" |
| #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| #include "llvm/Transforms/Utils/Cloning.h" |
| |
| #include "FunctionBuilder.h" |
| #include "LLVMUtils.h" |
| #include "StateFunctionTransform.h" |
| #include "runtime.h" |
| |
| #include <queue> |
| |
| using namespace hlsl; |
| using namespace llvm; |
| |
| static std::vector<Function *> |
| getFunctionsWithPrefix(Module *mod, const std::string &prefix) { |
| std::vector<Function *> functions; |
| for (auto F = mod->begin(), E = mod->end(); F != E; ++F) { |
| StringRef name = F->getName(); |
| if (name.startswith(prefix)) |
| functions.push_back(F); |
| } |
| return functions; |
| } |
| |
| static bool inlineFunc(CallInst *call, Function *Fimpl) { |
| // Note. LLVM inlining may not be sufficient if the function references DX |
| // resources because the corresponding metadata is not created if the function |
| // comes from another module. |
| |
| // Make sure that we have a definition for the called function in this module |
| Function *F = call->getCalledFunction(); |
| Module *dstM = F->getParent(); |
| if (F->isDeclaration()) { |
| // Map called functions in impl module to functions in this one (because the |
| // cloning step doesn't do this automatically) |
| ValueToValueMapTy VMap; |
| for (auto &I : inst_range(Fimpl)) { |
| if (CallInst *c = dyn_cast<CallInst>(&I)) { |
| Function *calledFimpl = c->getCalledFunction(); |
| if (VMap.count(calledFimpl)) |
| continue; |
| |
| Constant *calledF = dstM->getOrInsertFunction( |
| calledFimpl->getName(), calledFimpl->getFunctionType(), |
| calledFimpl->getAttributes()); |
| VMap[calledFimpl] = calledF; |
| } |
| } |
| |
| // Map arguments |
| for (auto SI = Fimpl->arg_begin(), SE = Fimpl->arg_end(), |
| DI = F->arg_begin(); |
| SI != SE; ++SI, ++DI) |
| VMap[SI] = DI; |
| |
| SmallVector<ReturnInst *, 4> returns; |
| CloneFunctionInto(F, Fimpl, VMap, true, returns); |
| F->setLinkage(GlobalValue::InternalLinkage); |
| } |
| |
| InlineFunctionInfo IFI; |
| return InlineFunction(call, IFI, false); |
| } |
| |
| // Remove ELF mangling |
| static std::string cleanName(StringRef name) { |
| if (!name.startswith("\x1?")) |
| return name; |
| |
| size_t pos = name.find("@@"); |
| if (pos == name.npos) |
| return name; |
| |
| std::string newName = name.substr(2, pos - 2); |
| return newName; |
| } |
| |
| static inline Function *getOrInsertFunction(Module *mod, Function *F) { |
| return dyn_cast<Function>( |
| mod->getOrInsertFunction(F->getName(), F->getFunctionType())); |
| } |
| |
| template <typename K, typename V> |
| V get(std::map<K, V> &theMap, const K &key, |
| V defaultVal = static_cast<V>(nullptr)) { |
| auto it = theMap.find(key); |
| if (it == theMap.end()) |
| return defaultVal; |
| else |
| return it->second; |
| } |
| |
| DxrFallbackCompiler::DxrFallbackCompiler( |
| llvm::Module *mod, const std::vector<std::string> &shaderNames, |
| unsigned maxAttributeSize, unsigned stackSizeInBytes, |
| bool findCalledShaders /*= false*/) |
| : m_module(mod), m_entryShaderNames(shaderNames), |
| m_stackSizeInBytes(stackSizeInBytes), |
| m_maxAttributeSize(maxAttributeSize), |
| m_findCalledShaders(findCalledShaders) {} |
| |
| void DxrFallbackCompiler::compile(std::vector<int> &shaderEntryStateIds, |
| std::vector<unsigned int> &shaderStackSizes, |
| IntToFuncNameMap *pCachedMap) { |
| std::vector<std::string> shaderNames = m_entryShaderNames; |
| initShaderMap(shaderNames); |
| |
| // Bring in runtime so we can get the runtime data type |
| linkRuntime(); |
| Type *runtimeDataArgTy = getRuntimeDataArgType(); |
| |
| // Make sure all calls to intrinsics and shaders are at function scope and |
| // fix up control flow. |
| lowerAnyHitControlFlowFuncs(); |
| lowerReportHit(); |
| lowerTraceRay(runtimeDataArgTy); |
| |
| // Create state functions |
| IntToFuncMap stateFunctionMap; // stateID -> state function |
| const int baseStateId = |
| 1000; // could be anything but this makes stateIds more recognizable |
| createStateFunctions(stateFunctionMap, shaderEntryStateIds, shaderStackSizes, |
| baseStateId, shaderNames, runtimeDataArgTy); |
| |
| if (pCachedMap) { |
| for (auto &entry : stateFunctionMap) { |
| (*pCachedMap)[entry.first] = entry.second->getName().str(); |
| } |
| } |
| } |
| |
| void DxrFallbackCompiler::link(std::vector<int> &shaderEntryStateIds, |
| std::vector<unsigned int> &shaderStackSizes, |
| IntToFuncNameMap *pCachedMap) { |
| IntToFuncMap stateFunctionMap; // stateID -> state function |
| if (pCachedMap) { |
| for (auto entry : *pCachedMap) { |
| stateFunctionMap[entry.first] = m_module->getFunction(entry.second); |
| } |
| } else { |
| for (UINT i = 0; i < shaderEntryStateIds.size(); i++) { |
| UINT substateIndex = 0; |
| UINT baseStateId = shaderEntryStateIds[i]; |
| while (true) { |
| auto substateName = |
| m_entryShaderNames[i] + ".ss_" + std::to_string(substateIndex); |
| |
| auto function = m_module->getFunction(substateName); |
| if (!function) |
| break; |
| stateFunctionMap[baseStateId + substateIndex] = |
| m_module->getFunction(substateName); |
| substateIndex++; |
| } |
| } |
| } |
| |
| // Fix up scheduler |
| Function *schedulerFunc = m_module->getFunction("fb_Fallback_Scheduler"); |
| createLaunchParams(schedulerFunc); |
| |
| Type *runtimeDataArgTy = getRuntimeDataArgType(); |
| createStateDispatch(schedulerFunc, stateFunctionMap, runtimeDataArgTy); |
| createStack(schedulerFunc); |
| |
| lowerIntrinsics(); |
| } |
| |
| void DxrFallbackCompiler::setDebugOutputLevel(int val) { |
| m_debugOutputLevel = val; |
| } |
| |
| static bool isShader(Function *F) { |
| if (F->hasFnAttribute("exp-shader")) |
| return true; |
| |
| DxilModule &DM = F->getParent()->GetDxilModule(); |
| return (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay()); |
| } |
| |
| DXIL::ShaderKind getRayShaderKind(Function *F) { |
| if (F->hasFnAttribute("exp-shader")) |
| return DXIL::ShaderKind::RayGeneration; |
| |
| DxilModule &DM = F->getParent()->GetDxilModule(); |
| if (DM.HasDxilFunctionProps(F) && DM.GetDxilFunctionProps(F).IsRay()) |
| return DM.GetDxilFunctionProps(F).shaderKind; |
| |
| return DXIL::ShaderKind::Invalid; |
| } |
| |
| // Some shaders should use the "pending" values of intrinsics instead of the |
| // committed ones. In particular anyhit and intersection shaders use the |
| // pending values with the exception that the committed rayTCurrent should be |
| // used in intersection. |
| static bool shouldUsePendingValue(Function *F, StringRef instrinsicName) { |
| DxilModule &DM = F->getParent()->GetDxilModule(); |
| if (!DM.HasDxilFunctionProps(F)) |
| return false; |
| const hlsl::DxilFunctionProps &props = DM.GetDxilFunctionProps(F); |
| |
| return props.IsAnyHit() || |
| (props.IsIntersection() && instrinsicName != "rayTCurrent"); |
| } |
| |
| void DxrFallbackCompiler::initShaderMap(std::vector<std::string> &shaderNames) { |
| // Clean names and initialize shaderMap |
| StringToFuncMap allShadersMap; |
| for (Function &F : m_module->functions()) { |
| if (isShader(&F)) { |
| if (!F.isDeclaration()) |
| allShadersMap[cleanName(F.getName())] = &F; |
| } |
| |
| F.removeFnAttr(Attribute::NoInline); |
| } |
| |
| for (auto &name : shaderNames) |
| m_shaderMap[name] = allShadersMap[name]; |
| |
| if (!m_findCalledShaders) |
| return; |
| |
| // Create a map from shader name to CallGraphNode |
| CallGraph callGraph(*m_module); |
| std::map<std::string, CallGraphNode *> allShaderNodes; |
| for (auto &kv : m_shaderMap) { |
| const std::string &name = kv.first; |
| Function *func = kv.second; |
| allShaderNodes[name] = callGraph[func]; |
| } |
| |
| // Start traversing the call graph from given shaderNames |
| std::deque<CallGraphNode *> workList; |
| for (auto &name : shaderNames) |
| workList.push_back(allShaderNodes[name]); |
| while (!workList.empty()) { |
| CallGraphNode *cur = workList.front(); |
| workList.pop_front(); |
| for (size_t i = 0; i < cur->size(); ++i) { |
| Function *nextFunc = (*cur)[i]->getFunction(); |
| if (!nextFunc) |
| continue; |
| if (isShader(nextFunc)) { |
| const std::string nextName = cleanName(nextFunc->getName()); |
| if (m_shaderMap.count(nextName) == 0) // not in the shaderMap yet? |
| { |
| workList.push_back(allShaderNodes[nextName]); |
| shaderNames.push_back(nextName); |
| m_shaderMap[nextName] = workList.back()->getFunction(); |
| } |
| } |
| } |
| } |
| } |
| |
| void DxrFallbackCompiler::linkRuntime() { |
| Linker linker(m_module); |
| std::unique_ptr<Module> runtimeModule = |
| loadModuleFromAsmString(m_module->getContext(), getRuntimeString()); |
| bool linkErr = linker.linkInModule(runtimeModule.get()); |
| assert(!linkErr && "Error linking runtime"); |
| UNREFERENCED_PARAMETER(linkErr); |
| } |
| |
| static void inlineFuncAndAddRet(CallInst *call, Function *F) { |
| // Add a return after the function call. |
| // Should be followed immediately by "unreachable". Turn that into a "ret |
| // void". |
| Instruction *ret = ReturnInst::Create(call->getContext()); |
| ReplaceInstWithInst(call->getParent()->getTerminator(), ret); |
| |
| bool success = inlineFunc(call, F); |
| assert(success); |
| UNREFERENCED_PARAMETER(success); |
| } |
| |
| void DxrFallbackCompiler::lowerAnyHitControlFlowFuncs() { |
| std::vector<CallInst *> callsToIgnoreHit = |
| getCallsInShadersToFunction("dx.op.ignoreHit"); |
| if (!callsToIgnoreHit.empty()) { |
| Function *ignoreHitFunc = |
| m_module->getFunction("\x1?Fallback_IgnoreHit@@YAXXZ"); |
| assert(ignoreHitFunc && "IgnoreHit() implementation not found"); |
| for (CallInst *call : callsToIgnoreHit) |
| inlineFuncAndAddRet(call, ignoreHitFunc); |
| } |
| |
| std::vector<CallInst *> callsToAcceptHitAndEndSearch = |
| getCallsInShadersToFunction("dx.op.acceptHitAndEndSearch"); |
| if (!callsToAcceptHitAndEndSearch.empty()) { |
| Function *acceptHitAndEndSearchFunc = |
| m_module->getFunction("\x1?Fallback_AcceptHitAndEndSearch@@YAXXZ"); |
| assert(acceptHitAndEndSearchFunc && |
| "AcceptHitAndEndSearch() implementation not found"); |
| for (CallInst *call : callsToAcceptHitAndEndSearch) |
| inlineFuncAndAddRet(call, acceptHitAndEndSearchFunc); |
| } |
| } |
| |
| void DxrFallbackCompiler::lowerReportHit() { |
| std::vector<CallInst *> callsToReportHit = |
| getCallsInShadersToFunctionWithPrefix("dx.op.reportHit"); |
| if (callsToReportHit.empty()) |
| return; |
| |
| Function *reportHitFunc = |
| m_module->getFunction("\x1?Fallback_ReportHit@@YAHMI@Z"); |
| assert(reportHitFunc && "ReportHit() implementation not found"); |
| |
| LLVMContext &C = m_module->getContext(); |
| for (CallInst *call : callsToReportHit) { |
| // Wrap attribute arguments in Fallback_SetPendingAttr() call |
| Instruction *insertBefore = call; |
| hlsl::DxilInst_ReportHit reportHitCall(call); |
| |
| Value *attr = reportHitCall.get_Attributes(); |
| Function *setPendingAttrFunc = |
| FunctionBuilder(m_module, "\x1?Fallback_SetPendingAttr@@") |
| .voidTy() |
| .type(attr->getType(), "attr") |
| .build(); |
| CallInst::Create(setPendingAttrFunc, {attr}, "", insertBefore); |
| |
| // Make call to implementation and load result |
| CallInst *callImpl = CallInst::Create( |
| reportHitFunc, {reportHitCall.get_THit(), reportHitCall.get_HitKind()}, |
| "reportHit.result", insertBefore); |
| Value *result = callImpl; |
| |
| // Result < 0 ==> ret |
| Value *zero = makeInt32(0, C); |
| Value *ltz = new ICmpInst(insertBefore, CmpInst::ICMP_SLT, result, zero, |
| "endSearch"); |
| BasicBlock *prevBlock = call->getParent(); |
| BasicBlock *retBlock = prevBlock->splitBasicBlock(call, "endSearch"); |
| BasicBlock *nextBlock = retBlock->splitBasicBlock(call, "afterReportHit"); |
| ReplaceInstWithInst(prevBlock->getTerminator(), |
| BranchInst::Create(retBlock, nextBlock, ltz)); |
| ReplaceInstWithInst(retBlock->getTerminator(), ReturnInst::Create(C)); |
| |
| // Compare result to zero and store into original result |
| Value *gtz = |
| new ICmpInst(insertBefore, CmpInst::ICMP_SGT, result, zero, "accepted"); |
| call->replaceAllUsesWith(gtz); |
| |
| bool success = inlineFunc(callImpl, reportHitFunc); |
| assert(success); |
| (void)success; |
| |
| call->eraseFromParent(); |
| } |
| } |
| |
| void DxrFallbackCompiler::lowerTraceRay(Type *runtimeDataArgTy) { |
| std::vector<CallInst *> callsToTraceRay = |
| getCallsInShadersToFunctionWithPrefix("dx.op.traceRay"); |
| if (callsToTraceRay.empty()) { |
| // TODO: It might be worth dropping this from the tests eventually |
| callsToTraceRay = |
| getCallsInShadersToFunctionWithPrefix("\x1?TraceRayTest@@"); |
| if (callsToTraceRay.empty()) |
| return; |
| } |
| |
| std::vector<Function *> traceRayImpl = |
| getFunctionsWithPrefix(m_module, "\x1?Fallback_TraceRay@@"); |
| assert(traceRayImpl.size() == 1 && |
| "Could not find Fallback_TraceRay() implementation"); |
| |
| enum { CLOSEST_HIT = 0, MISS = 1 }; |
| Function *traceRaySave[] = {m_module->getFunction("traceRaySave_ClosestHit"), |
| m_module->getFunction("traceRaySave_Miss")}; |
| Function *traceRayRestore[] = { |
| m_module->getFunction("traceRayRestore_ClosestHit"), |
| m_module->getFunction("traceRayRestore_Miss")}; |
| assert(traceRaySave[CLOSEST_HIT] && traceRayRestore[CLOSEST_HIT] && |
| traceRaySave[MISS] && traceRayRestore[MISS] && |
| "Could not find TraceRay spill functions"); |
| |
| Function *dummyRuntimeDataArgFunc = |
| StateFunctionTransform::createDummyRuntimeDataArgFunc(m_module, |
| runtimeDataArgTy); |
| assert(dummyRuntimeDataArgFunc && |
| "dummyRuntimeDataArg function could not be created."); |
| |
| // Process calls |
| LLVMContext &C = m_module->getContext(); |
| Type *int32Ty = Type::getInt32Ty(C); |
| std::map<FunctionType *, Function *> movePayloadToStackFuncs; |
| std::map<Function *, AllocaInst *> funcToSpillAlloca; |
| for (CallInst *call : callsToTraceRay) { |
| Instruction *insertBefore = call; |
| |
| // Spill runtime data values, if necessary (closesthit and miss shaders) |
| Function *caller = call->getParent()->getParent(); |
| DXIL::ShaderKind kind = getRayShaderKind(caller); |
| if (kind == DXIL::ShaderKind::ClosestHit || |
| kind == DXIL::ShaderKind::Miss) { |
| int sh = (kind == DXIL::ShaderKind::ClosestHit) ? CLOSEST_HIT : MISS; |
| AllocaInst *spillAlloca = get(funcToSpillAlloca, caller); |
| if (!spillAlloca) { |
| Argument *spillAllocaArg = (++traceRaySave[sh]->arg_begin()); |
| Type *spillAllocaTy = |
| spillAllocaArg->getType()->getPointerElementType(); |
| spillAlloca = new AllocaInst(spillAllocaTy, "spill.alloca", |
| caller->getEntryBlock().begin()); |
| funcToSpillAlloca[caller] = spillAlloca; |
| } |
| |
| // Create calls. SFT will inline them. |
| Value *runtimeDataArg = CallInst::Create(dummyRuntimeDataArgFunc, |
| "runtimeData", insertBefore); |
| CallInst::Create(traceRaySave[sh], {runtimeDataArg, spillAlloca}, "", |
| insertBefore); |
| CallInst::Create(traceRayRestore[sh], {runtimeDataArg, spillAlloca}, "", |
| getInstructionAfter(call)); |
| } |
| |
| // Get the payload offset to pass to trace implementation |
| // hlsl::DxilInst_TraceRay traceRayCall(call); |
| // TODO: Avoiding the intrinsic to support the test's use of TraceRayTest |
| Value *payload = call->getOperand(call->getNumArgOperands() - 1); |
| FunctionType *funcType = |
| FunctionType::get(int32Ty, {payload->getType()}, false); |
| Function *movePayloadToStackFunc = getOrCreateFunction( |
| "movePayloadToStack", m_module, funcType, movePayloadToStackFuncs); |
| Value *newPayloadOffset = CallInst::Create( |
| movePayloadToStackFunc, {payload}, "new.payload.offset", insertBefore); |
| |
| // Call implementation |
| unsigned i = 0; |
| if (call->getCalledFunction()->getName().startswith("dx.op")) |
| i += 2; // skip intrinsic number and acceleration structure (for now) |
| std::vector<Value *> args; |
| for (; i < call->getNumArgOperands() - 1; ++i) |
| args.push_back(call->getArgOperand(i)); |
| args.push_back(newPayloadOffset); |
| CallInst::Create(traceRayImpl[0], args, "", insertBefore); |
| |
| call->eraseFromParent(); |
| } |
| } |
| |
| static std::vector<StateFunctionTransform::ParameterSemanticType> |
| getParameterTypes(Function *F, DXIL::ShaderKind shaderKind) { |
| std::vector<StateFunctionTransform::ParameterSemanticType> paramTypes; |
| if (shaderKind == DXIL::ShaderKind::AnyHit || |
| shaderKind == DXIL::ShaderKind::ClosestHit) { |
| paramTypes.push_back(StateFunctionTransform::PST_PAYLOAD); |
| paramTypes.push_back(StateFunctionTransform::PST_ATTRIBUTE); |
| } else if (shaderKind == DXIL::ShaderKind::Miss) { |
| paramTypes.push_back(StateFunctionTransform::PST_PAYLOAD); |
| } else { |
| paramTypes.assign(F->getNumOperands(), StateFunctionTransform::PST_NONE); |
| } |
| return paramTypes; |
| } |
| |
| static void collectResources(DxilModule &DM, std::set<Value *> &resources) { |
| for (auto &r : DM.GetCBuffers()) |
| resources.insert(r->GetGlobalSymbol()); |
| for (auto &r : DM.GetUAVs()) |
| resources.insert(r->GetGlobalSymbol()); |
| for (auto &r : DM.GetSRVs()) |
| resources.insert(r->GetGlobalSymbol()); |
| for (auto &r : DM.GetSamplers()) |
| resources.insert(r->GetGlobalSymbol()); |
| } |
| |
| void DxrFallbackCompiler::createStateFunctions( |
| IntToFuncMap &stateFunctionMap, std::vector<int> &shaderEntryStateIds, |
| std::vector<unsigned int> &shaderStackSizes, int baseStateId, |
| const std::vector<std::string> &shaderNames, Type *runtimeDataArgTy) { |
| for (auto &kv : m_shaderMap) { |
| if (kv.second == nullptr) |
| errs() << "Function not found for shader " << kv.first << "\n"; |
| } |
| |
| DxilModule &DM = m_module->GetOrCreateDxilModule(); |
| std::set<Value *> resources; |
| collectResources(DM, resources); |
| |
| shaderEntryStateIds.clear(); |
| shaderStackSizes.clear(); |
| int stateId = baseStateId; |
| for (auto &shader : shaderNames) { |
| std::vector<Function *> stateFunctions; |
| Function *F = m_shaderMap[shader]; |
| StateFunctionTransform sft(F, shaderNames, runtimeDataArgTy); |
| if (m_debugOutputLevel >= 2) |
| sft.setVerbose(true); |
| if (m_debugOutputLevel >= 3) |
| sft.setDumpFilename("dump.ll"); |
| if (shader == "Fallback_TraceRay") |
| sft.setAttributeSize(m_maxAttributeSize); |
| DXIL::ShaderKind shaderKind = getRayShaderKind(F); |
| if (shaderKind != DXIL::ShaderKind::Invalid) |
| sft.setParameterInfo(getParameterTypes(F, shaderKind), |
| shaderKind == DXIL::ShaderKind::ClosestHit); |
| sft.setResourceGlobals(resources); |
| UINT shaderStackSize = 0; |
| sft.run(stateFunctions, shaderStackSize); |
| |
| shaderEntryStateIds.push_back(stateId); |
| shaderStackSizes.push_back(shaderStackSize); |
| for (Function *stateF : stateFunctions) { |
| stateFunctionMap[stateId++] = stateF; |
| if (DM.HasDxilFunctionProps(F)) { |
| DM.CloneDxilEntryProps(F, stateF); |
| } |
| } |
| } |
| |
| StateFunctionTransform::finalizeStateIds(m_module, shaderEntryStateIds); |
| } |
| |
| void DxrFallbackCompiler::createLaunchParams(Function *func) { |
| Module *mod = func->getParent(); |
| Function *rewrite_setLaunchParams = |
| mod->getFunction("rewrite_setLaunchParams"); |
| CallInst *call = dyn_cast<CallInst>(*rewrite_setLaunchParams->user_begin()); |
| |
| LLVMContext &context = mod->getContext(); |
| Instruction *insertBefore = call; |
| |
| Function *DTidFunc = |
| FunctionBuilder(mod, "dx.op.threadId.i32").i32().i32().i32().build(); |
| Value *DTidx = |
| CallInst::Create(DTidFunc, |
| {makeInt32((int)hlsl::OP::OpCode::ThreadId, context), |
| makeInt32(0, context)}, |
| "DTidx", insertBefore); |
| Value *DTidy = |
| CallInst::Create(DTidFunc, |
| {makeInt32((int)hlsl::OP::OpCode::ThreadId, context), |
| makeInt32(1, context)}, |
| "DTidy", insertBefore); |
| |
| Value *dimx = call->getArgOperand(1); |
| Value *dimy = call->getArgOperand(2); |
| |
| Function *groupIndexFunc = |
| FunctionBuilder(mod, "dx.op.flattenedThreadIdInGroup.i32") |
| .i32() |
| .i32() |
| .build(); |
| Value *groupIndex = CallInst::Create(groupIndexFunc, {makeInt32(96, context)}, |
| "groupIndex", insertBefore); |
| |
| Function *fb_setLaunchParams = |
| mod->getFunction("fb_Fallback_SetLaunchParams"); |
| Value *runtimeDataArg = call->getArgOperand(0); |
| CallInst::Create(fb_setLaunchParams, |
| {runtimeDataArg, DTidx, DTidy, dimx, dimy, groupIndex}, "", |
| insertBefore); |
| |
| call->eraseFromParent(); |
| rewrite_setLaunchParams->eraseFromParent(); |
| } |
| |
| void DxrFallbackCompiler::createStateDispatch( |
| Function *func, const IntToFuncMap &stateFunctionMap, |
| Type *runtimeDataArgTy) { |
| Module *mod = func->getParent(); |
| Function *dispatchFunc = |
| createDispatchFunction(stateFunctionMap, runtimeDataArgTy); |
| Function *rewrite_dispatchFunc = mod->getFunction("rewrite_dispatch"); |
| rewrite_dispatchFunc->replaceAllUsesWith(dispatchFunc); |
| rewrite_dispatchFunc->eraseFromParent(); |
| } |
| |
| void DxrFallbackCompiler::createStack(Function *func) { |
| LLVMContext &context = func->getContext(); |
| |
| // We would like to allocate the properly sized stack here, but DXIL doesn't |
| // allow bitcasts between objects of different sizes. So we have to use the |
| // default size from the runtime and replace all the accesses later. |
| Function *rewrite_createStack = m_module->getFunction("rewrite_createStack"); |
| CallInst *call = dyn_cast<CallInst>(*rewrite_createStack->user_begin()); |
| AllocaInst *stack = new AllocaInst(call->getType()->getPointerElementType(), |
| "theStack", call); |
| stack->setAlignment(sizeof(int)); |
| call->replaceAllUsesWith(stack); |
| call->eraseFromParent(); |
| rewrite_createStack->eraseFromParent(); |
| |
| if (m_stackSizeInBytes == 0) // Take the default |
| m_stackSizeInBytes = |
| stack->getType()->getPointerElementType()->getArrayNumElements() * |
| sizeof(int); |
| Function *rewrite_getStackSize = |
| m_module->getFunction("rewrite_getStackSize"); |
| call = dyn_cast<CallInst>(*rewrite_getStackSize->user_begin()); |
| Value *stackSizeVal = makeInt32(m_stackSizeInBytes, context); |
| call->replaceAllUsesWith(stackSizeVal); |
| call->eraseFromParent(); |
| rewrite_getStackSize->eraseFromParent(); |
| } |
| |
| // WAR to avoid crazy <3 x float> code emitted by vanilla clang in the runtime |
| static bool expandFloat3(std::vector<Value *> &args, Value *arg, |
| Instruction *insertBefore) { |
| VectorType *argTy = dyn_cast<VectorType>(arg->getType()); |
| if (!argTy || argTy->getVectorNumElements() != 3) |
| return false; |
| |
| LLVMContext &C = arg->getContext(); |
| args.push_back( |
| ExtractElementInst::Create(arg, makeInt32(0, C), "vec.x", insertBefore)); |
| args.push_back( |
| ExtractElementInst::Create(arg, makeInt32(1, C), "vec.y", insertBefore)); |
| args.push_back( |
| ExtractElementInst::Create(arg, makeInt32(2, C), "vec.z", insertBefore)); |
| |
| return true; |
| } |
| |
| static bool float3x4ToFloat12(std::vector<Value *> &args, Value *arg, |
| Instruction *insertBefore) { |
| StructType *STy = dyn_cast<StructType>(arg->getType()); |
| if (!STy || STy->getName() != "class.matrix.float.3.4") |
| return false; |
| |
| BasicBlock &entryBlock = |
| insertBefore->getParent()->getParent()->getEntryBlock(); |
| AllocaInst *alloca = |
| new AllocaInst(arg->getType(), "tmp", entryBlock.begin()); |
| new StoreInst(arg, alloca, insertBefore); |
| VectorType *VTy = VectorType::get(Type::getFloatTy(arg->getContext()), 12); |
| Value *vec12Ptr = |
| new BitCastInst(alloca, VTy->getPointerTo(), "vec12.ptr", insertBefore); |
| Value *vec12 = new LoadInst(vec12Ptr, "vec12.", insertBefore); |
| args.push_back(vec12); |
| |
| return true; |
| } |
| |
| void DxrFallbackCompiler::lowerIntrinsics() { |
| std::vector<Function *> intrinsics = getFunctionsWithPrefix(m_module, "fb_"); |
| assert(intrinsics.size() > 0); |
| |
| // Replace intrinsics in anyhit shaders with their pending versions |
| LLVMContext &C = m_module->getContext(); |
| std::map<std::string, Function *> pendingIntrinsics; |
| std::string pendingPrefixes[] = {"fb_dxop_pending_", "fb_Fallback_Pending"}; |
| for (auto &F : intrinsics) { |
| std::string intrinsicName; |
| if (F->getName().startswith(pendingPrefixes[0])) |
| intrinsicName = F->getName().substr(pendingPrefixes[0].length()); |
| else if (F->getName().startswith(pendingPrefixes[1])) |
| intrinsicName = |
| "Fallback_" + F->getName().substr(pendingPrefixes[1].length()).str(); |
| else |
| continue; |
| |
| pendingIntrinsics[intrinsicName] = F; |
| } |
| |
| for (Function *func : intrinsics) { |
| StringRef intrinsicName; |
| std::string name; |
| bool isDxilOp = false; |
| if (func->getName().startswith("fb_Fallback_")) { |
| intrinsicName = func->getName().substr(3); // after the "fb_" prefix |
| name = "\x1?" + intrinsicName.str(); |
| } else if (func->getName().startswith("fb_dxop_")) { |
| intrinsicName = func->getName().substr(8); |
| name = "dx.op." + intrinsicName.str(); |
| isDxilOp = true; |
| } else { |
| assert(0 && "Bad intrinsic"); |
| } |
| std::vector<Function *> calledFunc = getFunctionsWithPrefix(m_module, name); |
| if (calledFunc.empty()) |
| continue; |
| std::vector<CallInst *> calls = getCallsToFunction(calledFunc[0]); |
| if (calls.empty()) |
| continue; |
| |
| bool needsRuntimeDataArg = (intrinsicName != "Fallback_Scheduler"); |
| Function *pendingFunc = get(pendingIntrinsics, intrinsicName.str()); |
| Function *funcInModule = nullptr; |
| Function *pendingFuncInModule = nullptr; |
| for (CallInst *call : calls) { |
| Function *caller = call->getParent()->getParent(); |
| if (needsRuntimeDataArg && !caller->hasFnAttribute("state_function")) |
| continue; |
| |
| Function *F = nullptr; |
| if (pendingFunc && shouldUsePendingValue(caller, intrinsicName)) { |
| if (!pendingFuncInModule) |
| pendingFuncInModule = getOrInsertFunction(m_module, pendingFunc); |
| F = pendingFuncInModule; |
| } else { |
| if (!funcInModule) |
| funcInModule = getOrInsertFunction(m_module, func); |
| F = funcInModule; |
| } |
| |
| // insert runtime data and the rest of the arguments |
| std::vector<Value *> args; |
| if (needsRuntimeDataArg) |
| args.push_back(caller->arg_begin()); |
| int argIdx = 0; |
| for (auto &arg : call->arg_operands()) { |
| if (argIdx++ == 0 && isDxilOp) |
| continue; // skip the intrinsic number |
| if (!expandFloat3(args, arg, call) && |
| !float3x4ToFloat12(args, arg, call)) |
| args.push_back(arg); |
| } |
| |
| CallInst *newCall = CallInst::Create(F, args, "", call); |
| if (F->getFunctionType()->getReturnType() != Type::getVoidTy(C)) { |
| newCall->takeName(call); |
| call->replaceAllUsesWith(newCall); |
| } |
| call->eraseFromParent(); |
| } |
| } |
| } |
| |
| Type *DxrFallbackCompiler::getRuntimeDataArgType() { |
| // Get the first argument from a known runtime function (assuming the runtime |
| // has already been linked in). |
| Function *F = m_module->getFunction("stackIntPtr"); |
| return F->arg_begin()->getType(); |
| } |
| |
| Function *DxrFallbackCompiler::createDispatchFunction( |
| const IntToFuncMap &stateFunctionMap, Type *runtimeDataArgTy) { |
| LLVMContext &context = m_module->getContext(); |
| FunctionType *stateFuncTy = |
| FunctionType::get(Type::getInt32Ty(context), {runtimeDataArgTy}, false); |
| |
| Function *dispatchFunc = FunctionBuilder(m_module, "dispatch") |
| .i32() |
| .type(runtimeDataArgTy, "runtimeData") |
| .i32("stateID") |
| .build(); |
| Value *runtimeDataArg = dispatchFunc->arg_begin(); |
| Value *stateIdArg = ++dispatchFunc->arg_begin(); |
| BasicBlock *entryBlock = BasicBlock::Create(context, "entry", dispatchFunc); |
| BasicBlock *badBlock = |
| BasicBlock::Create(context, "badStateID", dispatchFunc); |
| IRBuilder<> builder(badBlock); |
| builder.SetInsertPoint(badBlock); |
| builder.CreateRet(makeInt32(-3, context)); // return an error value |
| |
| builder.SetInsertPoint(entryBlock); |
| SwitchInst *switchInst = |
| builder.CreateSwitch(stateIdArg, badBlock, stateFunctionMap.size()); |
| BasicBlock *endBlock = badBlock; |
| for (auto &kv : stateFunctionMap) { |
| int stateId = kv.first; |
| Function *stateFunc = kv.second; |
| |
| Value *stateFuncInModule = |
| m_module->getOrInsertFunction(stateFunc->getName(), stateFuncTy); |
| BasicBlock *block = BasicBlock::Create( |
| context, "state_" + Twine(stateId) + "." + stateFunc->getName(), |
| dispatchFunc, endBlock); |
| builder.SetInsertPoint(block); |
| Value *nextStateId = |
| builder.CreateCall(stateFuncInModule, {runtimeDataArg}, "nextStateId"); |
| builder.CreateRet(nextStateId); |
| |
| switchInst->addCase(makeInt32(stateId, context), block); |
| } |
| |
| return dispatchFunc; |
| } |
| |
| std::vector<CallInst *> |
| DxrFallbackCompiler::getCallsInShadersToFunction(const std::string &funcName) { |
| std::vector<CallInst *> calls; |
| Function *F = m_module->getFunction(funcName); |
| if (!F) |
| return calls; |
| |
| for (User *U : F->users()) { |
| CallInst *call = dyn_cast<CallInst>(U); |
| if (!call) |
| continue; |
| |
| Function *caller = call->getParent()->getParent(); |
| auto it = m_shaderMap.find(cleanName(caller->getName())); |
| if (it != m_shaderMap.end()) |
| calls.push_back(call); |
| } |
| return calls; |
| } |
| |
| std::vector<CallInst *> |
| DxrFallbackCompiler::getCallsInShadersToFunctionWithPrefix( |
| const std::string &funcNamePrefix) { |
| std::vector<CallInst *> calls; |
| for (Function *F : getFunctionsWithPrefix(m_module, funcNamePrefix)) { |
| for (User *U : F->users()) { |
| CallInst *call = dyn_cast<CallInst>(U); |
| if (!call) |
| continue; |
| |
| Function *caller = call->getParent()->getParent(); |
| if (m_shaderMap.count(cleanName(caller->getName()))) |
| calls.push_back(call); |
| } |
| } |
| return calls; |
| } |
| |
| void DxrFallbackCompiler::resizeStack(Function *F, unsigned sizeInBytes) { |
| // Find the stack |
| AllocaInst *stack = nullptr; |
| for (auto &I : F->getEntryBlock().getInstList()) { |
| AllocaInst *alloc = dyn_cast<AllocaInst>(&I); |
| if (alloc && alloc->getName().startswith("theStack")) { |
| stack = alloc; |
| break; |
| } |
| } |
| if (!stack) |
| return; |
| |
| // Create a new stack |
| LLVMContext &C = F->getContext(); |
| ArrayType *newStackTy = |
| ArrayType::get(Type::getInt32Ty(C), sizeInBytes / sizeof(int)); |
| AllocaInst *newStack = new AllocaInst(newStackTy, "", stack); |
| newStack->takeName(stack); |
| |
| // Remap all GEPs - replaceAllUsesWith() won't change types |
| for (auto U = stack->user_begin(), UE = stack->user_end(); U != UE;) { |
| GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(*U++); |
| assert(gep && "theStack has non-gep user."); |
| |
| std::vector<Value *> idxList(gep->idx_begin(), gep->idx_end()); |
| GetElementPtrInst *newGep = |
| GetElementPtrInst::CreateInBounds(newStack, idxList, "", gep); |
| newGep->takeName(gep); |
| gep->replaceAllUsesWith(newGep); |
| gep->eraseFromParent(); |
| } |
| |
| stack->eraseFromParent(); |
| } |