| /* Copyright (c) 2024-2026 LunarG, Inc. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "containers/limits.h" |
| #include "containers/small_vector.h" |
| #include "generated/spirv_grammar_helper.h" |
| #include "module.h" |
| #include <cassert> |
| #include <cstdint> |
| #include <cstring> |
| #include <spirv/unified1/spirv.hpp> |
| #include "containers/custom_containers.h" |
| #include "function_basic_block.h" |
| |
| namespace gpuav { |
| namespace spirv { |
| |
| void Module::SetSpecConstantValue(Instruction* inst, const Type& type, vvl::unordered_map<uint32_t, uint32_t>& id_to_spec_id) { |
| const uint32_t opcode = inst->Opcode(); |
| |
| if (opcode == spv::OpSpecConstantComposite) { |
| return inst->FreezeSpecConstant(); |
| } |
| |
| const bool has_spec_info = interface_.specialization_info && interface_.specialization_info->mapEntryCount > 0; |
| if (!has_spec_info) { |
| return inst->FreezeSpecConstant(); |
| } |
| |
| const uint32_t result_id = inst->ResultId(); |
| const auto it = id_to_spec_id.find(result_id); |
| if (it == id_to_spec_id.end()) { |
| // OpDecorate SpecId was not set, so using default |
| return inst->FreezeSpecConstant(); |
| } |
| const uint32_t spec_id = it->second; |
| |
| VkSpecializationMapEntry map_entry = {0, 0, 0}; |
| bool found = false; |
| for (uint32_t i = 0; i < interface_.specialization_info->mapEntryCount; i++) { |
| if (interface_.specialization_info->pMapEntries[i].constantID == spec_id) { |
| map_entry = interface_.specialization_info->pMapEntries[i]; |
| found = true; |
| break; |
| } |
| } |
| |
| if (!found) { |
| return inst->FreezeSpecConstant(); |
| } |
| |
| if ((map_entry.offset + map_entry.size) <= interface_.specialization_info->dataSize) { |
| // Spec constants at most can be a int64/float64 |
| assert(map_entry.size <= 8); |
| const uint8_t* out_p = static_cast<const uint8_t*>(interface_.specialization_info->pData); |
| const uint8_t* target_addr = out_p + map_entry.offset; |
| |
| if (opcode == spv::OpSpecConstantTrue || opcode == spv::OpSpecConstantFalse) { |
| // For Boolean, just swap from True <-> False spec constant, then will be frozen |
| VkBool32 raw_value = 0; |
| std::memcpy(&raw_value, target_addr, std::min(map_entry.size, sizeof(VkBool32))); |
| inst->SetNewOpcode(raw_value ? spv::OpSpecConstantTrue : spv::OpSpecConstantFalse); |
| } else { |
| assert(opcode == spv::OpSpecConstant); |
| const bool is_signed = type.IsSignedInt(); |
| |
| uint64_t raw_value = 0; |
| std::memcpy(&raw_value, target_addr, map_entry.size); |
| |
| // Sign-extend 8-bit and 16-bit negative numbers up to 32 bits |
| if (is_signed && map_entry.size < 4) { |
| const uint32_t bit_width = static_cast<uint32_t>(map_entry.size * 8); |
| const uint64_t sign_bit = 1ULL << (bit_width - 1); |
| |
| // If the sign bit is 1, fill the upper bits with 1s |
| if (raw_value & sign_bit) { |
| uint64_t mask = ~0ULL << bit_width; |
| raw_value |= mask; |
| } |
| } |
| |
| const uint32_t word_0 = static_cast<uint32_t>(raw_value & 0xFFFFFFFF); |
| inst->UpdateWord(3, word_0); |
| if (map_entry.size == 8) { |
| const uint32_t word_1 = static_cast<uint32_t>(raw_value >> 32); |
| inst->UpdateWord(4, word_1); |
| } |
| } |
| } |
| |
| return inst->FreezeSpecConstant(); |
| } |
| |
| static uint64_t EvaluateArithmetic(spv::Op opcode, uint64_t arg0, uint64_t arg1, uint32_t bit_width) { |
| uint64_t mask = (bit_width == 64) ? ~0ULL : (1ULL << bit_width) - 1; |
| |
| uint64_t u_a = arg0; |
| uint64_t u_b = arg1; |
| int64_t s_a = static_cast<int64_t>(arg0); |
| int64_t s_b = static_cast<int64_t>(arg1); |
| |
| uint64_t res = 0; |
| |
| switch (opcode) { |
| case spv::OpSNegate: |
| res = static_cast<uint64_t>(-s_a); |
| break; |
| case spv::OpSDiv: |
| if (s_b != 0) { |
| if (bit_width == 64 && s_a == vvl::kI64Min && s_b == -1) { |
| res = static_cast<uint64_t>(vvl::kI64Min); |
| } else { |
| res = static_cast<uint64_t>(s_a / s_b); |
| } |
| } |
| break; |
| case spv::OpSRem: |
| if (s_b != 0) { |
| if (bit_width == 64 && s_a == vvl::kI64Min && s_b == -1) { |
| res = 0; |
| } else { |
| res = static_cast<uint64_t>(s_a % s_b); |
| } |
| } |
| break; |
| case spv::OpSMod: |
| if (s_b != 0) { |
| if (bit_width == 64 && s_a == vvl::kI64Min && s_b == -1) { |
| res = 0; |
| } else { |
| res = static_cast<uint64_t>(s_a % s_b); |
| if ((res > 0 && s_b < 0) || s_b > 0) { |
| res += u_b; |
| } |
| } |
| } |
| break; |
| case spv::OpShiftRightArithmetic: |
| if (u_b >= 64) { |
| res = (s_a < 0) ? -1 : 0; |
| } else { |
| res = static_cast<uint64_t>(s_a >> u_b); |
| } |
| break; |
| case spv::OpSLessThan: |
| res = (s_a < s_b) ? 1 : 0; |
| break; |
| case spv::OpSGreaterThan: |
| res = (s_a > s_b) ? 1 : 0; |
| break; |
| case spv::OpSLessThanEqual: |
| res = (s_a <= s_b) ? 1 : 0; |
| break; |
| case spv::OpSGreaterThanEqual: |
| res = (s_a >= s_b) ? 1 : 0; |
| break; |
| |
| case spv::OpIAdd: |
| res = u_a + u_b; |
| break; |
| case spv::OpISub: |
| res = u_a - u_b; |
| break; |
| case spv::OpIMul: |
| res = u_a * u_b; |
| break; |
| case spv::OpUDiv: |
| if (u_b != 0) res = u_a / u_b; |
| break; |
| case spv::OpUMod: |
| if (u_b != 0) res = u_a % u_b; |
| break; |
| case spv::OpShiftRightLogical: |
| res = (u_b >= 64) ? 0 : (u_a >> u_b); |
| break; |
| case spv::OpShiftLeftLogical: |
| res = (u_b >= 64) ? 0 : (u_a << u_b); |
| break; |
| case spv::OpBitwiseOr: |
| res = u_a | u_b; |
| break; |
| case spv::OpBitwiseXor: |
| res = u_a ^ u_b; |
| break; |
| case spv::OpBitwiseAnd: |
| res = u_a & u_b; |
| break; |
| case spv::OpNot: |
| res = ~u_a; |
| break; |
| |
| case spv::OpLogicalOr: |
| res = (u_a || u_b) ? 1 : 0; |
| break; |
| case spv::OpLogicalAnd: |
| res = (u_a && u_b) ? 1 : 0; |
| break; |
| case spv::OpLogicalNot: |
| res = (!u_a) ? 1 : 0; |
| break; |
| case spv::OpLogicalEqual: |
| case spv::OpIEqual: |
| res = (u_a == u_b) ? 1 : 0; |
| break; |
| case spv::OpLogicalNotEqual: |
| case spv::OpINotEqual: |
| res = (u_a != u_b) ? 1 : 0; |
| break; |
| case spv::OpULessThan: |
| res = (u_a < u_b) ? 1 : 0; |
| break; |
| case spv::OpUGreaterThan: |
| res = (u_a > u_b) ? 1 : 0; |
| break; |
| case spv::OpULessThanEqual: |
| res = (u_a <= u_b) ? 1 : 0; |
| break; |
| case spv::OpUGreaterThanEqual: |
| res = (u_a >= u_b) ? 1 : 0; |
| break; |
| default: |
| assert(false); // Only a limited set of instructions allowed |
| return 0; |
| } |
| |
| return res & mask; |
| } |
| |
| bool Module::ConstantFoldVectorShuffle(Instruction* inst, const Type& result_type) { |
| assert(result_type.spv_type_ == SpvType::kVector); |
| |
| const Constant* vec1 = type_manager_.FindConstantById(inst->Word(4)); |
| const Constant* vec2 = type_manager_.FindConstantById(inst->Word(5)); |
| if (!vec1 || !vec2) { |
| return false; |
| } |
| |
| // LongVectors should not be possible as it would require the user to know the number of literal operands to use |
| const uint32_t vec1_length = type_manager_.FindTypeById(vec1->inst_.TypeId())->VectorSize(); |
| |
| std::vector<uint32_t> words = {result_type.Id(), inst->ResultId()}; |
| |
| for (uint32_t i = 6; i < inst->Length(); i++) { |
| const uint32_t index = inst->Word(i); |
| |
| if (index < vec1_length) { |
| if (vec1->inst_.Opcode() == spv::OpConstantNull) { |
| words.emplace_back(type_manager_.GetConstantZeroUint32().Id()); |
| } else { |
| const uint32_t constant_id = vec1->inst_.Word(3 + index); |
| words.emplace_back(constant_id); |
| } |
| } else { |
| if (vec2->inst_.Opcode() == spv::OpConstantNull) { |
| words.emplace_back(type_manager_.GetConstantZeroUint32().Id()); |
| } else { |
| const uint32_t vec2_index = index - vec1_length; |
| const uint32_t constant_id = vec2->inst_.Word(3 + vec2_index); |
| words.emplace_back(constant_id); |
| } |
| } |
| } |
| |
| auto new_inst = std::make_unique<Instruction>(1 + (uint32_t)words.size(), spv::OpConstantComposite); |
| new_inst->Fill(words); |
| type_manager_.AddConstant(std::move(new_inst), result_type); |
| return true; |
| } |
| |
| bool Module::ConstantFoldCompositeExtract(Instruction* inst, const Type& result_type) { |
| // There might be multiple indices |
| uint32_t current_id = inst->Word(4); |
| for (uint32_t i = 5; i < inst->Length(); i++) { |
| const uint32_t index = inst->Word(i); |
| const Constant* composite = type_manager_.FindConstantById(current_id); |
| if (!composite) { |
| assert(false); |
| return false; |
| } else if (composite->inst_.Opcode() == spv::OpConstantNull) { |
| auto new_inst = std::make_unique<Instruction>(3, spv::OpConstantNull); |
| new_inst->Fill({result_type.Id(), inst->ResultId()}); |
| type_manager_.AddConstant(std::move(new_inst), result_type); |
| return true; |
| } else if (composite->inst_.Opcode() == spv::OpConstantComposite) { |
| // Move down the tree to the next component ID |
| current_id = composite->inst_.Word(3 + index); |
| } else { |
| assert(false); |
| return false; |
| } |
| } |
| |
| // current_id now points to the exact constant instruction we want, so just make a copy |
| const Constant* final_composite = type_manager_.FindConstantById(current_id); |
| if (!final_composite) { |
| return false; |
| } |
| // TODO - Create a helper to make a copy more easily |
| const uint32_t* raw_words = final_composite->inst_.GetRawBytes(); |
| std::vector<uint32_t> words(raw_words + 1, raw_words + final_composite->inst_.Length()); |
| words[0] = result_type.Id(); |
| words[1] = inst->ResultId(); |
| |
| auto new_inst = std::make_unique<Instruction>(1 + (uint32_t)words.size(), (spv::Op)final_composite->inst_.Opcode()); |
| new_inst->Fill(words); |
| type_manager_.AddConstant(std::move(new_inst), result_type); |
| return true; |
| } |
| |
| bool Module::ConstantFoldCompositeInsert(Instruction* inst, const Type& result_type) { |
| const Constant* base_composite = type_manager_.FindConstantById(inst->Word(5)); |
| if (!base_composite) { |
| assert(false); |
| return false; |
| } |
| |
| // This assumes 1-level insertion (like inserting a scalar into a vector) |
| if (inst->Length() > 7 && result_type.spv_type_ == SpvType::kVector) { |
| assert(false); |
| return false; |
| } |
| |
| const uint32_t index = inst->Word(6); |
| |
| std::vector<uint32_t> words = {result_type.Id(), inst->ResultId()}; |
| |
| const uint32_t vec_length = result_type.VectorSize(); |
| for (uint32_t i = 0; i < vec_length; i++) { |
| if (i == index) { |
| uint32_t insert_id = inst->Word(4); // |object| operand |
| words.emplace_back(insert_id); |
| } else if (base_composite->inst_.Opcode() == spv::OpConstantNull) { |
| words.emplace_back(type_manager_.GetConstantZeroUint32().Id()); |
| } else { |
| const uint32_t constant_id = base_composite->inst_.Word(3 + i); |
| words.emplace_back(constant_id); |
| } |
| } |
| |
| auto new_inst = std::make_unique<Instruction>(1 + (uint32_t)words.size(), spv::OpConstantComposite); |
| new_inst->Fill(words); |
| type_manager_.AddConstant(std::move(new_inst), result_type); |
| return true; |
| } |
| |
| bool Module::ConstantFold(Instruction* inst, const Type& result_type) { |
| assert(inst->Opcode() == spv::OpSpecConstantOp); |
| if (result_type.spv_type_ == SpvType::kStruct) { |
| return false; // TODO - Add support |
| } |
| |
| spv::Op target_opcode = (spv::Op)inst->Word(3); |
| |
| // OpSpecConstantOp has a limited set of instructions it can be, these are not handled the special |
| if (target_opcode == spv::OpVectorShuffle) { |
| return ConstantFoldVectorShuffle(inst, result_type); |
| } else if (target_opcode == spv::OpCompositeExtract) { |
| return ConstantFoldCompositeExtract(inst, result_type); |
| } else if (target_opcode == spv::OpCompositeInsert) { |
| return ConstantFoldCompositeInsert(inst, result_type); |
| } |
| |
| const bool is_vector = result_type.spv_type_ == SpvType::kVector; |
| uint32_t vector_length = is_vector ? result_type.inst_.Word(3) : 1; |
| const Type& scalar_type = is_vector ? *type_manager_.FindTypeById(result_type.inst_.Word(2)) : result_type; |
| |
| small_vector<uint32_t, 4> new_composite_components; |
| |
| const uint32_t start_operand_index = 4; // into OpSpecConstantOp |
| const uint32_t final_operand_index = inst->Length(); |
| // Scalar will have a single lane |
| for (uint32_t lane = 0; lane < vector_length; lane++) { |
| // might be up to 3 for OpSelect |
| small_vector<uint64_t, 3> args; |
| |
| for (uint32_t i = start_operand_index; i < final_operand_index; i++) { |
| const uint32_t operand_id = inst->Word(i); |
| const Constant* constant = type_manager_.FindConstantById(operand_id); |
| if (!constant) { |
| assert(false); // Something is wrong |
| return false; |
| } |
| |
| const Constant* lane_constant = constant; |
| if (is_vector && constant->inst_.Opcode() == spv::OpConstantComposite) { |
| uint32_t comp_id = constant->inst_.Word(3 + lane); |
| lane_constant = type_manager_.FindConstantById(comp_id); |
| if (!lane_constant) { |
| assert(false); |
| return false; |
| } |
| } |
| |
| if (lane_constant->inst_.Opcode() == spv::OpConstantNull) { |
| args.emplace_back(0ul); |
| } else if (lane_constant->type_.spv_type_ == SpvType::kBool) { |
| if (lane_constant->inst_.Opcode() == spv::OpConstantTrue) { |
| args.emplace_back(1ul); |
| } else { |
| assert(lane_constant->inst_.Opcode() == spv::OpConstantFalse); |
| args.emplace_back(0ul); |
| } |
| } else { |
| bool op_is_signed = lane_constant->type_.inst_.Word(3) != 0; |
| |
| if (target_opcode == spv::OpSConvert) { |
| // OpSConvert implies the source is treated as signed and must sign-extend |
| op_is_signed = true; |
| } else if (target_opcode == spv::OpUConvert) { |
| // OpUConvert implies the source is treated as unsigned (zero-extends) |
| op_is_signed = false; |
| } |
| |
| const uint64_t value64 = lane_constant->GetValueUint64(op_is_signed); |
| args.emplace_back(value64); |
| } |
| } |
| |
| uint64_t lane_result = 0; |
| const uint32_t result_bit_width = scalar_type.meta_.scalar.bit_width; |
| if (target_opcode == spv::OpSConvert || target_opcode == spv::OpUConvert) { |
| uint64_t mask = (result_bit_width == 64) ? ~0ULL : (1ULL << result_bit_width) - 1; |
| lane_result = args[0] & mask; |
| } else if (target_opcode == spv::OpSelect) { |
| lane_result = (args[0] != 0) ? args[1] : args[2]; |
| } else { |
| const uint64_t arg1 = args.size() > 1 ? args[1] : 0; |
| lane_result = EvaluateArithmetic(target_opcode, args[0], arg1, result_bit_width); |
| } |
| |
| if (is_vector) { |
| // If we have |
| // %12 = OpSpecConstantComposite %v2short %short_4 %short_4 |
| // %13 = OpSpecConstantOp %v2short IMul %12 %12 |
| // we will want |
| // %12 = OpConstantComposite %v2short %short_4 %short_4 |
| // %short_16 = OpConstant %short 16 |
| // %13 = OpConstantComposite %v2short %short_16 %short_16 |
| // Which means we need to need to "create" the new OpConstant here |
| uint32_t scalar_id = type_manager_.CreateConstantScalar(lane_result, scalar_type).Id(); |
| new_composite_components.emplace_back(scalar_id); |
| } else if (scalar_type.spv_type_ == SpvType::kBool) { |
| const spv::Op new_opcode = (lane_result == 0) ? spv::OpConstantFalse : spv::OpConstantTrue; |
| auto new_inst = std::make_unique<Instruction>(3, new_opcode); |
| new_inst->Fill({scalar_type.Id(), inst->ResultId()}); |
| type_manager_.AddConstant(std::move(new_inst), scalar_type); |
| return true; |
| } else { |
| type_manager_.CreateConstantScalar(lane_result, scalar_type, inst->ResultId()); |
| return true; |
| } |
| } |
| |
| assert(is_vector); |
| auto new_inst = std::make_unique<Instruction>(3 + new_composite_components.size(), spv::OpConstantComposite); |
| std::vector<uint32_t> words = {result_type.Id(), inst->ResultId()}; |
| words.insert(words.end(), new_composite_components.begin(), new_composite_components.end()); |
| new_inst->Fill(words); |
| type_manager_.AddConstant(std::move(new_inst), result_type); |
| return true; |
| } |
| |
| } // namespace spirv |
| } // namespace gpuav |