blob: ca8a8a33fdb475542b3705f3f7a8b8af2554a21f [file] [log] [blame] [edit]
///////////////////////////////////////////////////////////////////////////////
// //
// HLMatrixLowerPass.cpp //
// Copyright (C) Microsoft Corporation. All rights reserved. //
// This file is distributed under the University of Illinois Open Source //
// License. See LICENSE.TXT for details. //
// //
// HLMatrixLowerPass implementation. //
// //
///////////////////////////////////////////////////////////////////////////////
#include "dxc/HLSL/HLMatrixLowerPass.h"
#include "HLMatrixSubscriptUseReplacer.h"
#include "dxc/DXIL/DxilModule.h"
#include "dxc/DXIL/DxilOperations.h"
#include "dxc/DXIL/DxilTypeSystem.h"
#include "dxc/DXIL/DxilUtil.h"
#include "dxc/HLSL/HLMatrixLowerHelper.h"
#include "dxc/HLSL/HLMatrixType.h"
#include "dxc/HLSL/HLModule.h"
#include "dxc/HLSL/HLOperations.h"
#include "dxc/HlslIntrinsicOp.h"
#include "dxc/Support/Global.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DIBuilder.h"
#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/Local.h"
#include <unordered_set>
#include <vector>
using namespace llvm;
using namespace hlsl;
using namespace hlsl::HLMatrixLower;
namespace hlsl {
namespace HLMatrixLower {
Value *BuildVector(Type *EltTy, ArrayRef<llvm::Value *> elts,
IRBuilder<> &Builder) {
Value *Vec = UndefValue::get(
VectorType::get(EltTy, static_cast<unsigned>(elts.size())));
for (unsigned i = 0; i < elts.size(); i++)
Vec = Builder.CreateInsertElement(Vec, elts[i], i);
return Vec;
}
} // namespace HLMatrixLower
} // namespace hlsl
namespace {
// Creates and manages a set of temporary overloaded functions keyed on the
// function type, and which should be destroyed when the pool gets out of scope.
class TempOverloadPool {
public:
TempOverloadPool(llvm::Module &Module, const char *BaseName)
: Module(Module), BaseName(BaseName) {}
~TempOverloadPool() { clear(); }
Function *get(FunctionType *Ty);
bool contains(FunctionType *Ty) const { return Funcs.count(Ty) != 0; }
bool contains(Function *Func) const;
void clear();
private:
llvm::Module &Module;
const char *BaseName;
llvm::DenseMap<FunctionType *, Function *> Funcs;
};
Function *TempOverloadPool::get(FunctionType *Ty) {
auto It = Funcs.find(Ty);
if (It != Funcs.end())
return It->second;
std::string MangledName;
raw_string_ostream MangledNameStream(MangledName);
MangledNameStream << BaseName;
MangledNameStream << '.';
Ty->print(MangledNameStream);
MangledNameStream.flush();
Function *Func = cast<Function>(Module.getOrInsertFunction(MangledName, Ty));
Funcs.insert(std::make_pair(Ty, Func));
return Func;
}
bool TempOverloadPool::contains(Function *Func) const {
auto It = Funcs.find(Func->getFunctionType());
return It != Funcs.end() && It->second == Func;
}
void TempOverloadPool::clear() {
for (auto Entry : Funcs) {
DXASSERT(Entry.second->use_empty(),
"Temporary function still used during pool destruction.");
Entry.second->eraseFromParent();
}
Funcs.clear();
}
// High-level matrix lowering pass.
//
// This pass converts matrices to their lowered vector representations,
// including global variables, local variables and operations,
// but not function signatures (arguments and return types) - left to
// HLSignatureLower and HLMatrixBitcastLower, nor matrices obtained from
// resources or constant - left to HLOperationLower.
//
// Algorithm overview:
// 1. Find all matrix and matrix array global variables and lower them to
// vectors.
// Walk any GEPs and insert vec-to-mat translation stubs so that consuming
// instructions keep dealing with matrix types for the moment.
// 2. For each function
// 2a. Lower all matrix and matrix array allocas, just like global variables.
// 2b. Lower all other instructions producing or consuming matrices
//
// Conversion stubs are used to allow converting instructions in isolation,
// and in an order-independent manner:
//
// Initial: MatInst1(MatInst2(MatInst3))
// After lowering MatInst2: MatInst1(VecToMat(VecInst2(MatToVec(MatInst3))))
// After lowering MatInst1: VecInst1(VecInst2(MatToVec(MatInst3)))
// After lowering MatInst3: VecInst1(VecInst2(VecInst3))
class HLMatrixLowerPass : public ModulePass {
public:
static char ID; // Pass identification, replacement for typeid
explicit HLMatrixLowerPass() : ModulePass(ID) {}
StringRef getPassName() const override { return "HL matrix lower"; }
bool runOnModule(Module &M) override;
private:
void runOnFunction(Function &Func);
void addToDeadInsts(Instruction *Inst) { m_deadInsts.emplace_back(Inst); }
void deleteDeadInsts();
void getMatrixAllocasAndOtherInsts(Function &Func,
std::vector<AllocaInst *> &MatAllocas,
std::vector<Instruction *> &MatInsts);
Value *getLoweredByValOperand(Value *Val, IRBuilder<> &Builder,
bool DiscardStub = false);
Value *tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Builder,
bool DiscardStub = false);
Value *bitCastValue(Value *SrcVal, Type *DstTy, bool DstTyAlloca,
IRBuilder<> &Builder);
void replaceAllUsesByLoweredValue(Instruction *MatInst, Value *VecVal);
void replaceAllVariableUses(Value *MatPtr, Value *LoweredPtr);
void replaceAllVariableUses(SmallVectorImpl<Value *> &GEPIdxStack,
Value *StackTopPtr, Value *LoweredPtr);
Value *translateScalarMatMul(Value *scalar, Value *mat, IRBuilder<> &Builder,
bool isLhsScalar = true);
void lowerGlobal(GlobalVariable *Global);
Constant *lowerConstInitVal(Constant *Val);
AllocaInst *lowerAlloca(AllocaInst *MatAlloca);
void lowerInstruction(Instruction *Inst);
void lowerReturn(ReturnInst *Return);
Value *lowerCall(CallInst *Call);
Value *lowerNonHLCall(CallInst *Call);
void lowerPreciseCall(CallInst *Call, IRBuilder<> Builder);
Value *lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeGroup);
Value *lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode);
Value *lowerHLMulIntrinsic(Value *Lhs, Value *Rhs, bool Unsigned,
IRBuilder<> &Builder);
Value *lowerHLTransposeIntrinsic(Value *MatVal, IRBuilder<> &Builder);
Value *lowerHLDeterminantIntrinsic(Value *MatVal, IRBuilder<> &Builder);
Value *lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode,
IRBuilder<> &Builder);
Value *lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode,
IRBuilder<> &Builder);
Value *lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode);
Value *lowerHLLoad(CallInst *Call, Value *MatPtr, bool RowMajor,
IRBuilder<> &Builder);
Value *lowerHLStore(CallInst *Call, Value *MatVal, Value *MatPtr,
bool RowMajor, bool Return, IRBuilder<> &Builder);
Value *lowerHLCast(CallInst *Call, Value *Src, Type *DstTy,
HLCastOpcode Opcode, IRBuilder<> &Builder);
Value *lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
Value *lowerHLMatElementSubscript(CallInst *Call, bool RowMajor);
Value *lowerHLMatSubscript(CallInst *Call, bool RowMajor);
void lowerHLMatSubscript(CallInst *Call, Value *MatPtr,
SmallVectorImpl<Value *> &ElemIndices);
Value *lowerHLInit(CallInst *Call);
Value *lowerHLSelect(CallInst *Call);
private:
Module *m_pModule;
HLModule *m_pHLModule;
bool m_HasDbgInfo;
// Pools for the translation stubs
TempOverloadPool *m_matToVecStubs = nullptr;
TempOverloadPool *m_vecToMatStubs = nullptr;
std::vector<Instruction *> m_deadInsts;
};
} // namespace
char HLMatrixLowerPass::ID = 0;
ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower",
"HLSL High-Level Matrix Lower", false, false)
bool HLMatrixLowerPass::runOnModule(Module &M) {
TempOverloadPool matToVecStubs(M, "hlmatrixlower.mat2vec");
TempOverloadPool vecToMatStubs(M, "hlmatrixlower.vec2mat");
m_pModule = &M;
m_pHLModule = &m_pModule->GetOrCreateHLModule();
// Load up debug information, to cross-reference values and the instructions
// used to load them.
m_HasDbgInfo = hasDebugInfo(M);
m_matToVecStubs = &matToVecStubs;
m_vecToMatStubs = &vecToMatStubs;
// First, lower static global variables.
// We need to accumulate them locally because we'll be creating new ones as we
// lower them.
std::vector<GlobalVariable *> Globals;
for (GlobalVariable &Global : M.globals()) {
if ((dxilutil::IsStaticGlobal(&Global) ||
dxilutil::IsSharedMemoryGlobal(&Global)) &&
HLMatrixType::isMatrixPtrOrArrayPtr(Global.getType())) {
Globals.emplace_back(&Global);
}
}
for (GlobalVariable *Global : Globals)
lowerGlobal(Global);
for (Function &F : M.functions()) {
if (F.isDeclaration())
continue;
runOnFunction(F);
}
m_pModule = nullptr;
m_pHLModule = nullptr;
m_matToVecStubs = nullptr;
m_vecToMatStubs = nullptr;
// If you hit an assert during TempOverloadPool destruction,
// it means that either a matrix producer was lowered,
// causing a translation stub to be created,
// but the consumer of that matrix was never (properly) lowered.
// Or the opposite: a matrix consumer was lowered and not its producer.
return true;
}
void HLMatrixLowerPass::runOnFunction(Function &Func) {
// Skip hl function definition (like createhandle)
if (hlsl::GetHLOpcodeGroupByName(&Func) != HLOpcodeGroup::NotHL)
return;
// Save the matrix instructions first since the translation process
// will temporarily create other instructions consuming/producing matrix
// types.
std::vector<AllocaInst *> MatAllocas;
std::vector<Instruction *> MatInsts;
getMatrixAllocasAndOtherInsts(Func, MatAllocas, MatInsts);
// First lower all allocas and take care of their GEP chains
for (AllocaInst *MatAlloca : MatAllocas) {
AllocaInst *LoweredAlloca = lowerAlloca(MatAlloca);
replaceAllVariableUses(MatAlloca, LoweredAlloca);
addToDeadInsts(MatAlloca);
}
// Now lower all other matrix instructions
for (Instruction *MatInst : MatInsts)
lowerInstruction(MatInst);
deleteDeadInsts();
}
void HLMatrixLowerPass::deleteDeadInsts() {
while (!m_deadInsts.empty()) {
Instruction *Inst = m_deadInsts.back();
m_deadInsts.pop_back();
DXASSERT_NOMSG(Inst->use_empty());
for (Value *Operand : Inst->operand_values()) {
Instruction *OperandInst = dyn_cast<Instruction>(Operand);
if (OperandInst &&
++OperandInst->user_begin() == OperandInst->user_end()) {
// We were its only user, erase recursively.
// This will get rid of translation stubs:
// Original: MatConsumer(MatProducer)
// Producer lowered: MatConsumer(VecToMat(VecProducer)), MatProducer
// dead Consumer lowered: VecConsumer(VecProducer)),
// MatConsumer(VecToMat) dead Only by recursing on MatConsumer's operand
// do we delete the VecToMat stub.
DXASSERT_NOMSG(*OperandInst->user_begin() == Inst);
m_deadInsts.emplace_back(OperandInst);
}
}
Inst->eraseFromParent();
}
}
// Find all instructions consuming or producing matrices,
// directly or through pointers/arrays.
void HLMatrixLowerPass::getMatrixAllocasAndOtherInsts(
Function &Func, std::vector<AllocaInst *> &MatAllocas,
std::vector<Instruction *> &MatInsts) {
for (BasicBlock &BasicBlock : Func) {
for (Instruction &Inst : BasicBlock) {
// Don't lower GEPs directly, we'll handle them as we lower the root
// pointer, typically a global variable or alloca.
if (isa<GetElementPtrInst>(&Inst))
continue;
// Don't lower lifetime intrinsics here, we'll handle them as we lower the
// alloca.
IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(&Inst);
if (Intrin && Intrin->getIntrinsicID() == Intrinsic::lifetime_start)
continue;
if (Intrin && Intrin->getIntrinsicID() == Intrinsic::lifetime_end)
continue;
if (AllocaInst *Alloca = dyn_cast<AllocaInst>(&Inst)) {
if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Alloca->getType())) {
MatAllocas.emplace_back(Alloca);
}
continue;
}
if (CallInst *Call = dyn_cast<CallInst>(&Inst)) {
// Lowering of global variables will have introduced
// vec-to-mat translation stubs, which we deal with indirectly,
// as we lower the instructions consuming them.
if (m_vecToMatStubs->contains(Call->getCalledFunction()))
continue;
// Mat-to-vec stubs should only be introduced during instruction
// lowering. Globals lowering won't introduce any because their only
// operand is their initializer, which we can fully lower without
// stubbing since it is constant.
DXASSERT(!m_matToVecStubs->contains(Call->getCalledFunction()),
"Unexpected mat-to-vec stubbing before function instruction "
"lowering.");
// Match matrix producers
if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Inst.getType())) {
MatInsts.emplace_back(Call);
continue;
}
// Match matrix consumers
for (Value *Operand : Inst.operand_values()) {
if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Operand->getType())) {
MatInsts.emplace_back(Call);
break;
}
}
continue;
}
if (ReturnInst *Return = dyn_cast<ReturnInst>(&Inst)) {
Value *ReturnValue = Return->getReturnValue();
if (ReturnValue != nullptr &&
HLMatrixType::isMatrixOrPtrOrArrayPtr(ReturnValue->getType()))
MatInsts.emplace_back(Return);
continue;
}
// Nothing else should produce or consume matrices
}
}
}
// Gets the matrix-lowered representation of a value, potentially adding a
// translation stub. DiscardStub causes any vec-to-mat translation stubs to be
// deleted, it should be true only if the original instruction will be modified
// and kept alive. If a new instruction is created and the original marked as
// dead, then the remove dead instructions pass will take care of removing the
// stub.
Value *HLMatrixLowerPass::getLoweredByValOperand(Value *Val,
IRBuilder<> &Builder,
bool DiscardStub) {
Type *Ty = Val->getType();
// We're only lowering byval matrices.
// Since structs and arrays are always accessed by pointer,
// we do not need to worry about a matrix being hidden inside a more complex
// type.
DXASSERT(!Ty->isPointerTy(), "Value cannot be a pointer.");
HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty);
if (!MatTy)
return Val;
Type *LoweredTy = MatTy.getLoweredVectorTypeForReg();
// Check if the value is already a vec-to-mat translation stub
if (CallInst *Call = dyn_cast<CallInst>(Val)) {
if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
if (DiscardStub && Call->getNumUses() == 1) {
Call->use_begin()->set(UndefValue::get(Call->getType()));
addToDeadInsts(Call);
}
Value *LoweredVal = Call->getArgOperand(0);
DXASSERT(LoweredVal->getType() == LoweredTy,
"Unexpected already-lowered value type.");
return LoweredVal;
}
}
// Lower mat 0 to vec 0.
if (isa<ConstantAggregateZero>(Val))
return ConstantAggregateZero::get(LoweredTy);
// Return a mat-to-vec translation stub
FunctionType *TranslationStubTy =
FunctionType::get(LoweredTy, {Ty}, /* isVarArg */ false);
Function *TranslationStub = m_matToVecStubs->get(TranslationStubTy);
return Builder.CreateCall(TranslationStub, {Val});
}
// Attempts to retrieve the lowered vector pointer equivalent to a matrix
// pointer. Returns nullptr if the pointed-to matrix lives in memory that cannot
// be lowered at this time, for example a buffer or shader inputs/outputs, which
// are lowered during signature lowering.
Value *HLMatrixLowerPass::tryGetLoweredPtrOperand(Value *Ptr,
IRBuilder<> &Builder,
bool DiscardStub) {
if (!HLMatrixType::isMatrixPtrOrArrayPtr(Ptr->getType()))
return nullptr;
// Matrix pointers can only be derived from Allocas, GlobalVariables or
// resource accesses. The first two cases are what this pass must be able to
// lower, and we should already have replaced their uses by vector to matrix
// pointer translation stubs.
if (CallInst *Call = dyn_cast<CallInst>(Ptr)) {
if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
if (DiscardStub && Call->getNumUses() == 1) {
Call->use_begin()->set(UndefValue::get(Call->getType()));
addToDeadInsts(Call);
}
return Call->getArgOperand(0);
}
}
// There's one more case to handle.
// When compiling shader libraries, signatures won't have been lowered yet.
// So we can have a matrix in a struct as an argument,
// or an alloca'd struct holding the return value of a call and containing a
// matrix.
Value *RootPtr = Ptr;
while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
RootPtr = GEP->getPointerOperand();
Argument *Arg = dyn_cast<Argument>(RootPtr);
bool IsNonShaderArg =
Arg != nullptr &&
!m_pHLModule->IsEntryThatUsesSignatures(Arg->getParent());
if (IsNonShaderArg || isa<AllocaInst>(RootPtr)) {
// Bitcast the matrix pointer to its lowered equivalent.
// The HLMatrixBitcast pass will take care of this later.
return Builder.CreateBitCast(Ptr,
HLMatrixType::getLoweredType(Ptr->getType()));
}
// The pointer must be derived from a resource, we don't handle it in this
// pass.
return nullptr;
}
// Bitcasts a value from matrix to vector or vice-versa.
// This is used to convert to/from arguments/return values since we don't
// lower signatures in this pass. The later HLMatrixBitcastLower pass fixes
// this.
Value *HLMatrixLowerPass::bitCastValue(Value *SrcVal, Type *DstTy,
bool DstTyAlloca, IRBuilder<> &Builder) {
Type *SrcTy = SrcVal->getType();
DXASSERT_NOMSG(!SrcTy->isPointerTy());
// We store and load from a temporary alloca, bitcasting either on the store
// pointer or on the load pointer.
IRBuilder<> AllocaBuilder(
dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
Value *Alloca = AllocaBuilder.CreateAlloca(DstTyAlloca ? DstTy : SrcTy);
Value *BitCastedAlloca = Builder.CreateBitCast(
Alloca, (DstTyAlloca ? SrcTy : DstTy)->getPointerTo());
Builder.CreateStore(SrcVal, DstTyAlloca ? BitCastedAlloca : Alloca);
return Builder.CreateLoad(DstTyAlloca ? Alloca : BitCastedAlloca);
}
// Replaces all uses of a matrix value by its lowered vector form,
// inserting translation stubs for users which still expect a matrix value.
void HLMatrixLowerPass::replaceAllUsesByLoweredValue(Instruction *MatInst,
Value *VecVal) {
if (VecVal == nullptr || VecVal == MatInst)
return;
DXASSERT(HLMatrixType::getLoweredType(MatInst->getType()) ==
VecVal->getType(),
"Unexpected lowered value type.");
Instruction *VecToMatStub = nullptr;
while (!MatInst->use_empty()) {
Use &ValUse = *MatInst->use_begin();
// Handle non-matrix cases, just point to the new value.
if (MatInst->getType() == VecVal->getType()) {
ValUse.set(VecVal);
continue;
}
// If the user is already a matrix-to-vector translation stub,
// we can now replace it by the proper vector value.
if (CallInst *Call = dyn_cast<CallInst>(ValUse.getUser())) {
if (m_matToVecStubs->contains(Call->getCalledFunction())) {
Call->replaceAllUsesWith(VecVal);
ValUse.set(UndefValue::get(MatInst->getType()));
addToDeadInsts(Call);
continue;
}
}
// Otherwise, the user should point to a vector-to-matrix translation
// stub of the new vector value.
if (VecToMatStub == nullptr) {
FunctionType *TranslationStubTy = FunctionType::get(
MatInst->getType(), {VecVal->getType()}, /* isVarArg */ false);
Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
Instruction *PrevInst = dyn_cast<Instruction>(VecVal);
if (PrevInst == nullptr)
PrevInst = MatInst;
IRBuilder<> Builder(PrevInst->getNextNode());
VecToMatStub = Builder.CreateCall(TranslationStub, {VecVal});
}
ValUse.set(VecToMatStub);
}
}
// Replaces all uses of a matrix or matrix array alloca or global variable by
// its lowered equivalent. This doesn't lower the users, but will insert a
// translation stub from the lowered value pointer back to the matrix value
// pointer, and recreate any GEPs around the new pointer. Before:
// User(GEP(MatrixArrayAlloca)) After:
// User(VecToMatPtrStub(GEP'(VectorArrayAlloca)))
void HLMatrixLowerPass::replaceAllVariableUses(Value *MatPtr,
Value *LoweredPtr) {
DXASSERT_NOMSG(HLMatrixType::isMatrixPtrOrArrayPtr(MatPtr->getType()));
DXASSERT_NOMSG(LoweredPtr->getType() ==
HLMatrixType::getLoweredType(MatPtr->getType()));
SmallVector<Value *, 4> GEPIdxStack;
GEPIdxStack.emplace_back(
ConstantInt::get(Type::getInt32Ty(MatPtr->getContext()), 0));
replaceAllVariableUses(GEPIdxStack, MatPtr, LoweredPtr);
}
void HLMatrixLowerPass::replaceAllVariableUses(
SmallVectorImpl<Value *> &GEPIdxStack, Value *StackTopPtr,
Value *LoweredPtr) {
while (!StackTopPtr->use_empty()) {
llvm::Use &Use = *StackTopPtr->use_begin();
if (GEPOperator *GEP = dyn_cast<GEPOperator>(Use.getUser())) {
DXASSERT(GEP->getNumIndices() >= 1, "Unexpected degenerate GEP.");
DXASSERT(cast<ConstantInt>(*GEP->idx_begin())->isZero(),
"Unexpected non-zero first GEP index.");
// Recurse in GEP to find actual users
for (auto It = GEP->idx_begin() + 1; It != GEP->idx_end(); ++It)
GEPIdxStack.emplace_back(*It);
replaceAllVariableUses(GEPIdxStack, GEP, LoweredPtr);
GEPIdxStack.erase(GEPIdxStack.end() - (GEP->getNumIndices() - 1),
GEPIdxStack.end());
// Discard the GEP
DXASSERT_NOMSG(GEP->use_empty());
if (GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(GEP)) {
Use.set(UndefValue::get(Use->getType()));
addToDeadInsts(GEPInst);
} else {
// constant GEP
cast<Constant>(GEP)->destroyConstant();
}
continue;
}
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Use.getUser())) {
DXASSERT(CE->getOpcode() == Instruction::AddrSpaceCast || CE->use_empty(),
"Unexpected constant user");
replaceAllVariableUses(GEPIdxStack, CE, LoweredPtr);
DXASSERT_NOMSG(CE->use_empty());
CE->destroyConstant();
continue;
}
if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(Use.getUser())) {
replaceAllVariableUses(GEPIdxStack, CI, LoweredPtr);
Use.set(UndefValue::get(Use->getType()));
addToDeadInsts(CI);
continue;
}
if (BitCastInst *BCI = dyn_cast<BitCastInst>(Use.getUser())) {
// Replace bitcasts to i8* for lifetime intrinsics.
if (BCI->getType()->isPointerTy() &&
BCI->getType()->getPointerElementType()->isIntegerTy(8)) {
DXASSERT(onlyUsedByLifetimeMarkers(BCI),
"bitcast to i8* must only be used by lifetime intrinsics");
Value *NewBCI =
IRBuilder<>(BCI).CreateBitCast(LoweredPtr, BCI->getType());
// Replace all uses of the use.
BCI->replaceAllUsesWith(NewBCI);
// Remove the current use to end iteration.
Use.set(UndefValue::get(Use->getType()));
addToDeadInsts(BCI);
continue;
}
}
// Recreate the same GEP sequence, if any, on the lowered pointer
IRBuilder<> Builder(cast<Instruction>(Use.getUser()));
Value *LoweredStackTopPtr =
GEPIdxStack.size() == 1 ? LoweredPtr
: Builder.CreateGEP(LoweredPtr, GEPIdxStack);
// Generate a stub translating the vector pointer back to a matrix pointer,
// such that consuming instructions are unaffected.
FunctionType *TranslationStubTy = FunctionType::get(
StackTopPtr->getType(), {LoweredStackTopPtr->getType()},
/* isVarArg */ false);
Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
Use.set(Builder.CreateCall(TranslationStub, {LoweredStackTopPtr}));
}
}
void HLMatrixLowerPass::lowerGlobal(GlobalVariable *Global) {
if (Global->user_empty())
return;
PointerType *LoweredPtrTy =
cast<PointerType>(HLMatrixType::getLoweredType(Global->getType()));
DXASSERT_NOMSG(LoweredPtrTy != Global->getType());
Constant *LoweredInitVal = Global->hasInitializer()
? lowerConstInitVal(Global->getInitializer())
: nullptr;
GlobalVariable *LoweredGlobal = new GlobalVariable(
*m_pModule, LoweredPtrTy->getElementType(), Global->isConstant(),
Global->getLinkage(), LoweredInitVal, Global->getName() + ".v",
/*InsertBefore*/ nullptr, Global->getThreadLocalMode(),
Global->getType()->getAddressSpace());
// Add debug info.
if (m_HasDbgInfo) {
DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
HLModule::UpdateGlobalVariableDebugInfo(Global, Finder, LoweredGlobal);
}
replaceAllVariableUses(Global, LoweredGlobal);
Global->removeDeadConstantUsers();
Global->eraseFromParent();
}
Constant *HLMatrixLowerPass::lowerConstInitVal(Constant *Val) {
Type *Ty = Val->getType();
// If it's an array of matrices, recurse for each element or nested array
if (ArrayType *ArrayTy = dyn_cast<ArrayType>(Ty)) {
SmallVector<Constant *, 4> LoweredElems;
unsigned NumElems = ArrayTy->getNumElements();
LoweredElems.reserve(NumElems);
for (unsigned ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
Constant *ArrayElem = Val->getAggregateElement(ElemIdx);
LoweredElems.emplace_back(lowerConstInitVal(ArrayElem));
}
Type *LoweredElemTy = HLMatrixType::getLoweredType(
ArrayTy->getElementType(), /*MemRepr*/ true);
ArrayType *LoweredArrayTy = ArrayType::get(LoweredElemTy, NumElems);
return ConstantArray::get(LoweredArrayTy, LoweredElems);
}
// Otherwise it's a matrix, lower it to a vector
HLMatrixType MatTy = HLMatrixType::cast(Ty);
DXASSERT_NOMSG(isa<StructType>(Ty));
Constant *RowArrayVal = Val->getAggregateElement((unsigned)0);
// Original initializer should have been produced in row/column-major order
// depending on the qualifiers of the target variable, so preserve the order.
SmallVector<Constant *, 16> MatElems;
for (unsigned RowIdx = 0; RowIdx < MatTy.getNumRows(); ++RowIdx) {
Constant *RowVal = RowArrayVal->getAggregateElement(RowIdx);
for (unsigned ColIdx = 0; ColIdx < MatTy.getNumColumns(); ++ColIdx) {
MatElems.emplace_back(RowVal->getAggregateElement(ColIdx));
}
}
Constant *Vec = ConstantVector::get(MatElems);
// Matrix elements are always in register representation,
// but the lowered global variable is of vector type in
// its memory representation, so we must convert here.
// This will produce a constant so we can use an IRBuilder without a valid
// insertion point.
IRBuilder<> DummyBuilder(Val->getContext());
return cast<Constant>(MatTy.emitLoweredRegToMem(Vec, DummyBuilder));
}
AllocaInst *HLMatrixLowerPass::lowerAlloca(AllocaInst *MatAlloca) {
PointerType *LoweredAllocaTy =
cast<PointerType>(HLMatrixType::getLoweredType(MatAlloca->getType()));
IRBuilder<> Builder(MatAlloca);
AllocaInst *LoweredAlloca = Builder.CreateAlloca(
LoweredAllocaTy->getElementType(), nullptr, MatAlloca->getName());
// Update debug info.
if (DbgDeclareInst *DbgDeclare = llvm::FindAllocaDbgDeclare(MatAlloca)) {
DILocalVariable *DbgDeclareVar = DbgDeclare->getVariable();
DIExpression *DbgDeclareExpr = DbgDeclare->getExpression();
DIBuilder DIB(*MatAlloca->getModule());
DIB.insertDeclare(LoweredAlloca, DbgDeclareVar, DbgDeclareExpr,
DbgDeclare->getDebugLoc(), DbgDeclare);
}
if (HLModule::HasPreciseAttributeWithMetadata(MatAlloca))
HLModule::MarkPreciseAttributeWithMetadata(LoweredAlloca);
replaceAllVariableUses(MatAlloca, LoweredAlloca);
return LoweredAlloca;
}
void HLMatrixLowerPass::lowerInstruction(Instruction *Inst) {
if (CallInst *Call = dyn_cast<CallInst>(Inst)) {
Value *LoweredValue = lowerCall(Call);
// lowerCall returns the lowered value iff we should discard
// the original matrix instruction and replace all of its uses
// by the lowered value. It returns nullptr to opt-out of this.
if (LoweredValue != nullptr) {
replaceAllUsesByLoweredValue(Call, LoweredValue);
addToDeadInsts(Inst);
}
} else if (ReturnInst *Return = dyn_cast<ReturnInst>(Inst)) {
lowerReturn(Return);
} else
llvm_unreachable("Unexpected matrix instruction type.");
}
void HLMatrixLowerPass::lowerReturn(ReturnInst *Return) {
Value *RetVal = Return->getReturnValue();
Type *RetTy = RetVal->getType();
DXASSERT_LOCALVAR(RetTy, !RetTy->isPointerTy(),
"Unexpected matrix returned by pointer.");
IRBuilder<> Builder(Return);
Value *LoweredRetVal =
getLoweredByValOperand(RetVal, Builder, /* DiscardStub */ true);
// Since we're not lowering the signature, we can't return the lowered value
// directly, so insert a bitcast, which HLMatrixBitcastLower knows how to
// eliminate.
Value *BitCastedRetVal = bitCastValue(LoweredRetVal, RetVal->getType(),
/* DstTyAlloca */ false, Builder);
Return->setOperand(0, BitCastedRetVal);
}
Value *HLMatrixLowerPass::lowerCall(CallInst *Call) {
HLOpcodeGroup OpcodeGroup = GetHLOpcodeGroupByName(Call->getCalledFunction());
return OpcodeGroup == HLOpcodeGroup::NotHL
? lowerNonHLCall(Call)
: lowerHLOperation(Call, OpcodeGroup);
}
// Special function to lower precise call applied to a matrix
// The matrix should be lowered and the call regenerated with vector arg
void HLMatrixLowerPass::lowerPreciseCall(CallInst *Call, IRBuilder<> Builder) {
DXASSERT(Call->getNumArgOperands() == 1,
"Only one arg expected for precise matrix call");
Value *Arg = Call->getArgOperand(0);
Value *LoweredArg = getLoweredByValOperand(Arg, Builder);
HLModule::MarkPreciseAttributeOnValWithFunctionCall(LoweredArg, Builder,
*m_pModule);
addToDeadInsts(Call);
}
Value *HLMatrixLowerPass::lowerNonHLCall(CallInst *Call) {
// First, handle any operand of matrix-derived type
// We don't lower the callee's signature in this pass,
// so, for any matrix-typed parameter, we create a bitcast from the
// lowered vector back to the matrix type, which the later
// HLMatrixBitcastLower pass knows how to eliminate.
IRBuilder<> PreCallBuilder(Call);
unsigned NumArgs = Call->getNumArgOperands();
Function *Func = Call->getCalledFunction();
if (Func && HLModule::HasPreciseAttribute(Func)) {
lowerPreciseCall(Call, PreCallBuilder);
return nullptr;
}
for (unsigned ArgIdx = 0; ArgIdx < NumArgs; ++ArgIdx) {
Use &ArgUse = Call->getArgOperandUse(ArgIdx);
if (ArgUse->getType()->isPointerTy()) {
// Byref arg
Value *LoweredArg = tryGetLoweredPtrOperand(ArgUse.get(), PreCallBuilder,
/* DiscardStub */ true);
if (LoweredArg != nullptr) {
// Pointer to a matrix we've lowered, insert a bitcast back to matrix
// pointer type.
Value *BitCastedArg =
PreCallBuilder.CreateBitCast(LoweredArg, ArgUse->getType());
ArgUse.set(BitCastedArg);
}
} else {
// Byvalue arg
Value *LoweredArg = getLoweredByValOperand(ArgUse.get(), PreCallBuilder,
/* DiscardStub */ true);
if (LoweredArg == ArgUse.get())
continue;
Value *BitCastedArg =
bitCastValue(LoweredArg, ArgUse->getType(), /* DstTyAlloca */ false,
PreCallBuilder);
ArgUse.set(BitCastedArg);
}
}
// Now check the return type
HLMatrixType RetMatTy = HLMatrixType::dyn_cast(Call->getType());
if (!RetMatTy) {
DXASSERT(!HLMatrixType::isMatrixPtrOrArrayPtr(Call->getType()),
"Unexpected user call returning a matrix by pointer.");
// Nothing to replace, other instructions can consume a non-matrix return
// type.
return nullptr;
}
// The callee returns a matrix, and we don't lower signatures in this pass.
// We perform a sketchy bitcast to the lowered register-representation type,
// which the later HLMatrixBitcastLower pass knows how to eliminate.
IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Call));
Value *LoweredAlloca =
AllocaBuilder.CreateAlloca(RetMatTy.getLoweredVectorTypeForReg());
IRBuilder<> PostCallBuilder(Call->getNextNode());
Value *BitCastedAlloca = PostCallBuilder.CreateBitCast(
LoweredAlloca, Call->getType()->getPointerTo());
// This is slightly tricky
// We want to replace all uses of the matrix-returning call by the bitcasted
// value, but the store to the bitcasted pointer itself is a use of that
// matrix, so we need to create the load, replace the uses, and then insert
// the store.
LoadInst *LoweredVal = PostCallBuilder.CreateLoad(LoweredAlloca);
replaceAllUsesByLoweredValue(Call, LoweredVal);
// Now we can insert the store. Make sure to do so before the load.
PostCallBuilder.SetInsertPoint(LoweredVal);
PostCallBuilder.CreateStore(Call, BitCastedAlloca);
// Return nullptr since we did our own uses replacement and we don't want
// the matrix instruction to be marked as dead since we're still using it.
return nullptr;
}
Value *HLMatrixLowerPass::lowerHLOperation(CallInst *Call,
HLOpcodeGroup OpcodeGroup) {
IRBuilder<> Builder(Call);
switch (OpcodeGroup) {
case HLOpcodeGroup::HLIntrinsic:
return lowerHLIntrinsic(Call, static_cast<IntrinsicOp>(GetHLOpcode(Call)));
case HLOpcodeGroup::HLBinOp:
return lowerHLBinaryOperation(
Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
static_cast<HLBinaryOpcode>(GetHLOpcode(Call)), Builder);
case HLOpcodeGroup::HLUnOp:
return lowerHLUnaryOperation(
Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx),
static_cast<HLUnaryOpcode>(GetHLOpcode(Call)), Builder);
case HLOpcodeGroup::HLMatLoadStore:
return lowerHLLoadStore(
Call, static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(Call)));
case HLOpcodeGroup::HLCast:
return lowerHLCast(
Call, Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx),
Call->getType(), static_cast<HLCastOpcode>(GetHLOpcode(Call)), Builder);
case HLOpcodeGroup::HLSubscript:
return lowerHLSubscript(Call,
static_cast<HLSubscriptOpcode>(GetHLOpcode(Call)));
case HLOpcodeGroup::HLInit:
return lowerHLInit(Call);
case HLOpcodeGroup::HLSelect:
return lowerHLSelect(Call);
default:
llvm_unreachable("Unexpected matrix opcode");
}
}
Value *HLMatrixLowerPass::lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode) {
IRBuilder<> Builder(Call);
// See if this is a matrix-specific intrinsic which we should expand here
switch (Opcode) {
case IntrinsicOp::IOP_umul:
case IntrinsicOp::IOP_mul:
return lowerHLMulIntrinsic(
Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
/* Unsigned */ Opcode == IntrinsicOp::IOP_umul, Builder);
case IntrinsicOp::IOP_transpose:
return lowerHLTransposeIntrinsic(
Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
case IntrinsicOp::IOP_determinant:
return lowerHLDeterminantIntrinsic(
Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
}
// Delegate to a lowered intrinsic call
SmallVector<Value *, 4> LoweredArgs;
LoweredArgs.reserve(Call->getNumArgOperands());
for (Value *Arg : Call->arg_operands()) {
if (Arg->getType()->isPointerTy()) {
// ByRef parameter (for example, frexp's second parameter)
// If the argument points to a lowered matrix variable, replace it here,
// otherwise preserve the matrix type and let further passes handle the
// lowering.
Value *LoweredArg = tryGetLoweredPtrOperand(Arg, Builder);
if (LoweredArg == nullptr)
LoweredArg = Arg;
LoweredArgs.emplace_back(LoweredArg);
} else {
LoweredArgs.emplace_back(getLoweredByValOperand(Arg, Builder));
}
}
Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
return callHLFunction(
*m_pModule, HLOpcodeGroup::HLIntrinsic, static_cast<unsigned>(Opcode),
LoweredRetTy, LoweredArgs,
Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
}
// Handles multiplcation of a scalar with a matrix
Value *HLMatrixLowerPass::translateScalarMatMul(Value *Lhs, Value *Rhs,
IRBuilder<> &Builder,
bool isLhsScalar) {
Value *Mat = isLhsScalar ? Rhs : Lhs;
Value *Scalar = isLhsScalar ? Lhs : Rhs;
Value *LoweredMat = getLoweredByValOperand(Mat, Builder);
Type *ScalarTy = Scalar->getType();
FixedVectorType *VT = dyn_cast<FixedVectorType>(LoweredMat->getType());
// Perform the scalar-matrix multiplication!
Type *ElemTy = VT->getElementType();
bool isIntMulOp = ScalarTy->isIntegerTy() && ElemTy->isIntegerTy();
bool isFloatMulOp =
ScalarTy->isFloatingPointTy() && ElemTy->isFloatingPointTy();
DXASSERT(ScalarTy == ElemTy,
"Scalar type must match the matrix component type.");
Value *Result = Builder.CreateVectorSplat(VT->getNumElements(), Scalar);
if (isFloatMulOp) {
// Preserve the order of operation for floats
Result = isLhsScalar ? Builder.CreateFMul(Result, LoweredMat)
: Builder.CreateFMul(LoweredMat, Result);
} else if (isIntMulOp) {
// Doesn't matter for integers but still preserve the order of operation
Result = isLhsScalar ? Builder.CreateMul(Result, LoweredMat)
: Builder.CreateMul(LoweredMat, Result);
} else {
DXASSERT(
0, "Unknown type encountered when doing scalar-matrix multiplication.");
}
return Result;
}
Value *HLMatrixLowerPass::lowerHLMulIntrinsic(Value *Lhs, Value *Rhs,
bool Unsigned,
IRBuilder<> &Builder) {
HLMatrixType LhsMatTy = HLMatrixType::dyn_cast(Lhs->getType());
HLMatrixType RhsMatTy = HLMatrixType::dyn_cast(Rhs->getType());
Value *LoweredLhs = getLoweredByValOperand(Lhs, Builder);
Value *LoweredRhs = getLoweredByValOperand(Rhs, Builder);
// Translate multiplication of scalar with matrix
bool isLhsScalar = !LoweredLhs->getType()->isVectorTy();
bool isRhsScalar = !LoweredRhs->getType()->isVectorTy();
bool isScalar = isLhsScalar || isRhsScalar;
if (isScalar)
return translateScalarMatMul(Lhs, Rhs, Builder, isLhsScalar);
DXASSERT(LoweredLhs->getType()->getScalarType() ==
LoweredRhs->getType()->getScalarType(),
"Unexpected element type mismatch in mul intrinsic.");
DXASSERT(cast<VectorType>(LoweredLhs->getType()) &&
cast<VectorType>(LoweredRhs->getType()),
"Unexpected scalar in lowered matrix mul intrinsic operands.");
Type *ElemTy = LoweredLhs->getType()->getScalarType();
// Figure out the dimensions of each side
unsigned LhsNumRows, LhsNumCols, RhsNumRows, RhsNumCols;
if (LhsMatTy && RhsMatTy) {
LhsNumRows = LhsMatTy.getNumRows();
LhsNumCols = LhsMatTy.getNumColumns();
RhsNumRows = RhsMatTy.getNumRows();
RhsNumCols = RhsMatTy.getNumColumns();
} else if (LhsMatTy) {
LhsNumRows = LhsMatTy.getNumRows();
LhsNumCols = LhsMatTy.getNumColumns();
FixedVectorType *VT = dyn_cast<FixedVectorType>(LoweredRhs->getType());
RhsNumRows = VT->getNumElements();
RhsNumCols = 1;
} else if (RhsMatTy) {
LhsNumRows = 1;
FixedVectorType *VT = dyn_cast<FixedVectorType>(LoweredLhs->getType());
LhsNumCols = VT->getNumElements();
RhsNumRows = RhsMatTy.getNumRows();
RhsNumCols = RhsMatTy.getNumColumns();
} else {
llvm_unreachable("mul intrinsic was identified as a matrix operation but "
"neither operand is a matrix.");
}
DXASSERT(LhsNumCols == RhsNumRows,
"Matrix mul intrinsic operands dimensions mismatch.");
HLMatrixType ResultMatTy(ElemTy, LhsNumRows, RhsNumCols);
unsigned AccCount = LhsNumCols;
// Get the multiply-and-add intrinsic function, we'll need it
IntrinsicOp MadOpcode =
Unsigned ? IntrinsicOp::IOP_umad : IntrinsicOp::IOP_mad;
FunctionType *MadFuncTy = FunctionType::get(
ElemTy, {Builder.getInt32Ty(), ElemTy, ElemTy, ElemTy}, false);
Function *MadFunc = GetOrCreateHLFunction(
*m_pModule, MadFuncTy, HLOpcodeGroup::HLIntrinsic, (unsigned)MadOpcode);
Constant *MadOpcodeVal = Builder.getInt32((unsigned)MadOpcode);
// Perform the multiplication!
Value *Result =
UndefValue::get(VectorType::get(ElemTy, LhsNumRows * RhsNumCols));
for (unsigned ResultRowIdx = 0; ResultRowIdx < ResultMatTy.getNumRows();
++ResultRowIdx) {
for (unsigned ResultColIdx = 0; ResultColIdx < ResultMatTy.getNumColumns();
++ResultColIdx) {
unsigned ResultElemIdx =
ResultMatTy.getRowMajorIndex(ResultRowIdx, ResultColIdx);
Value *ResultElem = nullptr;
for (unsigned AccIdx = 0; AccIdx < AccCount; ++AccIdx) {
unsigned LhsElemIdx = HLMatrixType::getRowMajorIndex(
ResultRowIdx, AccIdx, LhsNumRows, LhsNumCols);
unsigned RhsElemIdx = HLMatrixType::getRowMajorIndex(
AccIdx, ResultColIdx, RhsNumRows, RhsNumCols);
Value *LhsElem = Builder.CreateExtractElement(
LoweredLhs, static_cast<uint64_t>(LhsElemIdx));
Value *RhsElem = Builder.CreateExtractElement(
LoweredRhs, static_cast<uint64_t>(RhsElemIdx));
if (ResultElem == nullptr) {
ResultElem = ElemTy->isFloatingPointTy()
? Builder.CreateFMul(LhsElem, RhsElem)
: Builder.CreateMul(LhsElem, RhsElem);
} else {
ResultElem = Builder.CreateCall(
MadFunc, {MadOpcodeVal, LhsElem, RhsElem, ResultElem});
}
}
Result = Builder.CreateInsertElement(
Result, ResultElem, static_cast<uint64_t>(ResultElemIdx));
}
}
return Result;
}
Value *HLMatrixLowerPass::lowerHLTransposeIntrinsic(Value *MatVal,
IRBuilder<> &Builder) {
HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
return MatTy.emitLoweredVectorRowToCol(LoweredVal, Builder);
}
static Value *determinant2x2(Value *M00, Value *M01, Value *M10, Value *M11,
IRBuilder<> &Builder) {
Value *Mul0 = Builder.CreateFMul(M00, M11);
Value *Mul1 = Builder.CreateFMul(M01, M10);
return Builder.CreateFSub(Mul0, Mul1);
}
static Value *determinant3x3(Value *M00, Value *M01, Value *M02, Value *M10,
Value *M11, Value *M12, Value *M20, Value *M21,
Value *M22, IRBuilder<> &Builder) {
Value *Det00 = determinant2x2(M11, M12, M21, M22, Builder);
Value *Det01 = determinant2x2(M10, M12, M20, M22, Builder);
Value *Det02 = determinant2x2(M10, M11, M20, M21, Builder);
Det00 = Builder.CreateFMul(M00, Det00);
Det01 = Builder.CreateFMul(M01, Det01);
Det02 = Builder.CreateFMul(M02, Det02);
Value *Result = Builder.CreateFSub(Det00, Det01);
Result = Builder.CreateFAdd(Result, Det02);
return Result;
}
static Value *determinant4x4(Value *M00, Value *M01, Value *M02, Value *M03,
Value *M10, Value *M11, Value *M12, Value *M13,
Value *M20, Value *M21, Value *M22, Value *M23,
Value *M30, Value *M31, Value *M32, Value *M33,
IRBuilder<> &Builder) {
Value *Det00 =
determinant3x3(M11, M12, M13, M21, M22, M23, M31, M32, M33, Builder);
Value *Det01 =
determinant3x3(M10, M12, M13, M20, M22, M23, M30, M32, M33, Builder);
Value *Det02 =
determinant3x3(M10, M11, M13, M20, M21, M23, M30, M31, M33, Builder);
Value *Det03 =
determinant3x3(M10, M11, M12, M20, M21, M22, M30, M31, M32, Builder);
Det00 = Builder.CreateFMul(M00, Det00);
Det01 = Builder.CreateFMul(M01, Det01);
Det02 = Builder.CreateFMul(M02, Det02);
Det03 = Builder.CreateFMul(M03, Det03);
Value *Result = Builder.CreateFSub(Det00, Det01);
Result = Builder.CreateFAdd(Result, Det02);
Result = Builder.CreateFSub(Result, Det03);
return Result;
}
Value *HLMatrixLowerPass::lowerHLDeterminantIntrinsic(Value *MatVal,
IRBuilder<> &Builder) {
HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
DXASSERT_NOMSG(MatTy.getNumColumns() == MatTy.getNumRows());
Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
// Extract all matrix elements
SmallVector<Value *, 16> Elems;
for (unsigned ElemIdx = 0; ElemIdx < MatTy.getNumElements(); ++ElemIdx)
Elems.emplace_back(Builder.CreateExtractElement(
LoweredVal, static_cast<uint64_t>(ElemIdx)));
// Delegate to appropriate determinant function
switch (MatTy.getNumColumns()) {
case 1:
return Elems[0];
case 2:
return determinant2x2(Elems[0], Elems[1], Elems[2], Elems[3], Builder);
case 3:
return determinant3x3(Elems[0], Elems[1], Elems[2], Elems[3], Elems[4],
Elems[5], Elems[6], Elems[7], Elems[8], Builder);
case 4:
return determinant4x4(Elems[0], Elems[1], Elems[2], Elems[3], Elems[4],
Elems[5], Elems[6], Elems[7], Elems[8], Elems[9],
Elems[10], Elems[11], Elems[12], Elems[13], Elems[14],
Elems[15], Builder);
default:
llvm_unreachable("Unexpected matrix dimensions.");
}
}
Value *HLMatrixLowerPass::lowerHLUnaryOperation(Value *MatVal,
HLUnaryOpcode Opcode,
IRBuilder<> &Builder) {
Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
VectorType *VecTy = cast<VectorType>(LoweredVal->getType());
bool IsFloat = VecTy->getElementType()->isFloatingPointTy();
switch (Opcode) {
case HLUnaryOpcode::Plus:
return LoweredVal; // No-op
case HLUnaryOpcode::Minus:
return IsFloat
? Builder.CreateFSub(Constant::getNullValue(VecTy), LoweredVal)
: Builder.CreateSub(Constant::getNullValue(VecTy), LoweredVal);
case HLUnaryOpcode::LNot:
return IsFloat ? Builder.CreateFCmp(CmpInst::FCMP_UEQ, LoweredVal,
Constant::getNullValue(VecTy))
: Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredVal,
Constant::getNullValue(VecTy));
case HLUnaryOpcode::Not:
return Builder.CreateXor(LoweredVal, Constant::getAllOnesValue(VecTy));
case HLUnaryOpcode::PostInc:
case HLUnaryOpcode::PreInc:
case HLUnaryOpcode::PostDec:
case HLUnaryOpcode::PreDec: {
Constant *ScalarOne = IsFloat
? ConstantFP::get(VecTy->getElementType(), 1)
: ConstantInt::get(VecTy->getElementType(), 1);
Constant *VecOne =
ConstantVector::getSplat(VecTy->getNumElements(), ScalarOne);
// CodeGen already emitted the load and following store, our job is only to
// produce the updated value.
if (Opcode == HLUnaryOpcode::PostInc || Opcode == HLUnaryOpcode::PreInc) {
return IsFloat ? Builder.CreateFAdd(LoweredVal, VecOne)
: Builder.CreateAdd(LoweredVal, VecOne);
} else {
return IsFloat ? Builder.CreateFSub(LoweredVal, VecOne)
: Builder.CreateSub(LoweredVal, VecOne);
}
}
default:
llvm_unreachable("Unsupported unary matrix operator");
}
}
Value *HLMatrixLowerPass::lowerHLBinaryOperation(Value *Lhs, Value *Rhs,
HLBinaryOpcode Opcode,
IRBuilder<> &Builder) {
Value *LoweredLhs = getLoweredByValOperand(Lhs, Builder);
Value *LoweredRhs = getLoweredByValOperand(Rhs, Builder);
DXASSERT(LoweredLhs->getType()->isVectorTy() &&
LoweredRhs->getType()->isVectorTy(),
"Expected lowered binary operation operands to be vectors");
DXASSERT(
LoweredLhs->getType() == LoweredRhs->getType(),
"Expected lowered binary operation operands to have matching types.");
FixedVectorType *VT = dyn_cast<FixedVectorType>(LoweredLhs->getType());
bool IsFloat = VT->getElementType()->isFloatingPointTy();
switch (Opcode) {
case HLBinaryOpcode::Add:
return IsFloat ? Builder.CreateFAdd(LoweredLhs, LoweredRhs)
: Builder.CreateAdd(LoweredLhs, LoweredRhs);
case HLBinaryOpcode::Sub:
return IsFloat ? Builder.CreateFSub(LoweredLhs, LoweredRhs)
: Builder.CreateSub(LoweredLhs, LoweredRhs);
case HLBinaryOpcode::Mul:
return IsFloat ? Builder.CreateFMul(LoweredLhs, LoweredRhs)
: Builder.CreateMul(LoweredLhs, LoweredRhs);
case HLBinaryOpcode::Div:
return IsFloat ? Builder.CreateFDiv(LoweredLhs, LoweredRhs)
: Builder.CreateSDiv(LoweredLhs, LoweredRhs);
case HLBinaryOpcode::Rem:
return IsFloat ? Builder.CreateFRem(LoweredLhs, LoweredRhs)
: Builder.CreateSRem(LoweredLhs, LoweredRhs);
case HLBinaryOpcode::And:
return Builder.CreateAnd(LoweredLhs, LoweredRhs);
case HLBinaryOpcode::Or:
return Builder.CreateOr(LoweredLhs, LoweredRhs);
case HLBinaryOpcode::Xor:
return Builder.CreateXor(LoweredLhs, LoweredRhs);
case HLBinaryOpcode::Shl:
return Builder.CreateShl(LoweredLhs, LoweredRhs);
case HLBinaryOpcode::Shr:
return Builder.CreateAShr(LoweredLhs, LoweredRhs);
case HLBinaryOpcode::LT:
return IsFloat
? Builder.CreateFCmp(CmpInst::FCMP_OLT, LoweredLhs, LoweredRhs)
: Builder.CreateICmp(CmpInst::ICMP_SLT, LoweredLhs, LoweredRhs);
case HLBinaryOpcode::GT:
return IsFloat
? Builder.CreateFCmp(CmpInst::FCMP_OGT, LoweredLhs, LoweredRhs)
: Builder.CreateICmp(CmpInst::ICMP_SGT, LoweredLhs, LoweredRhs);
case HLBinaryOpcode::LE:
return IsFloat
? Builder.CreateFCmp(CmpInst::FCMP_OLE, LoweredLhs, LoweredRhs)
: Builder.CreateICmp(CmpInst::ICMP_SLE, LoweredLhs, LoweredRhs);
case HLBinaryOpcode::GE:
return IsFloat
? Builder.CreateFCmp(CmpInst::FCMP_OGE, LoweredLhs, LoweredRhs)
: Builder.CreateICmp(CmpInst::ICMP_SGE, LoweredLhs, LoweredRhs);
case HLBinaryOpcode::EQ:
return IsFloat
? Builder.CreateFCmp(CmpInst::FCMP_OEQ, LoweredLhs, LoweredRhs)
: Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredLhs, LoweredRhs);
case HLBinaryOpcode::NE:
return IsFloat
? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, LoweredRhs)
: Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, LoweredRhs);
case HLBinaryOpcode::UDiv:
return Builder.CreateUDiv(LoweredLhs, LoweredRhs);
case HLBinaryOpcode::URem:
return Builder.CreateURem(LoweredLhs, LoweredRhs);
case HLBinaryOpcode::UShr:
return Builder.CreateLShr(LoweredLhs, LoweredRhs);
case HLBinaryOpcode::ULT:
return Builder.CreateICmp(CmpInst::ICMP_ULT, LoweredLhs, LoweredRhs);
case HLBinaryOpcode::UGT:
return Builder.CreateICmp(CmpInst::ICMP_UGT, LoweredLhs, LoweredRhs);
case HLBinaryOpcode::ULE:
return Builder.CreateICmp(CmpInst::ICMP_ULE, LoweredLhs, LoweredRhs);
case HLBinaryOpcode::UGE:
return Builder.CreateICmp(CmpInst::ICMP_UGE, LoweredLhs, LoweredRhs);
case HLBinaryOpcode::LAnd:
case HLBinaryOpcode::LOr: {
Value *Zero = Constant::getNullValue(LoweredLhs->getType());
Value *LhsCmp =
IsFloat ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, Zero)
: Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, Zero);
Value *RhsCmp =
IsFloat ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredRhs, Zero)
: Builder.CreateICmp(CmpInst::ICMP_NE, LoweredRhs, Zero);
return Opcode == HLBinaryOpcode::LOr ? Builder.CreateOr(LhsCmp, RhsCmp)
: Builder.CreateAnd(LhsCmp, RhsCmp);
}
default:
llvm_unreachable("Unsupported binary matrix operator");
}
}
Value *HLMatrixLowerPass::lowerHLLoadStore(CallInst *Call,
HLMatLoadStoreOpcode Opcode) {
IRBuilder<> Builder(Call);
switch (Opcode) {
case HLMatLoadStoreOpcode::RowMatLoad:
case HLMatLoadStoreOpcode::ColMatLoad:
return lowerHLLoad(
Call, Call->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx),
/* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatLoad, Builder);
case HLMatLoadStoreOpcode::RowMatStore:
case HLMatLoadStoreOpcode::ColMatStore:
return lowerHLStore(
Call, Call->getArgOperand(HLOperandIndex::kMatStoreValOpIdx),
Call->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx),
/* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatStore,
/* Return */ !Call->getType()->isVoidTy(), Builder);
default:
llvm_unreachable("Unsupported matrix load/store operation");
}
}
Value *HLMatrixLowerPass::lowerHLLoad(CallInst *Call, Value *MatPtr,
bool RowMajor, IRBuilder<> &Builder) {
HLMatrixType MatTy =
HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
if (LoweredPtr == nullptr) {
// Can't lower this here, defer to HL signature lower
HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatLoad
: HLMatLoadStoreOpcode::ColMatLoad;
return callHLFunction(
*m_pModule, HLOpcodeGroup::HLMatLoadStore,
static_cast<unsigned>(Opcode), MatTy.getLoweredVectorTypeForReg(),
{Builder.getInt32((uint32_t)Opcode), MatPtr},
Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
}
return MatTy.emitLoweredLoad(LoweredPtr, Builder);
}
Value *HLMatrixLowerPass::lowerHLStore(CallInst *Call, Value *MatVal,
Value *MatPtr, bool RowMajor,
bool Return, IRBuilder<> &Builder) {
DXASSERT(MatVal->getType() == MatPtr->getType()->getPointerElementType(),
"Matrix store value/pointer type mismatch.");
Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
if (LoweredPtr == nullptr) {
// Can't lower the pointer here, defer to HL signature lower
HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatStore
: HLMatLoadStoreOpcode::ColMatStore;
return callHLFunction(
*m_pModule, HLOpcodeGroup::HLMatLoadStore,
static_cast<unsigned>(Opcode),
Return ? LoweredVal->getType() : Builder.getVoidTy(),
{Builder.getInt32((uint32_t)Opcode), MatPtr, LoweredVal},
Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
}
HLMatrixType MatTy =
HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
StoreInst *LoweredStore =
MatTy.emitLoweredStore(LoweredVal, LoweredPtr, Builder);
// If the intrinsic returned a value, return the stored lowered value
return Return ? LoweredVal : LoweredStore;
}
static Value *convertScalarOrVector(Value *SrcVal, Type *DstTy,
HLCastOpcode Opcode, IRBuilder<> Builder) {
DXASSERT(SrcVal->getType()->isVectorTy() == DstTy->isVectorTy(),
"Scalar/vector type mismatch in numerical conversion.");
Type *SrcTy = SrcVal->getType();
// Conversions between equivalent types are no-ops,
// even between signed/unsigned variants.
if (SrcTy == DstTy)
return SrcVal;
// Conversions to bools are comparisons
if (DstTy->getScalarSizeInBits() == 1) {
// fcmp une is what regular clang uses in C++ for (bool)f;
return SrcTy->isIntOrIntVectorTy()
? Builder.CreateICmpNE(
SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool")
: Builder.CreateFCmpUNE(
SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool");
}
// Cast necessary
bool SrcIsUnsigned = Opcode == HLCastOpcode::FromUnsignedCast ||
Opcode == HLCastOpcode::UnsignedUnsignedCast;
bool DstIsUnsigned = Opcode == HLCastOpcode::ToUnsignedCast ||
Opcode == HLCastOpcode::UnsignedUnsignedCast;
auto CastOp = static_cast<Instruction::CastOps>(
HLModule::GetNumericCastOp(SrcTy, SrcIsUnsigned, DstTy, DstIsUnsigned));
return Builder.CreateCast(CastOp, SrcVal, DstTy);
}
Value *HLMatrixLowerPass::lowerHLCast(CallInst *Call, Value *Src, Type *DstTy,
HLCastOpcode Opcode,
IRBuilder<> &Builder) {
// The opcode really doesn't mean much here, the types involved are what drive
// most of the casting.
DXASSERT(Opcode != HLCastOpcode::HandleToResCast,
"Unexpected matrix cast opcode.");
if (dxilutil::IsIntegerOrFloatingPointType(Src->getType())) {
// Scalar to matrix splat
HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
// Apply element conversion
Value *Result = convertScalarOrVector(Src, MatDstTy.getElementTypeForReg(),
Opcode, Builder);
// Splat to a vector
Result = Builder.CreateInsertElement(
UndefValue::get(VectorType::get(Result->getType(), 1)), Result,
static_cast<uint64_t>(0));
return Builder.CreateShuffleVector(
Result, Result,
ConstantVector::getSplat(MatDstTy.getNumElements(),
Builder.getInt32(0)));
} else if (VectorType *SrcVecTy = dyn_cast<VectorType>(Src->getType())) {
// Vector to matrix
HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
Value *Result = Src;
// We might need to truncate
if (MatDstTy.getNumElements() < SrcVecTy->getNumElements()) {
SmallVector<int, 4> ShuffleIndices;
for (unsigned Idx = 0; Idx < MatDstTy.getNumElements(); ++Idx)
ShuffleIndices.emplace_back(static_cast<int>(Idx));
Result = Builder.CreateShuffleVector(Src, Src, ShuffleIndices);
}
// Apply element conversion
return convertScalarOrVector(Result, MatDstTy.getLoweredVectorTypeForReg(),
Opcode, Builder);
}
// Source must now be a matrix
HLMatrixType MatSrcTy = HLMatrixType::cast(Src->getType());
VectorType *LoweredSrcTy = MatSrcTy.getLoweredVectorTypeForReg();
Value *LoweredSrc;
if (isa<Argument>(Src)) {
// Function arguments are lowered in HLSignatureLower.
// Initial codegen first generates those cast intrinsics to tell us how to
// lower them into vectors. Preserve them, but change the return type to
// vector.
DXASSERT(Opcode == HLCastOpcode::ColMatrixToVecCast ||
Opcode == HLCastOpcode::RowMatrixToVecCast,
"Unexpected cast of matrix argument.");
LoweredSrc = callHLFunction(
*m_pModule, HLOpcodeGroup::HLCast, static_cast<unsigned>(Opcode),
LoweredSrcTy, {Builder.getInt32((uint32_t)Opcode), Src},
Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
} else {
LoweredSrc = getLoweredByValOperand(Src, Builder);
}
DXASSERT_NOMSG(LoweredSrc->getType() == LoweredSrcTy);
Value *Result = LoweredSrc;
Type *LoweredDstTy = DstTy;
if (dxilutil::IsIntegerOrFloatingPointType(DstTy)) {
// Matrix to scalar
Result = Builder.CreateExtractElement(LoweredSrc, static_cast<uint64_t>(0));
} else if (FixedVectorType *DstVecTy = dyn_cast<FixedVectorType>(DstTy)) {
// Matrix to vector
DXASSERT(DstVecTy->getNumElements() <= LoweredSrcTy->getNumElements(),
"Cannot cast matrix to a larger vector.");
// We might have to truncate
if (DstVecTy->getNumElements() < LoweredSrcTy->getNumElements()) {
SmallVector<int, 3> ShuffleIndices;
for (unsigned Idx = 0; Idx < DstVecTy->getNumElements(); ++Idx)
ShuffleIndices.emplace_back(static_cast<int>(Idx));
Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
}
} else {
// Destination must now be a matrix too
HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
// Apply any changes at the matrix level: orientation changes and truncation
if (Opcode == HLCastOpcode::ColMatrixToRowMatrix)
Result = MatSrcTy.emitLoweredVectorColToRow(Result, Builder);
else if (Opcode == HLCastOpcode::RowMatrixToColMatrix)
Result = MatSrcTy.emitLoweredVectorRowToCol(Result, Builder);
else if (MatDstTy.getNumRows() != MatSrcTy.getNumRows() ||
MatDstTy.getNumColumns() != MatSrcTy.getNumColumns()) {
// Apply truncation
DXASSERT(MatDstTy.getNumRows() <= MatSrcTy.getNumRows() &&
MatDstTy.getNumColumns() <= MatSrcTy.getNumColumns(),
"Unexpected matrix cast between incompatible dimensions.");
SmallVector<int, 16> ShuffleIndices;
for (unsigned RowIdx = 0; RowIdx < MatDstTy.getNumRows(); ++RowIdx)
for (unsigned ColIdx = 0; ColIdx < MatDstTy.getNumColumns(); ++ColIdx)
ShuffleIndices.emplace_back(
static_cast<int>(MatSrcTy.getRowMajorIndex(RowIdx, ColIdx)));
Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
}
LoweredDstTy = MatDstTy.getLoweredVectorTypeForReg();
DXASSERT(cast<FixedVectorType>(Result->getType())->getNumElements() ==
cast<FixedVectorType>(LoweredDstTy)->getNumElements(),
"Unexpected matrix src/dst lowered element count mismatch after "
"truncation.");
}
// Apply element conversion
return convertScalarOrVector(Result, LoweredDstTy, Opcode, Builder);
}
Value *HLMatrixLowerPass::lowerHLSubscript(CallInst *Call,
HLSubscriptOpcode Opcode) {
switch (Opcode) {
case HLSubscriptOpcode::RowMatElement:
case HLSubscriptOpcode::ColMatElement:
return lowerHLMatElementSubscript(Call,
/* RowMajor */ Opcode ==
HLSubscriptOpcode::RowMatElement);
case HLSubscriptOpcode::RowMatSubscript:
case HLSubscriptOpcode::ColMatSubscript:
return lowerHLMatSubscript(Call,
/* RowMajor */ Opcode ==
HLSubscriptOpcode::RowMatSubscript);
case HLSubscriptOpcode::DefaultSubscript:
case HLSubscriptOpcode::CBufferSubscript:
// Those get lowered during HLOperationLower,
// and the return type must stay unchanged (as a matrix)
// to provide the metadata to properly emit the loads.
return nullptr;
default:
llvm_unreachable("Unexpected matrix subscript opcode.");
}
}
Value *HLMatrixLowerPass::lowerHLMatElementSubscript(CallInst *Call,
bool RowMajor) {
(void)RowMajor; // It doesn't look like we actually need this?
Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
Constant *IdxVec = cast<Constant>(
Call->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx));
VectorType *IdxVecTy = cast<VectorType>(IdxVec->getType());
// Get the loaded lowered vector element indices
SmallVector<Value *, 4> ElemIndices;
ElemIndices.reserve(IdxVecTy->getNumElements());
for (unsigned VecIdx = 0; VecIdx < IdxVecTy->getNumElements(); ++VecIdx) {
ElemIndices.emplace_back(IdxVec->getAggregateElement(VecIdx));
}
lowerHLMatSubscript(Call, MatPtr, ElemIndices);
// We did our own replacement of uses, opt-out of having the caller does it
// for us.
return nullptr;
}
Value *HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, bool RowMajor) {
(void)RowMajor; // It doesn't look like we actually need this?
Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
// Gather the indices, checking if they are all constant
SmallVector<Value *, 4> ElemIndices;
for (unsigned Idx = HLOperandIndex::kMatSubscriptSubOpIdx;
Idx < Call->getNumArgOperands(); ++Idx) {
ElemIndices.emplace_back(Call->getArgOperand(Idx));
}
lowerHLMatSubscript(Call, MatPtr, ElemIndices);
// We did our own replacement of uses, opt-out of having the caller does it
// for us.
return nullptr;
}
void HLMatrixLowerPass::lowerHLMatSubscript(
CallInst *Call, Value *MatPtr, SmallVectorImpl<Value *> &ElemIndices) {
DXASSERT_NOMSG(HLMatrixType::isMatrixPtr(MatPtr->getType()));
IRBuilder<> CallBuilder(Call);
Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, CallBuilder);
Value *LoweredMatrix = nullptr;
Value *RootPtr = LoweredPtr ? LoweredPtr : MatPtr;
while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
RootPtr = GEP->getPointerOperand();
if (LoweredPtr == nullptr) {
if (!isa<Argument>(RootPtr))
return;
// For a shader input, load the matrix into a lowered ptr
// The load will be handled by LowerSignature
HLMatLoadStoreOpcode Opcode = (HLSubscriptOpcode)GetHLOpcode(Call) ==
HLSubscriptOpcode::RowMatSubscript
? HLMatLoadStoreOpcode::RowMatLoad
: HLMatLoadStoreOpcode::ColMatLoad;
HLMatrixType MatTy =
HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
// Don't pass attributes from subscript (ReadNone) - load is ReadOnly.
// Attributes will be set when HL function is created.
// FIXME: This seems to indicate a potential bug, since the load should be
// placed where pointer users would have loaded from the pointer.
LoweredMatrix = callHLFunction(
*m_pModule, HLOpcodeGroup::HLMatLoadStore,
static_cast<unsigned>(Opcode), MatTy.getLoweredVectorTypeForReg(),
{CallBuilder.getInt32((uint32_t)Opcode), MatPtr}, AttributeSet(),
CallBuilder);
}
// For global variables, we can GEP directly into the lowered vector pointer.
// This is necessary to support group shared memory atomics and the likes.
bool AllowLoweredPtrGEPs = isa<GlobalVariable>(RootPtr);
// Just constructing this does all the work
HLMatrixSubscriptUseReplacer UseReplacer(Call, LoweredPtr, LoweredMatrix,
ElemIndices, AllowLoweredPtrGEPs,
m_deadInsts);
DXASSERT(Call->use_empty(),
"Expected all matrix subscript uses to have been replaced.");
addToDeadInsts(Call);
}
Value *HLMatrixLowerPass::lowerHLInit(CallInst *Call) {
DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
// Figure out the result type
HLMatrixType MatTy = HLMatrixType::cast(Call->getType());
VectorType *LoweredTy = MatTy.getLoweredVectorTypeForReg();
// Handle case where produced by EmitHLSLFlatConversion where there's one
// vector argument, instead of scalar arguments.
if (1 == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx &&
Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx)
->getType()
->isVectorTy()) {
Value *LoweredVec = Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
DXASSERT(LoweredTy->getNumElements() ==
cast<FixedVectorType>(LoweredVec->getType())->getNumElements(),
"Invalid matrix init argument vector element count.");
return LoweredVec;
}
DXASSERT(LoweredTy->getNumElements() ==
Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx,
"Invalid matrix init argument count.");
// Build the result vector from the init args.
// Both the args and the result vector are in row-major order, so no shuffling
// is necessary.
IRBuilder<> Builder(Call);
Value *LoweredVec = UndefValue::get(LoweredTy);
for (unsigned VecElemIdx = 0; VecElemIdx < LoweredTy->getNumElements();
++VecElemIdx) {
Value *ArgVal =
Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx + VecElemIdx);
DXASSERT(dxilutil::IsIntegerOrFloatingPointType(ArgVal->getType()),
"Expected only scalars in matrix initialization.");
LoweredVec = Builder.CreateInsertElement(LoweredVec, ArgVal,
static_cast<uint64_t>(VecElemIdx));
}
return LoweredVec;
}
Value *HLMatrixLowerPass::lowerHLSelect(CallInst *Call) {
DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
Value *Cond = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
Value *TrueMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx);
Value *FalseMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx);
DXASSERT(TrueMat->getType() == FalseMat->getType(),
"Unexpected type mismatch between matrix ternary operator values.");
#ifndef NDEBUG
// Assert that if the condition is a matrix, it matches the dimensions of the
// values
if (HLMatrixType MatCondTy = HLMatrixType::dyn_cast(Cond->getType())) {
HLMatrixType ValMatTy = HLMatrixType::cast(TrueMat->getType());
DXASSERT(MatCondTy.getNumRows() == ValMatTy.getNumRows() &&
MatCondTy.getNumColumns() == ValMatTy.getNumColumns(),
"Unexpected mismatch between ternary operator condition and value "
"matrix dimensions.");
}
#endif
IRBuilder<> Builder(Call);
Value *LoweredCond = getLoweredByValOperand(Cond, Builder);
Value *LoweredTrueVec = getLoweredByValOperand(TrueMat, Builder);
Value *LoweredFalseVec = getLoweredByValOperand(FalseMat, Builder);
Value *Result = UndefValue::get(LoweredTrueVec->getType());
bool IsScalarCond = !LoweredCond->getType()->isVectorTy();
unsigned NumElems =
cast<FixedVectorType>(Result->getType())->getNumElements();
for (uint64_t ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
Value *ElemCond = IsScalarCond
? LoweredCond
: Builder.CreateExtractElement(LoweredCond, ElemIdx);
Value *ElemTrueVal = Builder.CreateExtractElement(LoweredTrueVec, ElemIdx);
Value *ElemFalseVal =
Builder.CreateExtractElement(LoweredFalseVec, ElemIdx);
Value *ResultElem =
Builder.CreateSelect(ElemCond, ElemTrueVal, ElemFalseVal);
Result = Builder.CreateInsertElement(Result, ResultElem, ElemIdx);
}
return Result;
}