Pix: Cope with group shared AS->MS payload (#6619)
This change copes with the AS->MS payload being placed in group-shared
by the application (and MSFT's samples do indeed do this). (TIL, thanks
to pow2clk, that the spec says that the payload counts against the
group-shared total, implying, if not explicitly stating, that at least
on some platforms, the payload will be in group-shared anyway.)
The MS pass needs to be given data from the AS about the AS's thread
group topology, and this is done by extending the payload struct to add
three uints. This can't be done when the payload is resident in
group-shared, of course, because that would change the layout of
group-shared memory.
So the new approach here is to copy the payload to a new alloca (in the
default address space) struct with the members of the base struct plus
the extended data the MS needs, and then to copy piece-wise because
llvm.memcpy isn't appropriate for group-shared-to-normal address space
copies.
diff --git a/lib/DxilPIXPasses/DxilPIXAddTidToAmplificationShaderPayload.cpp b/lib/DxilPIXPasses/DxilPIXAddTidToAmplificationShaderPayload.cpp
index d002239..e756e98 100644
--- a/lib/DxilPIXPasses/DxilPIXAddTidToAmplificationShaderPayload.cpp
+++ b/lib/DxilPIXPasses/DxilPIXAddTidToAmplificationShaderPayload.cpp
@@ -45,7 +45,6 @@
}
void AddValueToExpandedPayload(OP *HlslOP, llvm::IRBuilder<> &B,
- ExpandedStruct &expanded,
AllocaInst *NewStructAlloca,
unsigned int expandedValueIndex, Value *value) {
Constant *Zero32Arg = HlslOP->GetU32Const(0);
@@ -53,135 +52,147 @@
IndexToAppendedValue.push_back(Zero32Arg);
IndexToAppendedValue.push_back(HlslOP->GetU32Const(expandedValueIndex));
auto *PointerToEmbeddedNewValue = B.CreateInBoundsGEP(
- expanded.ExpandedPayloadStructType, NewStructAlloca, IndexToAppendedValue,
+ NewStructAlloca, IndexToAppendedValue,
"PointerToEmbeddedNewValue" + std::to_string(expandedValueIndex));
B.CreateStore(value, PointerToEmbeddedNewValue);
}
-bool DxilPIXAddTidToAmplificationShaderPayload::runOnModule(Module &M) {
+void CopyAggregate(IRBuilder<> &B, Type *Ty, Value *Source, Value *Dest,
+ ArrayRef<Value *> GEPIndices) {
+ if (StructType *ST = dyn_cast<StructType>(Ty)) {
+ SmallVector<Value *, 16> StructIndices;
+ StructIndices.append(GEPIndices.begin(), GEPIndices.end());
+ StructIndices.push_back(nullptr);
+ for (unsigned j = 0; j < ST->getNumElements(); ++j) {
+ StructIndices.back() = B.getInt32(j);
+ CopyAggregate(B, ST->getElementType(j), Source, Dest, StructIndices);
+ }
+ } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
+ SmallVector<Value *, 16> StructIndices;
+ StructIndices.append(GEPIndices.begin(), GEPIndices.end());
+ StructIndices.push_back(nullptr);
+ for (unsigned j = 0; j < AT->getNumElements(); ++j) {
+ StructIndices.back() = B.getInt32(j);
+ CopyAggregate(B, AT->getArrayElementType(), Source, Dest, StructIndices);
+ }
+ } else {
+ auto *SourceGEP = B.CreateGEP(Source, GEPIndices, "CopyStructSourceGEP");
+ Value *Val = B.CreateLoad(SourceGEP, "CopyStructLoad");
+ auto *DestGEP = B.CreateGEP(Dest, GEPIndices, "CopyStructDestGEP");
+ B.CreateStore(Val, DestGEP, "CopyStructStore");
+ }
+}
+bool DxilPIXAddTidToAmplificationShaderPayload::runOnModule(Module &M) {
DxilModule &DM = M.GetOrCreateDxilModule();
LLVMContext &Ctx = M.getContext();
OP *HlslOP = DM.GetOP();
-
- Type *OriginalPayloadStructPointerType = nullptr;
- Type *OriginalPayloadStructType = nullptr;
- ExpandedStruct expanded;
llvm::Function *entryFunction = PIXPassHelpers::GetEntryFunction(DM);
for (inst_iterator I = inst_begin(entryFunction), E = inst_end(entryFunction);
I != E; ++I) {
- if (auto *Instr = llvm::cast<Instruction>(&*I)) {
- if (hlsl::OP::IsDxilOpFuncCallInst(Instr,
- hlsl::OP::OpCode::DispatchMesh)) {
- DxilInst_DispatchMesh DispatchMesh(Instr);
- OriginalPayloadStructPointerType =
- DispatchMesh.get_payload()->getType();
- OriginalPayloadStructType =
- OriginalPayloadStructPointerType->getPointerElementType();
- expanded = ExpandStructType(Ctx, OriginalPayloadStructType);
- }
+ if (hlsl::OP::IsDxilOpFuncCallInst(&*I, hlsl::OP::OpCode::DispatchMesh)) {
+ DxilInst_DispatchMesh DispatchMesh(&*I);
+ Type *OriginalPayloadStructPointerType =
+ DispatchMesh.get_payload()->getType();
+ Type *OriginalPayloadStructType =
+ OriginalPayloadStructPointerType->getPointerElementType();
+ ExpandedStruct expanded =
+ ExpandStructType(Ctx, OriginalPayloadStructType);
+
+ llvm::IRBuilder<> B(&*I);
+
+ auto *NewStructAlloca =
+ B.CreateAlloca(expanded.ExpandedPayloadStructType,
+ HlslOP->GetU32Const(1), "NewPayload");
+ NewStructAlloca->setAlignment(4);
+ auto PayloadType =
+ llvm::dyn_cast<PointerType>(DispatchMesh.get_payload()->getType());
+ SmallVector<Value *, 16> GEPIndices;
+ GEPIndices.push_back(B.getInt32(0));
+ CopyAggregate(B, PayloadType->getPointerElementType(),
+ DispatchMesh.get_payload(), NewStructAlloca, GEPIndices);
+
+ Constant *Zero32Arg = HlslOP->GetU32Const(0);
+ Constant *One32Arg = HlslOP->GetU32Const(1);
+ Constant *Two32Arg = HlslOP->GetU32Const(2);
+
+ auto GroupIdFunc =
+ HlslOP->GetOpFunc(DXIL::OpCode::GroupId, Type::getInt32Ty(Ctx));
+ Constant *GroupIdOpcode =
+ HlslOP->GetU32Const((unsigned)DXIL::OpCode::GroupId);
+ auto *GroupIdX =
+ B.CreateCall(GroupIdFunc, {GroupIdOpcode, Zero32Arg}, "GroupIdX");
+ auto *GroupIdY =
+ B.CreateCall(GroupIdFunc, {GroupIdOpcode, One32Arg}, "GroupIdY");
+ auto *GroupIdZ =
+ B.CreateCall(GroupIdFunc, {GroupIdOpcode, Two32Arg}, "GroupIdZ");
+
+ // FlatGroupID = z + y*numZ + x*numY*numZ
+ // Where x,y,z are the group ID components, and numZ and numY are the
+ // corresponding AS group-count arguments to the DispatchMesh Direct3D API
+ auto *GroupYxNumZ = B.CreateMul(
+ GroupIdY, HlslOP->GetU32Const(m_DispatchArgumentZ), "GroupYxNumZ");
+ auto *FlatGroupNumZY =
+ B.CreateAdd(GroupIdZ, GroupYxNumZ, "FlatGroupNumZY");
+ auto *GroupXxNumYZ = B.CreateMul(
+ GroupIdX,
+ HlslOP->GetU32Const(m_DispatchArgumentY * m_DispatchArgumentZ),
+ "GroupXxNumYZ");
+ auto *FlatGroupID =
+ B.CreateAdd(GroupXxNumYZ, FlatGroupNumZY, "FlatGroupID");
+
+ // The ultimate goal is a single unique thread ID for this AS thread.
+ // So take the flat group number, multiply it by the number of
+ // threads per group...
+ auto *FlatGroupIDWithSpaceForThreadInGroupId = B.CreateMul(
+ FlatGroupID,
+ HlslOP->GetU32Const(DM.GetNumThreads(0) * DM.GetNumThreads(1) *
+ DM.GetNumThreads(2)),
+ "FlatGroupIDWithSpaceForThreadInGroupId");
+
+ auto *FlattenedThreadIdInGroupFunc = HlslOP->GetOpFunc(
+ DXIL::OpCode::FlattenedThreadIdInGroup, Type::getInt32Ty(Ctx));
+ Constant *FlattenedThreadIdInGroupOpcode =
+ HlslOP->GetU32Const((unsigned)DXIL::OpCode::FlattenedThreadIdInGroup);
+ auto FlatThreadIdInGroup = B.CreateCall(FlattenedThreadIdInGroupFunc,
+ {FlattenedThreadIdInGroupOpcode},
+ "FlattenedThreadIdInGroup");
+
+ // ...and add the flat thread id:
+ auto *FlatId = B.CreateAdd(FlatGroupIDWithSpaceForThreadInGroupId,
+ FlatThreadIdInGroup, "FlatId");
+
+ AddValueToExpandedPayload(
+ HlslOP, B, NewStructAlloca,
+ expanded.ExpandedPayloadStructType->getStructNumElements() - 3,
+ FlatId);
+ AddValueToExpandedPayload(
+ HlslOP, B, NewStructAlloca,
+ expanded.ExpandedPayloadStructType->getStructNumElements() - 2,
+ DispatchMesh.get_threadGroupCountY());
+ AddValueToExpandedPayload(
+ HlslOP, B, NewStructAlloca,
+ expanded.ExpandedPayloadStructType->getStructNumElements() - 1,
+ DispatchMesh.get_threadGroupCountZ());
+
+ auto DispatchMeshFn = HlslOP->GetOpFunc(
+ DXIL::OpCode::DispatchMesh, expanded.ExpandedPayloadStructPtrType);
+ Constant *DispatchMeshOpcode =
+ HlslOP->GetU32Const((unsigned)DXIL::OpCode::DispatchMesh);
+ B.CreateCall(DispatchMeshFn,
+ {DispatchMeshOpcode, DispatchMesh.get_threadGroupCountX(),
+ DispatchMesh.get_threadGroupCountY(),
+ DispatchMesh.get_threadGroupCountZ(), NewStructAlloca});
+ I->removeFromParent();
+ delete &*I;
+ // Validation requires exactly one DispatchMesh in an AS, so we can exit
+ // after the first one:
+ DM.ReEmitDxilResources();
+ return true;
}
}
- AllocaInst *OldStructAlloca = nullptr;
- AllocaInst *NewStructAlloca = nullptr;
- std::vector<AllocaInst *> allocasOfPayloadType;
- for (inst_iterator I = inst_begin(entryFunction), E = inst_end(entryFunction);
- I != E; ++I) {
- auto *Inst = &*I;
- if (llvm::isa<AllocaInst>(Inst)) {
- auto *Alloca = llvm::cast<AllocaInst>(Inst);
- if (Alloca->getType() == OriginalPayloadStructPointerType) {
- allocasOfPayloadType.push_back(Alloca);
- }
- }
- }
- for (auto &Alloca : allocasOfPayloadType) {
- OldStructAlloca = Alloca;
- llvm::IRBuilder<> B(Alloca->getContext());
- NewStructAlloca = B.CreateAlloca(expanded.ExpandedPayloadStructType,
- HlslOP->GetU32Const(1), "NewPayload");
- NewStructAlloca->setAlignment(Alloca->getAlignment());
- NewStructAlloca->insertAfter(Alloca);
-
- ReplaceAllUsesOfInstructionWithNewValueAndDeleteInstruction(
- Alloca, NewStructAlloca, expanded.ExpandedPayloadStructType);
- }
-
- auto F = HlslOP->GetOpFunc(DXIL::OpCode::DispatchMesh,
- expanded.ExpandedPayloadStructPtrType);
- for (auto FI = F->user_begin(); FI != F->user_end();) {
- auto *FunctionUser = *FI++;
- auto *UserInstruction = llvm::cast<Instruction>(FunctionUser);
- DxilInst_DispatchMesh DispatchMesh(UserInstruction);
-
- llvm::IRBuilder<> B(UserInstruction);
-
- Constant *Zero32Arg = HlslOP->GetU32Const(0);
- Constant *One32Arg = HlslOP->GetU32Const(1);
- Constant *Two32Arg = HlslOP->GetU32Const(2);
-
- auto GroupIdFunc =
- HlslOP->GetOpFunc(DXIL::OpCode::GroupId, Type::getInt32Ty(Ctx));
- Constant *GroupIdOpcode =
- HlslOP->GetU32Const((unsigned)DXIL::OpCode::GroupId);
- auto *GroupIdX =
- B.CreateCall(GroupIdFunc, {GroupIdOpcode, Zero32Arg}, "GroupIdX");
- auto *GroupIdY =
- B.CreateCall(GroupIdFunc, {GroupIdOpcode, One32Arg}, "GroupIdY");
- auto *GroupIdZ =
- B.CreateCall(GroupIdFunc, {GroupIdOpcode, Two32Arg}, "GroupIdZ");
-
- // FlatGroupID = z + y*numZ + x*numY*numZ
- // Where x,y,z are the group ID components, and numZ and numY are the
- // corresponding AS group-count arguments to the DispatchMesh Direct3D API
- auto *GroupYxNumZ = B.CreateMul(
- GroupIdY, HlslOP->GetU32Const(m_DispatchArgumentZ), "GroupYxNumZ");
- auto *FlatGroupNumZY = B.CreateAdd(GroupIdZ, GroupYxNumZ, "FlatGroupNumZY");
- auto *GroupXxNumYZ = B.CreateMul(
- GroupIdX,
- HlslOP->GetU32Const(m_DispatchArgumentY * m_DispatchArgumentZ),
- "GroupXxNumYZ");
- auto *FlatGroupID =
- B.CreateAdd(GroupXxNumYZ, FlatGroupNumZY, "FlatGroFlatGroupIDupNum");
-
- // The ultimate goal is a single unique thread ID for this AS thread.
- // So take the flat group number, multiply it by the number of
- // threads per group...
- auto *FlatGroupIDWithSpaceForThreadInGroupId = B.CreateMul(
- FlatGroupID,
- HlslOP->GetU32Const(DM.GetNumThreads(0) * DM.GetNumThreads(1) *
- DM.GetNumThreads(2)),
- "FlatGroupIDWithSpaceForThreadInGroupId");
-
- auto *FlattenedThreadIdInGroupFunc = HlslOP->GetOpFunc(
- DXIL::OpCode::FlattenedThreadIdInGroup, Type::getInt32Ty(Ctx));
- Constant *FlattenedThreadIdInGroupOpcode =
- HlslOP->GetU32Const((unsigned)DXIL::OpCode::FlattenedThreadIdInGroup);
- auto FlatThreadIdInGroup = B.CreateCall(FlattenedThreadIdInGroupFunc,
- {FlattenedThreadIdInGroupOpcode},
- "FlattenedThreadIdInGroup");
-
- // ...and add the flat thread id:
- auto *FlatId = B.CreateAdd(FlatGroupIDWithSpaceForThreadInGroupId,
- FlatThreadIdInGroup, "FlatId");
-
- AddValueToExpandedPayload(HlslOP, B, expanded, NewStructAlloca,
- OriginalPayloadStructType->getStructNumElements(),
- FlatId);
- AddValueToExpandedPayload(
- HlslOP, B, expanded, NewStructAlloca,
- OriginalPayloadStructType->getStructNumElements() + 1,
- DispatchMesh.get_threadGroupCountY());
- AddValueToExpandedPayload(
- HlslOP, B, expanded, NewStructAlloca,
- OriginalPayloadStructType->getStructNumElements() + 2,
- DispatchMesh.get_threadGroupCountZ());
- }
-
- DM.ReEmitDxilResources();
-
- return true;
+ return false;
}
char DxilPIXAddTidToAmplificationShaderPayload::ID = 0;
diff --git a/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedComplexPayload.hlsl b/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedComplexPayload.hlsl
new file mode 100644
index 0000000..28eff71
--- /dev/null
+++ b/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedComplexPayload.hlsl
@@ -0,0 +1,88 @@
+// RUN: %dxc -enable-16bit-types -Od -Emain -Tas_6_6 %s | %opt -S -hlsl-dxil-PIX-add-tid-to-as-payload,dispatchArgY=3,dispatchArgZ=7 | %FileCheck %s
+
+// Check that the payload was piece-wise copied into a local copy from group-shared:
+// There are 28 elements:
+
+// CHECK: [[LOAD0:%.*]] = load [[TYPE0:.*]], [[TYPE0]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE0]] [[LOAD0]]
+// CHECK: [[LOAD1:%.*]] = load [[TYPE1:.*]], [[TYPE1]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE1]] [[LOAD1]]
+// CHECK: [[LOAD2:%.*]] = load [[TYPE2:.*]], [[TYPE2]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE2]] [[LOAD2]]
+// CHECK: [[LOAD3:%.*]] = load [[TYPE3:.*]], [[TYPE3]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE3]] [[LOAD3]]
+// CHECK: [[LOAD4:%.*]] = load [[TYPE4:.*]], [[TYPE4]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE4]] [[LOAD4]]
+// CHECK: [[LOAD5:%.*]] = load [[TYPE5:.*]], [[TYPE5]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE5]] [[LOAD5]]
+// CHECK: [[LOAD6:%.*]] = load [[TYPE6:.*]], [[TYPE6]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE6]] [[LOAD6]]
+// CHECK: [[LOAD7:%.*]] = load [[TYPE7:.*]], [[TYPE7]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE7]] [[LOAD7]]
+// CHECK: [[LOAD8:%.*]] = load [[TYPE8:.*]], [[TYPE8]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE8]] [[LOAD8]]
+// CHECK: [[LOAD9:%.*]] = load [[TYPE9:.*]], [[TYPE9]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE9]] [[LOAD9]]
+
+// CHECK: [[LOAD10:%.*]] = load [[TYPE10:.*]], [[TYPE10]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE10]] [[LOAD10]]
+// CHECK: [[LOAD11:%.*]] = load [[TYPE11:.*]], [[TYPE11]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE11]] [[LOAD11]]
+// CHECK: [[LOAD12:%.*]] = load [[TYPE12:.*]], [[TYPE12]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE12]] [[LOAD12]]
+// CHECK: [[LOAD13:%.*]] = load [[TYPE13:.*]], [[TYPE13]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE13]] [[LOAD13]]
+// CHECK: [[LOAD14:%.*]] = load [[TYPE14:.*]], [[TYPE14]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE14]] [[LOAD14]]
+// CHECK: [[LOAD15:%.*]] = load [[TYPE15:.*]], [[TYPE15]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE15]] [[LOAD15]]
+// CHECK: [[LOAD16:%.*]] = load [[TYPE16:.*]], [[TYPE16]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE16]] [[LOAD16]]
+// CHECK: [[LOAD17:%.*]] = load [[TYPE17:.*]], [[TYPE17]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE17]] [[LOAD17]]
+// CHECK: [[LOAD18:%.*]] = load [[TYPE18:.*]], [[TYPE18]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE18]] [[LOAD18]]
+// CHECK: [[LOAD19:%.*]] = load [[TYPE19:.*]], [[TYPE19]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE19]] [[LOAD19]]
+
+// CHECK: [[LOAD20:%.*]] = load [[TYPE20:.*]], [[TYPE20]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE20]] [[LOAD20]]
+// CHECK: [[LOAD21:%.*]] = load [[TYPE21:.*]], [[TYPE21]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE21]] [[LOAD21]]
+// CHECK: [[LOAD22:%.*]] = load [[TYPE22:.*]], [[TYPE22]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE22]] [[LOAD22]]
+// CHECK: [[LOAD23:%.*]] = load [[TYPE23:.*]], [[TYPE23]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE23]] [[LOAD23]]
+// CHECK: [[LOAD24:%.*]] = load [[TYPE24:.*]], [[TYPE24]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE24]] [[LOAD24]]
+// CHECK: [[LOAD25:%.*]] = load [[TYPE25:.*]], [[TYPE25]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE25]] [[LOAD25]]
+// CHECK: [[LOAD26:%.*]] = load [[TYPE26:.*]], [[TYPE26]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE26]] [[LOAD26]]
+// CHECK: [[LOAD27:%.*]] = load [[TYPE27:.*]], [[TYPE27]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE27]] [[LOAD27]]
+
+// And no more:
+// CHECK-NOT: [[LOAD28:%.*]] = load [[TYPE28:.*]], [[TYPE28]] addrspace(3)* getelementptr inbounds
+
+struct Contained {
+ uint j;
+ float af[3];
+};
+
+struct Bigger {
+ half h;
+ Contained a[2];
+};
+
+struct MyPayload {
+ uint i;
+ Bigger big[3];
+};
+
+groupshared MyPayload payload;
+
+[numthreads(1, 1, 1)] void main(uint gid
+ : SV_GroupID) {
+ DispatchMesh(1, 1, 1, payload);
+}
diff --git a/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedPayload.hlsl b/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedPayload.hlsl
new file mode 100644
index 0000000..7de78a8
--- /dev/null
+++ b/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedPayload.hlsl
@@ -0,0 +1,21 @@
+// RUN: %dxc -Od -Emain -Tas_6_6 %s | %opt -S -hlsl-dxil-PIX-add-tid-to-as-payload,dispatchArgY=3,dispatchArgZ=7 | %FileCheck %s
+
+// Check that the payload was piece-wise copied into a local copy
+// CHECK: [[LOADGEP:%.*]] = getelementptr %struct.MyPayload
+// CHECK: [[LOAD:%.*]] = load i32, i32* [[LOADGEP]]
+// CHECK: store volatile i32 [[LOAD]]
+
+struct MyPayload
+{
+ uint i;
+};
+
+groupshared MyPayload payload;
+
+[numthreads(1, 1, 1)]
+void main(uint gid : SV_GroupID)
+{
+ MyPayload copy;
+ copy = payload;
+ DispatchMesh(1, 1, 1, copy);
+}
diff --git a/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedTrickyTypesPayload.hlsl b/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedTrickyTypesPayload.hlsl
new file mode 100644
index 0000000..6f3e70d
--- /dev/null
+++ b/tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedTrickyTypesPayload.hlsl
@@ -0,0 +1,28 @@
+// RUN: %dxc -enable-16bit-types -Od -Emain -Tas_6_6 %s | %opt -S -hlsl-dxil-PIX-add-tid-to-as-payload,dispatchArgY=3,dispatchArgZ=7 | %FileCheck %s
+
+// Check that the payload was piece-wise copied into a local copy from group-shared:
+// There are only 2 elements (the bitfield should take up 1 uint slot)
+
+// CHECK: [[LOAD0:%.*]] = load [[TYPE0:.*]], [[TYPE0]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE0]] [[LOAD0]]
+// CHECK: [[LOAD1:%.*]] = load [[TYPE1:.*]], [[TYPE1]] addrspace(3)* getelementptr inbounds
+// CHECK:store volatile [[TYPE1]] [[LOAD1]]
+
+// And no more:
+// CHECK-NOT: [[LOAD2:%.*]] = load {{.*}}, {{.*}} addrspace(3)* getelementptr inbounds
+
+struct MyPayload {
+ uint i;
+ void Init() { i = 27; }
+struct {
+ int bf0 : 7;
+ int bf1 : 11;
+} bitfields;
+};
+
+groupshared MyPayload payload;
+
+[numthreads(1, 1, 1)] void main(uint gid
+ : SV_GroupID) {
+ DispatchMesh(1, 1, 1, payload);
+}
diff --git a/tools/clang/unittests/HLSL/PixTest.cpp b/tools/clang/unittests/HLSL/PixTest.cpp
index 55dad73..5828976 100644
--- a/tools/clang/unittests/HLSL/PixTest.cpp
+++ b/tools/clang/unittests/HLSL/PixTest.cpp
@@ -102,7 +102,8 @@
TEST_METHOD(CompileDebugDisasmPDB)
TEST_METHOD(AddToASPayload)
-
+ TEST_METHOD(AddToASGroupSharedPayload)
+ TEST_METHOD(AddToASGroupSharedPayload_MeshletCullSample)
TEST_METHOD(SignatureModification_Empty)
TEST_METHOD(SignatureModification_VertexIdAlready)
TEST_METHOD(SignatureModification_SomethingElseFirst)
@@ -565,7 +566,7 @@
TEST_F(PixTest, AddToASPayload) {
- const char *dynamicResourceDecriptorHeapAccess = R"(
+ const char *hlsl = R"(
struct MyPayload
{
float f1;
@@ -603,12 +604,10 @@
)";
- auto as = Compile(m_dllSupport, dynamicResourceDecriptorHeapAccess, L"as_6_6",
- {}, L"ASMain");
+ auto as = Compile(m_dllSupport, hlsl, L"as_6_6", {}, L"ASMain");
RunDxilPIXAddTidToAmplificationShaderPayloadPass(as);
- auto ms = Compile(m_dllSupport, dynamicResourceDecriptorHeapAccess, L"ms_6_6",
- {}, L"MSMain");
+ auto ms = Compile(m_dllSupport, hlsl, L"ms_6_6", {}, L"MSMain");
RunDxilPIXMeshShaderOutputPass(ms);
}
unsigned FindOrAddVSInSignatureElementForInstanceOrVertexID(
@@ -704,6 +703,63 @@
VERIFY_ARE_EQUAL(sig.GetElement(2).GetStartRow(), 2);
}
+TEST_F(PixTest, AddToASGroupSharedPayload) {
+
+ const char *hlsl = R"(
+struct Contained
+{
+ uint j;
+ float af[3];
+};
+
+struct Bigger
+{
+ half h;
+ void Init() { h = 1.f; }
+ Contained a[2];
+};
+
+struct MyPayload
+{
+ uint i;
+ Bigger big[3];
+};
+
+groupshared MyPayload payload;
+
+[numthreads(1, 1, 1)]
+void main(uint gid : SV_GroupID)
+{
+ DispatchMesh(1, 1, 1, payload);
+}
+
+ )";
+
+ auto as = Compile(m_dllSupport, hlsl, L"as_6_6", {L"-Od"}, L"main");
+ RunDxilPIXAddTidToAmplificationShaderPayloadPass(as);
+}
+
+TEST_F(PixTest, AddToASGroupSharedPayload_MeshletCullSample) {
+
+ const char *hlsl = R"(
+struct MyPayload
+{
+ uint i[32];
+};
+
+groupshared MyPayload payload;
+
+[numthreads(1, 1, 1)]
+void main(uint gid : SV_GroupID)
+{
+ DispatchMesh(1, 1, 1, payload);
+}
+
+ )";
+
+ auto as = Compile(m_dllSupport, hlsl, L"as_6_6", {L"-Od"}, L"main");
+ RunDxilPIXAddTidToAmplificationShaderPayloadPass(as);
+}
static llvm::DIType *PeelTypedefs(llvm::DIType *diTy) {
using namespace llvm;
const llvm::DITypeIdentifierMap EmptyMap;