blob: 5f6fd54cde3f97baa69245601cdf147456426c6f [file]
/* Copyright (c) 2024-2026 LunarG, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "trace_ray_pass.h"
#include "containers/container_utils.h"
#include "gpuav/shaders/gpuav_error_codes.h"
#include "gpuav/shaders/gpuav_shaders_constants.h"
#include "module.h"
#include <spirv/unified1/spirv.hpp>
#include <iostream>
#include "generated/gpuav_offline_spirv.h"
namespace gpuav {
namespace spirv {
const static OfflineModule kOfflineModule = {instrumentation_trace_ray_comp, instrumentation_trace_ray_comp_size,
UseErrorPayloadVariable};
const static OfflineFunction kTraceRayAccelerationStructureValidationFunction = {"inst_trace_ray_acceleration_structure",
instrumentation_trace_ray_comp_function_0_offset};
const static OfflineFunction kTraceRayValidationFunction = {"inst_trace_ray", instrumentation_trace_ray_comp_function_1_offset};
const static OfflineFunction kRayHitObjectValidationFunction = {"inst_ray_hit_object",
instrumentation_trace_ray_comp_function_2_offset};
const static OfflineFunction kRayHitObjectSbtIndexValidationFunction = {"inst_ray_hit_object_sbt_index_check",
instrumentation_trace_ray_comp_function_3_offset};
const static OfflineFunction kRayQueryInitializeValidationFunction = {"inst_ray_query_comp",
instrumentation_trace_ray_comp_function_4_offset};
const static OfflineFunction kReportIntersectionValidationFunction = {"inst_report_intersection_comp",
instrumentation_trace_ray_comp_function_5_offset};
TraceRayPass::TraceRayPass(Module& module) : Pass(module, kOfflineModule) { module.use_bda_ = true; }
std::vector<uint32_t> TraceRayPass::GetTlasValidationFunctionCallInstructions(const Function& function, uint32_t tlas_operand_pos,
uint32_t error_sub_code, BasicBlock& block,
InstructionIt* trace_ray_inst_it) {
if (module_.interface_.descriptor_mode != vvl::DescriptorMode::DescriptorModeClassic) {
return {};
}
const uint32_t as_op_load_id = (*trace_ray_inst_it)->get()->Operand(tlas_operand_pos);
const Instruction* as_op_load_inst = function.FindInstruction(as_op_load_id);
if (!as_op_load_inst) {
return {};
}
// AS descriptors use UniformConstant storage class (unlike buffers which use Uniform/StorageBuffer),
// so a non-array AS can be loaded directly from its variable with no access chain in between.
// Pre-initialize variable for that case; the loop below handles the access-chain case.
const Variable* variable = type_manager_.FindVariableById(as_op_load_inst->Operand(0));
std::vector<const Instruction*> access_chain_insts;
const Instruction* next_access_chain = function.FindInstruction(as_op_load_inst->Operand(0));
// We need to walk down possibly multiple chained OpAccessChains or OpCopyObject to get the variable
while (next_access_chain && next_access_chain->IsNonPtrAccessChain()) {
access_chain_insts.push_back(next_access_chain);
const uint32_t access_chain_base_id = next_access_chain->Operand(0);
variable = type_manager_.FindVariableById(access_chain_base_id);
if (variable) {
break; // found
}
next_access_chain = function.FindInstruction(access_chain_base_id);
}
if (!variable) {
return {};
}
const Type* descriptor_type = variable->PointerType(type_manager_);
if (!descriptor_type || descriptor_type->spv_type_ == SpvType::kRuntimeArray) {
return {}; // TODO - Currently we mark these as "bindless"
}
const bool is_descriptor_array = descriptor_type->IsArray();
if (is_descriptor_array && access_chain_insts.empty()) {
return {}; // array descriptor without an access chain is invalid SPIR-V
}
uint32_t descriptor_index_id = 0;
if (is_descriptor_array) {
// Because you can't have 2D array of descriptors, the first index of the last accessChain is the descriptor index
descriptor_index_id = access_chain_insts.back()->Operand(1);
} else {
// There is no array of this descriptor, so we essentially have an array of 1
descriptor_index_id = type_manager_.GetConstantZeroUint32().Id();
}
uint32_t descriptor_set = 0;
uint32_t descriptor_binding = 0;
for (const auto& annotation : module_.annotations_) {
if (annotation->Opcode() == spv::OpDecorate && annotation->Word(1) == variable->Id()) {
if (annotation->Word(2) == spv::DecorationDescriptorSet) {
descriptor_set = annotation->Word(3);
} else if (annotation->Word(2) == spv::DecorationBinding) {
descriptor_binding = annotation->Word(3);
}
}
}
if (descriptor_set >= glsl::kDebugInputBindlessMaxDescSets) {
module_.InternalWarning(Name(), "Tried to use a descriptor slot over the current max limit");
return {};
}
const Constant& desc_set_constant = type_manager_.GetConstantUInt32(descriptor_set);
const uint32_t desc_index_id = CastToUint32(descriptor_index_id, block, trace_ray_inst_it); // might be int32
const auto& layout_lut = module_.interface_.instrumentation_dsl.set_index_to_bindings_layout_lut;
BindingLayout binding_layout = layout_lut[descriptor_set][descriptor_binding];
const Constant& binding_layout_offset = type_manager_.GetConstantUInt32(binding_layout.start);
const uint32_t function_result = module_.TakeNextId();
const uint32_t function_def = GetLinkFunction(trace_ray_as_link_function_id_, kTraceRayAccelerationStructureValidationFunction);
const uint32_t bool_type = type_manager_.GetTypeBool().Id();
const uint32_t inst_position = (*trace_ray_inst_it)->get()->GetPositionOffset();
const uint32_t inst_position_id = type_manager_.CreateConstantUInt32(inst_position).Id();
const uint32_t error_sub_code_id = type_manager_.GetConstantUInt32(error_sub_code).Id();
return {bool_type,
function_result,
function_def,
inst_position_id,
desc_set_constant.Id(),
desc_index_id,
binding_layout_offset.Id(),
error_sub_code_id};
}
std::vector<uint32_t> TraceRayPass::GetTraceRayValidationFunctionCallInstructions(InstructionIt* trace_ray_inst_it) {
const uint32_t function_result = module_.TakeNextId();
const uint32_t function_def = GetLinkFunction(trace_ray_link_function_id_, kTraceRayValidationFunction);
const uint32_t bool_type = type_manager_.GetTypeBool().Id();
const uint32_t ray_flags_id = (*trace_ray_inst_it)->get()->Operand(1);
const uint32_t cull_mask_id = (*trace_ray_inst_it)->get()->Operand(2);
const uint32_t sbt_record_offset_id = (*trace_ray_inst_it)->get()->Operand(3);
const uint32_t sbt_record_stride_id = (*trace_ray_inst_it)->get()->Operand(4);
const uint32_t miss_index_id = (*trace_ray_inst_it)->get()->Operand(5);
const uint32_t origin_id = (*trace_ray_inst_it)->get()->Operand(6);
const uint32_t t_min_id = (*trace_ray_inst_it)->get()->Operand(7);
const uint32_t direction_id = (*trace_ray_inst_it)->get()->Operand(8);
const uint32_t t_max_id = (*trace_ray_inst_it)->get()->Operand(9);
const uint32_t pipeline_flags = (module_.interface_.pipeline_has_skip_aabbs_flag ? 0x1 : 0u) |
(module_.interface_.pipeline_has_skip_triangles_flag ? 0x2 : 0u);
const uint32_t pipeline_flags_id = type_manager_.CreateConstantUInt32(pipeline_flags).Id();
const uint32_t inst_position = (*trace_ray_inst_it)->get()->GetPositionOffset();
const uint32_t inst_position_id = type_manager_.CreateConstantUInt32(inst_position).Id();
return {bool_type,
function_result,
function_def,
inst_position_id,
ray_flags_id,
cull_mask_id,
sbt_record_offset_id,
sbt_record_stride_id,
miss_index_id,
origin_id,
t_min_id,
direction_id,
t_max_id,
pipeline_flags_id};
}
std::vector<uint32_t> TraceRayPass::GetRayHitObjectValidationFunctionCallInstructions(InstructionIt* ray_hit_object_inst_it) {
const uint32_t function_result = module_.TakeNextId();
const uint32_t function_def = GetLinkFunction(hit_object_link_function_id_, kRayHitObjectValidationFunction);
const uint32_t bool_type = type_manager_.GetTypeBool().Id();
const uint32_t opcode = (*ray_hit_object_inst_it)->get()->Opcode();
// All HitObject opcodes have ray parameters at the same positions
const uint32_t ray_flags_id = (*ray_hit_object_inst_it)->get()->Operand(2);
const uint32_t ray_origin_id = (*ray_hit_object_inst_it)->get()->Operand(7);
const uint32_t ray_tmin_id = (*ray_hit_object_inst_it)->get()->Operand(8);
const uint32_t ray_direction_id = (*ray_hit_object_inst_it)->get()->Operand(9);
const uint32_t ray_tmax_id = (*ray_hit_object_inst_it)->get()->Operand(10);
uint32_t time_id = 0;
if (opcode == spv::OpHitObjectTraceRayMotionEXT || opcode == spv::OpHitObjectTraceMotionReorderExecuteEXT) {
time_id = (*ray_hit_object_inst_it)->get()->Operand(11);
}
const uint32_t inst_position = (*ray_hit_object_inst_it)->get()->GetPositionOffset();
const uint32_t inst_position_id = type_manager_.CreateConstantUInt32(inst_position).Id();
const uint32_t opcode_type_id = type_manager_.CreateConstantUInt32(opcode).Id();
const uint32_t pipeline_flags = (module_.interface_.pipeline_has_skip_aabbs_flag ? 0x1 : 0u) |
(module_.interface_.pipeline_has_skip_triangles_flag ? 0x2 : 0u);
const uint32_t pipeline_flags_id = type_manager_.CreateConstantUInt32(pipeline_flags).Id();
// For non-motion opcodes, pass 0.0 as time (valid value, won't trigger error)
if (time_id == 0) {
time_id = type_manager_.GetConstantZeroFloat32().Id();
}
return {bool_type, function_result, function_def, inst_position_id, opcode_type_id, ray_flags_id,
ray_origin_id, ray_tmin_id, ray_direction_id, ray_tmax_id, pipeline_flags_id, time_id};
}
std::vector<uint32_t> TraceRayPass::GetRayHitObjectSbtIndexValidationFunctionCallInstructions(
InstructionIt* ray_hit_object_sbt_index_inst_it) {
const uint32_t function_result = module_.TakeNextId();
const uint32_t function_def = GetLinkFunction(hit_object_sbt_index_link_function_id_, kRayHitObjectSbtIndexValidationFunction);
const uint32_t bool_type = type_manager_.GetTypeBool().Id();
const uint32_t sbt_index_id = (*ray_hit_object_sbt_index_inst_it)->get()->Operand(1);
const uint32_t inst_position = (*ray_hit_object_sbt_index_inst_it)->get()->GetPositionOffset();
const uint32_t inst_position_id = type_manager_.CreateConstantUInt32(inst_position).Id();
// maxShaderBindingTableRecordIndex
const uint32_t max_sbt_index = module_.interface_.max_shader_binding_table_record_index;
const uint32_t max_sbt_index_id = type_manager_.CreateConstantUInt32(max_sbt_index).Id();
return {bool_type, function_result, function_def, inst_position_id, sbt_index_id, max_sbt_index_id};
}
std::vector<uint32_t> TraceRayPass::GetRayQueryInitializeValidationFunctionCallInstructions(InstructionIt* ray_query_init_inst_it) {
const uint32_t function_result = module_.TakeNextId();
const uint32_t function_def = GetLinkFunction(ray_query_initialize_function_id_, kRayQueryInitializeValidationFunction);
const uint32_t bool_type = type_manager_.GetTypeBool().Id();
const uint32_t ray_flags_id = (*ray_query_init_inst_it)->get()->Operand(2);
const uint32_t ray_origin_id = (*ray_query_init_inst_it)->get()->Operand(4);
const uint32_t ray_tmin_id = (*ray_query_init_inst_it)->get()->Operand(5);
const uint32_t ray_direction_id = (*ray_query_init_inst_it)->get()->Operand(6);
const uint32_t ray_tmax_id = (*ray_query_init_inst_it)->get()->Operand(7);
const uint32_t inst_position = (*ray_query_init_inst_it)->get()->GetPositionOffset();
const uint32_t inst_position_id = type_manager_.CreateConstantUInt32(inst_position).Id();
return {bool_type, function_result, function_def, inst_position_id, ray_flags_id,
ray_origin_id, ray_tmin_id, ray_direction_id, ray_tmax_id};
}
std::vector<uint32_t> TraceRayPass::GetReportIntersectionValidationFunctionCallInstructions(
InstructionIt* report_intersection_inst_it) {
const uint32_t function_result = module_.TakeNextId();
const uint32_t function_def = GetLinkFunction(report_intersection_function_id_, kReportIntersectionValidationFunction);
const uint32_t bool_type = type_manager_.GetTypeBool().Id();
const uint32_t hit_kind_id = (*report_intersection_inst_it)->get()->Operand(1);
const uint32_t inst_position = (*report_intersection_inst_it)->get()->GetPositionOffset();
const uint32_t inst_position_id = type_manager_.CreateConstantUInt32(inst_position).Id();
return {bool_type, function_result, function_def, inst_position_id, hit_kind_id};
}
uint32_t TraceRayPass::AddFunctionCall(BasicBlock& block, std::vector<uint32_t>&& instructions, InstructionIt* inst_it) {
const uint32_t function_result = instructions[1];
block.CreateInstruction(spv::OpFunctionCall, std::move(instructions), inst_it);
module_.need_log_error_ = true;
return function_result;
}
bool TraceRayPass::Instrument() {
// Can safely loop function list as there is no injecting of new Functions until linking time
for (Function& function : module_.functions_) {
if (!function.called_from_target_) {
continue;
}
for (auto block_it = function.blocks_.begin(); block_it != function.blocks_.end(); ++block_it) {
BasicBlock& current_block = **block_it;
cf_.Update(current_block);
if (debug_disable_loops_ && cf_.in_loop) {
continue;
}
if (current_block.IsLoopHeader()) {
continue; // Currently can't properly handle injecting CFG logic into a loop header block
}
auto& block_instructions = current_block.instructions_;
for (auto inst_it = block_instructions.begin(); inst_it != block_instructions.end(); ++inst_it) {
if (MaxInstrumentationsCountReached()) {
return instrumentations_count_ != 0;
}
const spv::Op opcode = spv::Op(inst_it->get()->Opcode());
std::vector<std::vector<uint32_t>> func_calls;
auto add_func_call = [&](std::vector<uint32_t>&& func_call) {
if (!func_call.empty()) {
func_calls.emplace_back(std::move(func_call));
}
};
switch (opcode) {
case spv::OpRayQueryInitializeKHR: {
add_func_call(GetTlasValidationFunctionCallInstructions(
function, 1, glsl::kErrorSubCode_RayQuery_TlasNotBuilt, current_block, &inst_it));
add_func_call(GetRayQueryInitializeValidationFunctionCallInstructions(&inst_it));
break;
}
case spv::OpTraceRayKHR: {
add_func_call(GetTraceRayValidationFunctionCallInstructions(&inst_it));
add_func_call(GetTlasValidationFunctionCallInstructions(
function, 0, glsl::kErrorSubCode_TraceRay_TlasNotBuilt, current_block, &inst_it));
break;
}
case spv::OpHitObjectSetShaderBindingTableRecordIndexEXT: {
add_func_call(GetRayHitObjectSbtIndexValidationFunctionCallInstructions(&inst_it));
break;
}
case spv::OpHitObjectTraceRayEXT:
case spv::OpHitObjectTraceReorderExecuteEXT:
case spv::OpHitObjectTraceRayMotionEXT:
case spv::OpHitObjectTraceMotionReorderExecuteEXT: {
add_func_call(GetRayHitObjectValidationFunctionCallInstructions(&inst_it));
break;
}
case spv::OpReportIntersectionKHR: {
add_func_call(GetReportIntersectionValidationFunctionCallInstructions(&inst_it));
break;
}
default:
break;
}
if (func_calls.empty()) {
continue;
}
++instrumentations_count_;
if (!module_.settings_.safe_mode) {
for (auto& func_call : func_calls) {
AddFunctionCall(current_block, std::move(func_call), &inst_it);
}
} else {
InjectConditionalData ic_data = InjectFunctionPre(function, block_it, inst_it);
uint32_t combined_func_results_id = AddFunctionCall(current_block, std::move(func_calls[0]), nullptr);
for (size_t i = 1; i < func_calls.size(); ++i) {
const uint32_t next_func_result_id = AddFunctionCall(current_block, std::move(func_calls[i]), nullptr);
const uint32_t next_combined_func_results_id = module_.TakeNextId();
current_block.CreateInstruction(spv::OpLogicalAnd,
{type_manager_.GetTypeBool().Id(), next_combined_func_results_id,
combined_func_results_id, next_func_result_id});
combined_func_results_id = next_combined_func_results_id;
}
ic_data.function_result_id = combined_func_results_id;
InjectFunctionPost(current_block, ic_data);
// Skip the newly added valid and invalid block. Start searching again from newly split merge block
++block_it;
++block_it;
break;
}
}
}
}
return instrumentations_count_ != 0;
}
void TraceRayPass::PrintDebugInfo() const {
std::cout << "TraceRayPass instrumentation count: " << instrumentations_count_ << '\n';
}
} // namespace spirv
} // namespace gpuav