From fe774b139bd904f1b71937cb02575feba62056ed Mon Sep 17 00:00:00 2001 From: spencer-lunarg Date: Tue, 9 Jul 2024 11:45:20 -0500 Subject: [PATCH] gpu: Add Shader Conditional Functions --- layers/gpu/spirv/README.md | 3 +- layers/gpu/spirv/bindless_descriptor_pass.cpp | 35 ++++++----- layers/gpu/spirv/bindless_descriptor_pass.h | 11 ++-- .../gpu/spirv/buffer_device_address_pass.cpp | 16 +++-- layers/gpu/spirv/buffer_device_address_pass.h | 33 +--------- layers/gpu/spirv/function_basic_block.cpp | 15 +++++ layers/gpu/spirv/function_basic_block.h | 10 +-- layers/gpu/spirv/pass.cpp | 61 +++++++++++++------ layers/gpu/spirv/pass.h | 56 +++++++++++++++-- layers/gpu/spirv/ray_query_pass.cpp | 12 ++-- layers/gpu/spirv/ray_query_pass.h | 6 +- 11 files changed, 156 insertions(+), 102 deletions(-) diff --git a/layers/gpu/spirv/README.md b/layers/gpu/spirv/README.md index 611d4a3226a..dcf791aa6b1 100644 --- a/layers/gpu/spirv/README.md +++ b/layers/gpu/spirv/README.md @@ -8,7 +8,8 @@ Each pass does logic needed to know if the current instruction needs have check ## Step 2 - Inject a function call -The logic to add a `if-else` control flow logic in SPIR-V is handled by the `Pass::InjectFunctionCheck` function. This will create the various blocks and resolve any ID updates +Functions are added via `Pass::InjectFunctionCheck`, but there are cases were we want to make sure we don't call the invalid instructions. For this we add an `if-else` control flow logic in SPIR-V (all handled by the `Pass::InjectConditionalFunctionCheck`) to inject the function. This will create the various blocks and resolve any ID updates + ## Step 3 - Create the OpFunctionCall diff --git a/layers/gpu/spirv/bindless_descriptor_pass.cpp b/layers/gpu/spirv/bindless_descriptor_pass.cpp index 329158f4216..eb43674ba8e 100644 --- a/layers/gpu/spirv/bindless_descriptor_pass.cpp +++ b/layers/gpu/spirv/bindless_descriptor_pass.cpp @@ -78,7 +78,7 @@ uint32_t BindlessDescriptorPass::FindTypeByteSize(uint32_t type_id, uint32_t mat // Find outermost buffer type and its access chain index. // Because access chains indexes can be runtime values, we need to build arithmetic logic in the SPIR-V to get the runtime value of // the indexing -uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block) { +uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block, InstructionIt* inst_it) { const Type* pointer_type = module_.type_manager_.FindTypeById(var_inst_->TypeId()); const Type* descriptor_type = module_.type_manager_.FindTypeById(pointer_type->inst_.Word(3)); @@ -119,7 +119,7 @@ uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block) { const uint32_t ac_index_id_32 = ConvertTo32(ac_index_id, block); current_offset_id = module_.TakeNextId(); - block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, arr_stride_id, ac_index_id_32}); + block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, arr_stride_id, ac_index_id_32}, inst_it); // Get element type for next step current_type_id = current_type->inst_.Operand(0); @@ -142,7 +142,7 @@ uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block) { const uint32_t ac_index_id_32 = ConvertTo32(ac_index_id, block); current_offset_id = module_.TakeNextId(); - block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, col_stride_id, ac_index_id_32}); + block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, col_stride_id, ac_index_id_32}, inst_it); // Get element type for next step current_type_id = vec_type_id; @@ -155,13 +155,14 @@ uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block) { const uint32_t ac_index_id_32 = ConvertTo32(ac_index_id, block); if (in_matrix && !col_major) { current_offset_id = module_.TakeNextId(); - block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, matrix_stride_id, ac_index_id_32}); + block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, matrix_stride_id, ac_index_id_32}, + inst_it); } else { const uint32_t component_type_size = FindTypeByteSize(component_type_id); const uint32_t size_id = module_.type_manager_.GetConstantUInt32(component_type_size).Id(); current_offset_id = module_.TakeNextId(); - block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, size_id, ac_index_id_32}); + block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, size_id, ac_index_id_32}, inst_it); } // Get element type for next step current_type_id = component_type_id; @@ -197,7 +198,7 @@ uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block) { sum_id = current_offset_id; } else { const uint32_t new_sum_id = module_.TakeNextId(); - block.CreateInstruction(spv::OpIAdd, {uint32_type.Id(), new_sum_id, sum_id, current_offset_id}); + block.CreateInstruction(spv::OpIAdd, {uint32_type.Id(), new_sum_id, sum_id, current_offset_id}, inst_it); sum_id = new_sum_id; } ac_word_index++; @@ -210,16 +211,12 @@ uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block) { const uint32_t last_id = module_.type_manager_.GetConstantUInt32(last).Id(); const uint32_t new_sum_id = module_.TakeNextId(); - block.CreateInstruction(spv::OpIAdd, {uint32_type.Id(), new_sum_id, sum_id, last_id}); + block.CreateInstruction(spv::OpIAdd, {uint32_type.Id(), new_sum_id, sum_id, last_id}, inst_it); return new_sum_id; } -uint32_t BindlessDescriptorPass::CreateFunctionCall(BasicBlock& block) { - // Add any debug information to pass into the function call - const uint32_t stage_info_id = GetStageInfo(block.function_); - const uint32_t inst_position = target_instruction_->position_index_; - auto inst_position_constant = module_.type_manager_.CreateConstantUInt32(inst_position); - +uint32_t BindlessDescriptorPass::CreateFunctionCall(BasicBlock& block, InstructionIt* inst_it, + const InjectionData& injection_data) { const Constant& set_constant = module_.type_manager_.GetConstantUInt32(descriptor_set_); const Constant& binding_constant = module_.type_manager_.GetConstantUInt32(descriptor_binding_); const uint32_t descriptor_index_id = CastToUint32(descriptor_index_id_, block); // might be int32 @@ -262,7 +259,7 @@ uint32_t BindlessDescriptorPass::CreateFunctionCall(BasicBlock& block) { auto copied = copy_object_map_.find(image_id); if (copied != copy_object_map_.end()) { image_id = copied->second; - block.CreateInstruction(spv::OpCopyObject, {type_id, copy_id, image_id}); + block.CreateInstruction(spv::OpCopyObject, {type_id, copy_id, image_id}, inst_it); } else { copy_object_map_.emplace(image_id, copy_id); // slower, but need to guarantee it is placed after a OpSampledImage @@ -278,7 +275,7 @@ uint32_t BindlessDescriptorPass::CreateFunctionCall(BasicBlock& block) { const Type* pointee_type = module_.type_manager_.FindTypeById(pointer_type->inst_.Word(3)); if (pointee_type && pointee_type->spv_type_ != SpvType::kArray && pointee_type->spv_type_ != SpvType::kRuntimeArray && pointee_type->spv_type_ != SpvType::kStruct) { - descriptor_offset_id_ = GetLastByte(block); // Get Last Byte Index + descriptor_offset_id_ = GetLastByte(block, inst_it); // Get Last Byte Index } } @@ -290,9 +287,11 @@ uint32_t BindlessDescriptorPass::CreateFunctionCall(BasicBlock& block) { const uint32_t function_def = GetLinkFunctionId(); const uint32_t bool_type = module_.type_manager_.GetTypeBool().Id(); - block.CreateInstruction(spv::OpFunctionCall, - {bool_type, function_result, function_def, inst_position_constant.Id(), stage_info_id, - set_constant.Id(), binding_constant.Id(), descriptor_index_id, descriptor_offset_id_}); + block.CreateInstruction( + spv::OpFunctionCall, + {bool_type, function_result, function_def, injection_data.inst_position_id, injection_data.stage_info_id, set_constant.Id(), + binding_constant.Id(), descriptor_index_id, descriptor_offset_id_}, + inst_it); return function_result; } diff --git a/layers/gpu/spirv/bindless_descriptor_pass.h b/layers/gpu/spirv/bindless_descriptor_pass.h index 7239f9e4662..524b735536d 100644 --- a/layers/gpu/spirv/bindless_descriptor_pass.h +++ b/layers/gpu/spirv/bindless_descriptor_pass.h @@ -27,20 +27,18 @@ struct BasicBlock; // Create a pass to instrument bindless descriptor checking // This pass instruments all bindless references to check that descriptor // array indices are inbounds, and if the descriptor indexing extension is -// enabled, that the descriptor has been initialized. If the reference is -// invalid, a record is written to the debug output buffer (if space allows) -// and a null value is returned. +// enabled, that the descriptor has been initialized. class BindlessDescriptorPass : public Pass { public: - BindlessDescriptorPass(Module& module) : Pass(module) {} + BindlessDescriptorPass(Module& module) : Pass(module, true) {} private: bool AnalyzeInstruction(const Function& function, const Instruction& inst) final; - uint32_t CreateFunctionCall(BasicBlock& block) final; + uint32_t CreateFunctionCall(BasicBlock& block, InstructionIt* inst_it, const InjectionData& injection_data) final; void Reset() final; uint32_t FindTypeByteSize(uint32_t type_id, uint32_t matrix_stride = 0, bool col_major = false, bool in_matrix = false); - uint32_t GetLastByte(BasicBlock& block); + uint32_t GetLastByte(BasicBlock& block, InstructionIt* inst_it); uint32_t link_function_id = 0; uint32_t GetLinkFunctionId(); @@ -49,7 +47,6 @@ class BindlessDescriptorPass : public Pass { const Instruction* var_inst_ = nullptr; const Instruction* image_inst_ = nullptr; - const Instruction* target_instruction_ = nullptr; uint32_t descriptor_set_ = 0; uint32_t descriptor_binding_ = 0; uint32_t descriptor_index_id_ = 0; diff --git a/layers/gpu/spirv/buffer_device_address_pass.cpp b/layers/gpu/spirv/buffer_device_address_pass.cpp index 11a0b85a7f4..b6e4b239ff8 100644 --- a/layers/gpu/spirv/buffer_device_address_pass.cpp +++ b/layers/gpu/spirv/buffer_device_address_pass.cpp @@ -35,17 +35,13 @@ uint32_t BufferDeviceAddressPass::GetLinkFunctionId() { return link_function_id; } -uint32_t BufferDeviceAddressPass::CreateFunctionCall(BasicBlock& block) { - // Add any debug information to pass into the function call - const uint32_t stage_info_id = GetStageInfo(block.function_); - const uint32_t inst_position = target_instruction_->position_index_; - auto inst_position_constant = module_.type_manager_.CreateConstantUInt32(inst_position); - +uint32_t BufferDeviceAddressPass::CreateFunctionCall(BasicBlock& block, InstructionIt* inst_it, + const InjectionData& injection_data) { // Convert reference pointer to uint64 const uint32_t pointer_id = target_instruction_->Operand(0); const Type& uint64_type = module_.type_manager_.GetTypeInt(64, 0); const uint32_t convert_id = module_.TakeNextId(); - block.CreateInstruction(spv::OpConvertPtrToU, {uint64_type.Id(), convert_id, pointer_id}); + block.CreateInstruction(spv::OpConvertPtrToU, {uint64_type.Id(), convert_id, pointer_id}, inst_it); const Constant& length_constant = module_.type_manager_.GetConstantUInt32(type_length_); const Constant& access_opcode = module_.type_manager_.GetConstantUInt32(access_opcode_); @@ -54,8 +50,10 @@ uint32_t BufferDeviceAddressPass::CreateFunctionCall(BasicBlock& block) { const uint32_t function_def = GetLinkFunctionId(); const uint32_t bool_type = module_.type_manager_.GetTypeBool().Id(); - block.CreateInstruction(spv::OpFunctionCall, {bool_type, function_result, function_def, inst_position_constant.Id(), - stage_info_id, convert_id, length_constant.Id(), access_opcode.Id()}); + block.CreateInstruction(spv::OpFunctionCall, + {bool_type, function_result, function_def, injection_data.inst_position_id, + injection_data.stage_info_id, convert_id, length_constant.Id(), access_opcode.Id()}, + inst_it); return function_result; } diff --git a/layers/gpu/spirv/buffer_device_address_pass.h b/layers/gpu/spirv/buffer_device_address_pass.h index 9003cb17130..b77598b57d7 100644 --- a/layers/gpu/spirv/buffer_device_address_pass.h +++ b/layers/gpu/spirv/buffer_device_address_pass.h @@ -26,46 +26,19 @@ struct BasicBlock; // Create a pass to instrument physical buffer address checking // This pass instruments all physical buffer address references to check that -// all referenced bytes fall in a valid buffer. If the reference is -// invalid, a record is written to the debug output buffer (if space allows) -// and a null value is returned. -// -// For OpStore we will just ignore the store if it is invalid, example: -// Before: -// bda.data[index] = value; -// After: -// if (isValid(bda.data, index)) { -// bda.data[index] = value; -// } -// -// For OpLoad we replace the value with Zero (via Phi node) if it is invalid, example -// Before: -// int X = bda.data[index]; -// int Y = bda.data[X]; -// After: -// if (isValid(bda.data, index)) { -// int X = bda.data[index]; -// } else { -// int X = 0; -// } -// if (isValid(bda.data, X)) { -// int Y = bda.data[X]; -// } else { -// int Y = 0; -// } +// all referenced bytes fall in a valid buffer. class BufferDeviceAddressPass : public Pass { public: - BufferDeviceAddressPass(Module& module) : Pass(module) {} + BufferDeviceAddressPass(Module& module) : Pass(module, true) {} private: bool AnalyzeInstruction(const Function& function, const Instruction& inst) final; - uint32_t CreateFunctionCall(BasicBlock& block) final; + uint32_t CreateFunctionCall(BasicBlock& block, InstructionIt* inst_it, const InjectionData& injection_data) final; void Reset() final; uint32_t link_function_id = 0; uint32_t GetLinkFunctionId(); - const Instruction* target_instruction_ = nullptr; uint32_t type_length_ = 0; uint32_t access_opcode_ = 0; }; diff --git a/layers/gpu/spirv/function_basic_block.cpp b/layers/gpu/spirv/function_basic_block.cpp index fea73404797..83dafa5af10 100644 --- a/layers/gpu/spirv/function_basic_block.cpp +++ b/layers/gpu/spirv/function_basic_block.cpp @@ -50,6 +50,16 @@ BasicBlock::BasicBlock(Module& module, Function& function) : function_(function) uint32_t BasicBlock::GetLabelId() { return (*(instructions_[0])).ResultId(); } +InstructionIt BasicBlock::GetFirstInjectableInstrution() { + InstructionIt inst_it; + for (inst_it = instructions_.begin(); inst_it != instructions_.end(); ++inst_it) { + if ((*inst_it)->Opcode() != spv::OpLabel && (*inst_it)->Opcode() != spv::OpVariable) { + break; + } + } + return inst_it; +} + void BasicBlock::CreateInstruction(spv::Op opcode, const std::vector& words, InstructionIt* inst_it) { const bool add_to_end = inst_it == nullptr; InstructionIt last_inst = instructions_.end(); @@ -74,6 +84,11 @@ void BasicBlock::CreateInstruction(spv::Op opcode, const std::vector& } } +Function::Function(Module& module, std::unique_ptr function_inst) : module_(module) { + // Used when loading initial SPIR-V + pre_block_inst_.push_back(std::move(function_inst)); // OpFunction +} + BasicBlockIt Function::InsertNewBlock(BasicBlockIt it) { auto new_block = std::make_unique(module_, (*it)->function_); it++; // make sure it inserted after diff --git a/layers/gpu/spirv/function_basic_block.h b/layers/gpu/spirv/function_basic_block.h index d536732fe8d..19d93f0d6f0 100644 --- a/layers/gpu/spirv/function_basic_block.h +++ b/layers/gpu/spirv/function_basic_block.h @@ -43,6 +43,10 @@ struct BasicBlock { uint32_t GetLabelId(); + // "All OpVariable instructions in a function must be the first instructions in the first block" + // So need to get the first valid location in block. + InstructionIt GetFirstInjectableInstrution(); + // Creates instruction and inserts it before the Instruction, updates poistion after new instruciton. // If no InstructionIt is provided, it will add it to the end of the block. void CreateInstruction(spv::Op opcode, const std::vector& words, InstructionIt* inst_it = nullptr); @@ -57,15 +61,13 @@ using BasicBlockList = std::vector>; using BasicBlockIt = BasicBlockList::iterator; struct Function { - Function(Module& module, std::unique_ptr function_inst) : module_(module) { - // Used when loading initial SPIR-V - pre_block_inst_.push_back(std::move(function_inst)); // OpFunction - } + Function(Module& module, std::unique_ptr function_inst); Function(Module& module) : module_(module) {} void ToBinary(std::vector& out); const Instruction& GetDef() { return *pre_block_inst_[0].get(); } + BasicBlock& GetFirstBlock() { return *blocks_[0]; } // Adds a new block after and returns reference to it BasicBlockIt InsertNewBlock(BasicBlockIt it); diff --git a/layers/gpu/spirv/pass.cpp b/layers/gpu/spirv/pass.cpp index 5d209f4abcb..d6eecbe80da 100644 --- a/layers/gpu/spirv/pass.cpp +++ b/layers/gpu/spirv/pass.cpp @@ -58,22 +58,14 @@ const Variable& Pass::GetBuiltinVariable(uint32_t built_in) { // To reduce having to load this information everytime we do a OpFunctionCall, instead just create it once per Function block and // reference it each time -uint32_t Pass::GetStageInfo(Function& function) { +uint32_t Pass::GetStageInfo(Function& function, BasicBlockIt target_block_it, InstructionIt& target_inst_it) { // Cached so only need to compute this once if (function.stage_info_id_ != 0) { return function.stage_info_id_; } - // Get the first block of function to add stage info - BasicBlock& block = *function.blocks_[0]; - - // All OpVariable instructions in a function must be the first instructions in the first block - InstructionIt inst_it; - for (inst_it = block.instructions_.begin(); inst_it != block.instructions_.end(); ++inst_it) { - if ((*inst_it)->Opcode() != spv::OpLabel && (*inst_it)->Opcode() != spv::OpVariable) { - break; - } - } + BasicBlock& block = function.GetFirstBlock(); + InstructionIt inst_it = block.GetFirstInjectableInstrution(); // Stage info is always passed in as a uvec4 const Type& uint32_type = module_.type_manager_.GetTypeInt(32, false); @@ -184,6 +176,12 @@ uint32_t Pass::GetStageInfo(Function& function) { {uvec4_type.Id(), function.stage_info_id_, stage_info[0], stage_info[1], stage_info[2], stage_info[3]}, &inst_it); + // because we are injecting things in the first block, there is a chance we just destroyed the iterator if the target + // instruction was also in the first block, so need to regain it for the caller + if ((*target_block_it)->GetLabelId() == block.GetLabelId()) { + target_inst_it = FindTargetInstruction(block); + } + return function.stage_info_id_; } @@ -265,7 +263,8 @@ uint32_t Pass::CastToUint32(uint32_t id, BasicBlock& block, InstructionIt* inst_ return new_id; // Return an id to the Uint equivalent. } -BasicBlockIt Pass::InjectFunctionCheck(Function* function, BasicBlockIt block_it, InstructionIt inst_it) { +BasicBlockIt Pass::InjectConditionalFunctionCheck(Function* function, BasicBlockIt block_it, InstructionIt inst_it, + const InjectionData& injection_data) { // We turn the block into 4 separate blocks block_it = function->InsertNewBlock(block_it); block_it = function->InsertNewBlock(block_it); @@ -347,18 +346,34 @@ BasicBlockIt Pass::InjectFunctionCheck(Function* function, BasicBlockIt block_it original_block.instructions_.erase(inst_it, original_block.instructions_.end()); // Go back to original Block and add function call and branch from the bool result - const uint32_t function_result = CreateFunctionCall(original_block); + const uint32_t function_result = CreateFunctionCall(original_block, nullptr, injection_data); original_block.CreateInstruction(spv::OpSelectionMerge, {merge_block_label, spv::SelectionControlMaskNone}); original_block.CreateInstruction(spv::OpBranchConditional, {function_result, valid_block_label, invalid_block_label}); - // clear values incase multiple calls are made Reset(); return block_it; } +void Pass::InjectFunctionCheck(BasicBlockIt block_it, InstructionIt* inst_it, const InjectionData& injection_data) { + CreateFunctionCall(**block_it, inst_it, injection_data); + Reset(); +} + +InstructionIt Pass::FindTargetInstruction(BasicBlock& block) const { + const uint32_t target_id = target_instruction_->ResultId(); + for (auto inst_it = block.instructions_.begin(); inst_it != block.instructions_.end(); ++inst_it) { + if ((*inst_it)->ResultId() == target_id) { + return inst_it; + } + } + assert(false); + return block.instructions_.end(); +} + void Pass::Run() { + // Can safely loop function list as there is no injecting of new Functions until linking time for (const auto& function : module_.functions_) { for (auto block_it = function->blocks_.begin(); block_it != function->blocks_.end(); ++block_it) { if ((*block_it)->loop_header_) { @@ -366,12 +381,24 @@ void Pass::Run() { } auto& block_instructions = (*block_it)->instructions_; for (auto inst_it = block_instructions.begin(); inst_it != block_instructions.end(); ++inst_it) { - if (AnalyzeInstruction(*(function.get()), *(inst_it->get()))) { - block_it = InjectFunctionCheck(function.get(), block_it, inst_it); - + // Every instruction is analyzed by the specific pass and lets us know if we need to inject a function or not + if (!AnalyzeInstruction(*function, *(inst_it->get()))) continue; + + // Add any debug information to pass into the function call + InjectionData injection_data; + injection_data.stage_info_id = GetStageInfo(*function, block_it, inst_it); + const uint32_t inst_position = target_instruction_->position_index_; + auto inst_position_constant = module_.type_manager_.CreateConstantUInt32(inst_position); + injection_data.inst_position_id = inst_position_constant.Id(); + + if (conditional_function_check_) { + block_it = InjectConditionalFunctionCheck(function.get(), block_it, inst_it, injection_data); // will start searching again from newly split merge block block_it--; break; + } else { + // inst_it is updated to the instruction after the new function call, it will not add/remove any Blocks + InjectFunctionCheck(block_it, &inst_it, injection_data); } } } diff --git a/layers/gpu/spirv/pass.h b/layers/gpu/spirv/pass.h index b2f46b88f88..085bef5e798 100644 --- a/layers/gpu/spirv/pass.h +++ b/layers/gpu/spirv/pass.h @@ -25,6 +25,12 @@ class Module; struct Variable; struct BasicBlock; +// Info we know is the same regardless what pass is consuming the CreateFunctionCall() +struct InjectionData { + uint32_t stage_info_id; + uint32_t inst_position_id; +}; + // Common helpers for all passes class Pass { public: @@ -34,7 +40,7 @@ class Pass { const Variable& GetBuiltinVariable(uint32_t built_in); // Returns the ID for OpCompositeConstruct it creates - uint32_t GetStageInfo(Function& function); + uint32_t GetStageInfo(Function& function, BasicBlockIt target_block_it, InstructionIt& target_inst_it); const Instruction* GetDecoration(uint32_t id, spv::Decoration decoration); const Instruction* GetMemeberDecoration(uint32_t id, uint32_t member_index, spv::Decoration decoration); @@ -45,17 +51,59 @@ class Pass { uint32_t CastToUint32(uint32_t id, BasicBlock& block, InstructionIt* inst_it = nullptr); protected: - Pass(Module& module) : module_(module) {} + Pass(Module& module, bool conditional_function_check) + : module_(module), conditional_function_check_(conditional_function_check) {} Module& module_; - BasicBlockIt InjectFunctionCheck(Function* function, BasicBlockIt block_it, InstructionIt inst_it); + BasicBlockIt InjectConditionalFunctionCheck(Function* function, BasicBlockIt block_it, InstructionIt inst_it, + const InjectionData& injection_data); + void InjectFunctionCheck(BasicBlockIt block_it, InstructionIt* inst_it, const InjectionData& injection_data); // Each pass decides if the instruction should needs to have its function check injected virtual bool AnalyzeInstruction(const Function& function, const Instruction& inst) = 0; // A callback from the function injection logic. // Each pass creates a OpFunctionCall and returns its result id. - virtual uint32_t CreateFunctionCall(BasicBlock& block) = 0; + // If |inst_it| is not null, it will update it to instruction post OpFunctionCall + virtual uint32_t CreateFunctionCall(BasicBlock& block, InstructionIt* inst_it, const InjectionData& injection_data) = 0; + // clear values incase multiple injections are made virtual void Reset() = 0; + + // If this is false, we assume through other means (such as robustness) we won't crash on bad values and go + // PassFunction(original_value) + // value = original_value; + // + // Otherwise, we will have wrap the checks to be safe + // For OpStore we will just ignore the store if it is invalid, example: + // Before: + // bda.data[index] = value; + // After: + // if (isValid(bda.data, index)) { + // bda.data[index] = value; + // } + // + // For OpLoad we replace the value with Zero (via Phi node) if it is invalid, example + // Before: + // int X = bda.data[index]; + // int Y = bda.data[X]; + // After: + // if (isValid(bda.data, index)) { + // int X = bda.data[index]; + // } else { + // int X = 0; + // } + // if (isValid(bda.data, X)) { + // int Y = bda.data[X]; + // } else { + // int Y = 0; + // } + const bool conditional_function_check_; + + // As various things are modifiying the instruction streams, we need to get back to where we were. + // Every pass needs to set this in AnalyzeInstruction() + const Instruction* target_instruction_ = nullptr; + + private: + InstructionIt FindTargetInstruction(BasicBlock& block) const; }; } // namespace spirv diff --git a/layers/gpu/spirv/ray_query_pass.cpp b/layers/gpu/spirv/ray_query_pass.cpp index d8fd97ab282..129b2db41ed 100644 --- a/layers/gpu/spirv/ray_query_pass.cpp +++ b/layers/gpu/spirv/ray_query_pass.cpp @@ -35,12 +35,7 @@ uint32_t RayQueryPass::GetLinkFunctionId() { return link_function_id; } -uint32_t RayQueryPass::CreateFunctionCall(BasicBlock& block) { - // Add any debug information to pass into the function call - const uint32_t stage_info_id = GetStageInfo(block.function_); - const uint32_t inst_position = target_instruction_->position_index_; - auto inst_position_constant = module_.type_manager_.CreateConstantUInt32(inst_position); - +uint32_t RayQueryPass::CreateFunctionCall(BasicBlock& block, InstructionIt* inst_it, const InjectionData& injection_data) { const uint32_t function_result = module_.TakeNextId(); const uint32_t function_def = GetLinkFunctionId(); const uint32_t bool_type = module_.type_manager_.GetTypeBool().Id(); @@ -52,8 +47,9 @@ uint32_t RayQueryPass::CreateFunctionCall(BasicBlock& block) { const uint32_t ray_tmax_id = target_instruction_->Operand(7); block.CreateInstruction(spv::OpFunctionCall, - {bool_type, function_result, function_def, inst_position_constant.Id(), stage_info_id, ray_flags_id, - ray_origin_id, ray_tmin_id, ray_direction_id, ray_tmax_id}); + {bool_type, function_result, function_def, injection_data.inst_position_id, + injection_data.stage_info_id, ray_flags_id, ray_origin_id, ray_tmin_id, ray_direction_id, ray_tmax_id}, + inst_it); return function_result; } diff --git a/layers/gpu/spirv/ray_query_pass.h b/layers/gpu/spirv/ray_query_pass.h index 502b5a5d6e8..0a2c3bf2633 100644 --- a/layers/gpu/spirv/ray_query_pass.h +++ b/layers/gpu/spirv/ray_query_pass.h @@ -27,17 +27,15 @@ struct BasicBlock; // Create a pass to instrument SPV_KHR_ray_query instructions class RayQueryPass : public Pass { public: - RayQueryPass(Module& module) : Pass(module) {} + RayQueryPass(Module& module) : Pass(module, true) {} private: bool AnalyzeInstruction(const Function& function, const Instruction& inst) final; - uint32_t CreateFunctionCall(BasicBlock& block) final; + uint32_t CreateFunctionCall(BasicBlock& block, InstructionIt* inst_it, const InjectionData& injection_data) final; void Reset() final; uint32_t link_function_id = 0; uint32_t GetLinkFunctionId(); - - const Instruction* target_instruction_ = nullptr; }; } // namespace spirv