Skip to content

Commit

Permalink
gpu: Add Shader Conditional Functions
Browse files Browse the repository at this point in the history
  • Loading branch information
spencer-lunarg committed Jul 9, 2024
1 parent d03e2bc commit fe774b1
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 102 deletions.
3 changes: 2 additions & 1 deletion layers/gpu/spirv/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ Each pass does logic needed to know if the current instruction needs have check

## Step 2 - Inject a function call

The logic to add a `if-else` control flow logic in SPIR-V is handled by the `Pass::InjectFunctionCheck` function. This will create the various blocks and resolve any ID updates
Functions are added via `Pass::InjectFunctionCheck`, but there are cases were we want to make sure we don't call the invalid instructions. For this we add an `if-else` control flow logic in SPIR-V (all handled by the `Pass::InjectConditionalFunctionCheck`) to inject the function. This will create the various blocks and resolve any ID updates


## Step 3 - Create the OpFunctionCall

Expand Down
35 changes: 17 additions & 18 deletions layers/gpu/spirv/bindless_descriptor_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ uint32_t BindlessDescriptorPass::FindTypeByteSize(uint32_t type_id, uint32_t mat
// Find outermost buffer type and its access chain index.
// Because access chains indexes can be runtime values, we need to build arithmetic logic in the SPIR-V to get the runtime value of
// the indexing
uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block) {
uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block, InstructionIt* inst_it) {
const Type* pointer_type = module_.type_manager_.FindTypeById(var_inst_->TypeId());
const Type* descriptor_type = module_.type_manager_.FindTypeById(pointer_type->inst_.Word(3));

Expand Down Expand Up @@ -119,7 +119,7 @@ uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block) {
const uint32_t ac_index_id_32 = ConvertTo32(ac_index_id, block);

current_offset_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, arr_stride_id, ac_index_id_32});
block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, arr_stride_id, ac_index_id_32}, inst_it);

// Get element type for next step
current_type_id = current_type->inst_.Operand(0);
Expand All @@ -142,7 +142,7 @@ uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block) {

const uint32_t ac_index_id_32 = ConvertTo32(ac_index_id, block);
current_offset_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, col_stride_id, ac_index_id_32});
block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, col_stride_id, ac_index_id_32}, inst_it);

// Get element type for next step
current_type_id = vec_type_id;
Expand All @@ -155,13 +155,14 @@ uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block) {
const uint32_t ac_index_id_32 = ConvertTo32(ac_index_id, block);
if (in_matrix && !col_major) {
current_offset_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, matrix_stride_id, ac_index_id_32});
block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, matrix_stride_id, ac_index_id_32},
inst_it);
} else {
const uint32_t component_type_size = FindTypeByteSize(component_type_id);
const uint32_t size_id = module_.type_manager_.GetConstantUInt32(component_type_size).Id();

current_offset_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, size_id, ac_index_id_32});
block.CreateInstruction(spv::OpIMul, {uint32_type.Id(), current_offset_id, size_id, ac_index_id_32}, inst_it);
}
// Get element type for next step
current_type_id = component_type_id;
Expand Down Expand Up @@ -197,7 +198,7 @@ uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block) {
sum_id = current_offset_id;
} else {
const uint32_t new_sum_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIAdd, {uint32_type.Id(), new_sum_id, sum_id, current_offset_id});
block.CreateInstruction(spv::OpIAdd, {uint32_type.Id(), new_sum_id, sum_id, current_offset_id}, inst_it);
sum_id = new_sum_id;
}
ac_word_index++;
Expand All @@ -210,16 +211,12 @@ uint32_t BindlessDescriptorPass::GetLastByte(BasicBlock& block) {
const uint32_t last_id = module_.type_manager_.GetConstantUInt32(last).Id();

const uint32_t new_sum_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIAdd, {uint32_type.Id(), new_sum_id, sum_id, last_id});
block.CreateInstruction(spv::OpIAdd, {uint32_type.Id(), new_sum_id, sum_id, last_id}, inst_it);
return new_sum_id;
}

uint32_t BindlessDescriptorPass::CreateFunctionCall(BasicBlock& block) {
// Add any debug information to pass into the function call
const uint32_t stage_info_id = GetStageInfo(block.function_);
const uint32_t inst_position = target_instruction_->position_index_;
auto inst_position_constant = module_.type_manager_.CreateConstantUInt32(inst_position);

uint32_t BindlessDescriptorPass::CreateFunctionCall(BasicBlock& block, InstructionIt* inst_it,
const InjectionData& injection_data) {
const Constant& set_constant = module_.type_manager_.GetConstantUInt32(descriptor_set_);
const Constant& binding_constant = module_.type_manager_.GetConstantUInt32(descriptor_binding_);
const uint32_t descriptor_index_id = CastToUint32(descriptor_index_id_, block); // might be int32
Expand Down Expand Up @@ -262,7 +259,7 @@ uint32_t BindlessDescriptorPass::CreateFunctionCall(BasicBlock& block) {
auto copied = copy_object_map_.find(image_id);
if (copied != copy_object_map_.end()) {
image_id = copied->second;
block.CreateInstruction(spv::OpCopyObject, {type_id, copy_id, image_id});
block.CreateInstruction(spv::OpCopyObject, {type_id, copy_id, image_id}, inst_it);
} else {
copy_object_map_.emplace(image_id, copy_id);
// slower, but need to guarantee it is placed after a OpSampledImage
Expand All @@ -278,7 +275,7 @@ uint32_t BindlessDescriptorPass::CreateFunctionCall(BasicBlock& block) {
const Type* pointee_type = module_.type_manager_.FindTypeById(pointer_type->inst_.Word(3));
if (pointee_type && pointee_type->spv_type_ != SpvType::kArray && pointee_type->spv_type_ != SpvType::kRuntimeArray &&
pointee_type->spv_type_ != SpvType::kStruct) {
descriptor_offset_id_ = GetLastByte(block); // Get Last Byte Index
descriptor_offset_id_ = GetLastByte(block, inst_it); // Get Last Byte Index
}
}

Expand All @@ -290,9 +287,11 @@ uint32_t BindlessDescriptorPass::CreateFunctionCall(BasicBlock& block) {
const uint32_t function_def = GetLinkFunctionId();
const uint32_t bool_type = module_.type_manager_.GetTypeBool().Id();

block.CreateInstruction(spv::OpFunctionCall,
{bool_type, function_result, function_def, inst_position_constant.Id(), stage_info_id,
set_constant.Id(), binding_constant.Id(), descriptor_index_id, descriptor_offset_id_});
block.CreateInstruction(
spv::OpFunctionCall,
{bool_type, function_result, function_def, injection_data.inst_position_id, injection_data.stage_info_id, set_constant.Id(),
binding_constant.Id(), descriptor_index_id, descriptor_offset_id_},
inst_it);

return function_result;
}
Expand Down
11 changes: 4 additions & 7 deletions layers/gpu/spirv/bindless_descriptor_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,18 @@ struct BasicBlock;
// Create a pass to instrument bindless descriptor checking
// This pass instruments all bindless references to check that descriptor
// array indices are inbounds, and if the descriptor indexing extension is
// enabled, that the descriptor has been initialized. If the reference is
// invalid, a record is written to the debug output buffer (if space allows)
// and a null value is returned.
// enabled, that the descriptor has been initialized.
class BindlessDescriptorPass : public Pass {
public:
BindlessDescriptorPass(Module& module) : Pass(module) {}
BindlessDescriptorPass(Module& module) : Pass(module, true) {}

private:
bool AnalyzeInstruction(const Function& function, const Instruction& inst) final;
uint32_t CreateFunctionCall(BasicBlock& block) final;
uint32_t CreateFunctionCall(BasicBlock& block, InstructionIt* inst_it, const InjectionData& injection_data) final;
void Reset() final;

uint32_t FindTypeByteSize(uint32_t type_id, uint32_t matrix_stride = 0, bool col_major = false, bool in_matrix = false);
uint32_t GetLastByte(BasicBlock& block);
uint32_t GetLastByte(BasicBlock& block, InstructionIt* inst_it);

uint32_t link_function_id = 0;
uint32_t GetLinkFunctionId();
Expand All @@ -49,7 +47,6 @@ class BindlessDescriptorPass : public Pass {
const Instruction* var_inst_ = nullptr;
const Instruction* image_inst_ = nullptr;

const Instruction* target_instruction_ = nullptr;
uint32_t descriptor_set_ = 0;
uint32_t descriptor_binding_ = 0;
uint32_t descriptor_index_id_ = 0;
Expand Down
16 changes: 7 additions & 9 deletions layers/gpu/spirv/buffer_device_address_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,13 @@ uint32_t BufferDeviceAddressPass::GetLinkFunctionId() {
return link_function_id;
}

uint32_t BufferDeviceAddressPass::CreateFunctionCall(BasicBlock& block) {
// Add any debug information to pass into the function call
const uint32_t stage_info_id = GetStageInfo(block.function_);
const uint32_t inst_position = target_instruction_->position_index_;
auto inst_position_constant = module_.type_manager_.CreateConstantUInt32(inst_position);

uint32_t BufferDeviceAddressPass::CreateFunctionCall(BasicBlock& block, InstructionIt* inst_it,
const InjectionData& injection_data) {
// Convert reference pointer to uint64
const uint32_t pointer_id = target_instruction_->Operand(0);
const Type& uint64_type = module_.type_manager_.GetTypeInt(64, 0);
const uint32_t convert_id = module_.TakeNextId();
block.CreateInstruction(spv::OpConvertPtrToU, {uint64_type.Id(), convert_id, pointer_id});
block.CreateInstruction(spv::OpConvertPtrToU, {uint64_type.Id(), convert_id, pointer_id}, inst_it);

const Constant& length_constant = module_.type_manager_.GetConstantUInt32(type_length_);
const Constant& access_opcode = module_.type_manager_.GetConstantUInt32(access_opcode_);
Expand All @@ -54,8 +50,10 @@ uint32_t BufferDeviceAddressPass::CreateFunctionCall(BasicBlock& block) {
const uint32_t function_def = GetLinkFunctionId();
const uint32_t bool_type = module_.type_manager_.GetTypeBool().Id();

block.CreateInstruction(spv::OpFunctionCall, {bool_type, function_result, function_def, inst_position_constant.Id(),
stage_info_id, convert_id, length_constant.Id(), access_opcode.Id()});
block.CreateInstruction(spv::OpFunctionCall,
{bool_type, function_result, function_def, injection_data.inst_position_id,
injection_data.stage_info_id, convert_id, length_constant.Id(), access_opcode.Id()},
inst_it);

return function_result;
}
Expand Down
33 changes: 3 additions & 30 deletions layers/gpu/spirv/buffer_device_address_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,46 +26,19 @@ struct BasicBlock;

// Create a pass to instrument physical buffer address checking
// This pass instruments all physical buffer address references to check that
// all referenced bytes fall in a valid buffer. If the reference is
// invalid, a record is written to the debug output buffer (if space allows)
// and a null value is returned.
//
// For OpStore we will just ignore the store if it is invalid, example:
// Before:
// bda.data[index] = value;
// After:
// if (isValid(bda.data, index)) {
// bda.data[index] = value;
// }
//
// For OpLoad we replace the value with Zero (via Phi node) if it is invalid, example
// Before:
// int X = bda.data[index];
// int Y = bda.data[X];
// After:
// if (isValid(bda.data, index)) {
// int X = bda.data[index];
// } else {
// int X = 0;
// }
// if (isValid(bda.data, X)) {
// int Y = bda.data[X];
// } else {
// int Y = 0;
// }
// all referenced bytes fall in a valid buffer.
class BufferDeviceAddressPass : public Pass {
public:
BufferDeviceAddressPass(Module& module) : Pass(module) {}
BufferDeviceAddressPass(Module& module) : Pass(module, true) {}

private:
bool AnalyzeInstruction(const Function& function, const Instruction& inst) final;
uint32_t CreateFunctionCall(BasicBlock& block) final;
uint32_t CreateFunctionCall(BasicBlock& block, InstructionIt* inst_it, const InjectionData& injection_data) final;
void Reset() final;

uint32_t link_function_id = 0;
uint32_t GetLinkFunctionId();

const Instruction* target_instruction_ = nullptr;
uint32_t type_length_ = 0;
uint32_t access_opcode_ = 0;
};
Expand Down
15 changes: 15 additions & 0 deletions layers/gpu/spirv/function_basic_block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ BasicBlock::BasicBlock(Module& module, Function& function) : function_(function)

uint32_t BasicBlock::GetLabelId() { return (*(instructions_[0])).ResultId(); }

InstructionIt BasicBlock::GetFirstInjectableInstrution() {
InstructionIt inst_it;
for (inst_it = instructions_.begin(); inst_it != instructions_.end(); ++inst_it) {
if ((*inst_it)->Opcode() != spv::OpLabel && (*inst_it)->Opcode() != spv::OpVariable) {
break;
}
}
return inst_it;
}

void BasicBlock::CreateInstruction(spv::Op opcode, const std::vector<uint32_t>& words, InstructionIt* inst_it) {
const bool add_to_end = inst_it == nullptr;
InstructionIt last_inst = instructions_.end();
Expand All @@ -74,6 +84,11 @@ void BasicBlock::CreateInstruction(spv::Op opcode, const std::vector<uint32_t>&
}
}

Function::Function(Module& module, std::unique_ptr<Instruction> function_inst) : module_(module) {
// Used when loading initial SPIR-V
pre_block_inst_.push_back(std::move(function_inst)); // OpFunction
}

BasicBlockIt Function::InsertNewBlock(BasicBlockIt it) {
auto new_block = std::make_unique<BasicBlock>(module_, (*it)->function_);
it++; // make sure it inserted after
Expand Down
10 changes: 6 additions & 4 deletions layers/gpu/spirv/function_basic_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ struct BasicBlock {

uint32_t GetLabelId();

// "All OpVariable instructions in a function must be the first instructions in the first block"
// So need to get the first valid location in block.
InstructionIt GetFirstInjectableInstrution();

// Creates instruction and inserts it before the Instruction, updates poistion after new instruciton.
// If no InstructionIt is provided, it will add it to the end of the block.
void CreateInstruction(spv::Op opcode, const std::vector<uint32_t>& words, InstructionIt* inst_it = nullptr);
Expand All @@ -57,15 +61,13 @@ using BasicBlockList = std::vector<std::unique_ptr<BasicBlock>>;
using BasicBlockIt = BasicBlockList::iterator;

struct Function {
Function(Module& module, std::unique_ptr<Instruction> function_inst) : module_(module) {
// Used when loading initial SPIR-V
pre_block_inst_.push_back(std::move(function_inst)); // OpFunction
}
Function(Module& module, std::unique_ptr<Instruction> function_inst);
Function(Module& module) : module_(module) {}

void ToBinary(std::vector<uint32_t>& out);

const Instruction& GetDef() { return *pre_block_inst_[0].get(); }
BasicBlock& GetFirstBlock() { return *blocks_[0]; }

// Adds a new block after and returns reference to it
BasicBlockIt InsertNewBlock(BasicBlockIt it);
Expand Down
Loading

0 comments on commit fe774b1

Please sign in to comment.