Skip to content

Commit

Permalink
layers: Handle VkShaderModuleCreateInfo from pNext
Browse files Browse the repository at this point in the history
  • Loading branch information
spencer-lunarg committed Jun 10, 2024
1 parent 14da7bb commit 97af4dd
Show file tree
Hide file tree
Showing 20 changed files with 563 additions and 62 deletions.
2 changes: 2 additions & 0 deletions layers/core_checks/cc_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,8 @@ bool CoreChecks::ValidatePipelineShaderStage(const vvl::Pipeline &pipeline,
"VkPipelineShaderStageModuleIdentifierCreateInfoEXT or VkShaderModuleCreateInfo found in the "
"pNext chain. (stage %s).",
string_VkShaderStageFlagBits(stage_ci.stage));
} else {
skip |= ValidateShaderModuleCreateInfo(*module_create_info, loc.pNext(Struct::VkShaderModuleCreateInfo));
}
}
return skip;
Expand Down
61 changes: 34 additions & 27 deletions layers/core_checks/cc_spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2135,13 +2135,15 @@ bool CoreChecks::ValidateShaderStage(const ShaderStageState &stage_state, const
if ((pipeline && pipeline->uses_shader_module_id) || !stage_state.spirv_state) {
return skip; // these edge cases should be validated already
}

const spirv::Module &module_state = *stage_state.spirv_state.get();
if (!module_state.valid_spirv) return skip; // checked elsewhere

if (!stage_state.entrypoint) {
const char *vuid = pipeline ? "VUID-VkPipelineShaderStageCreateInfo-pName-00707" : "VUID-VkShaderCreateInfoEXT-pName-08440";
return LogError(vuid, device, loc.dot(Field::pName), "`%s` entrypoint not found for stage %s.", stage_state.GetPName(),
string_VkShaderStageFlagBits(stage));
}

const spirv::Module &module_state = *stage_state.spirv_state.get();
const spirv::EntryPoint &entrypoint = *stage_state.entrypoint;

// to prevent const_cast on pipeline object, just store here as not needed outside function anyway
Expand Down Expand Up @@ -2478,8 +2480,10 @@ bool CoreChecks::RunSpirvValidation(spv_const_binary_t &binary, const Location &
spv_diagnostic diag = nullptr;
const spv_result_t spv_valid = spvValidateWithOptions(ctx, spirv_val_options, &binary, &diag);
if (spv_valid != SPV_SUCCESS) {
const char *vuid = loc.function == Func::vkCreateShaderModule ? "VUID-VkShaderModuleCreateInfo-pCode-08737"
: "VUID-VkShaderCreateInfoEXT-pCode-08737";
// VkShaderModuleCreateInfo can come from many functions
const char *vuid = loc.function == Func::vkCreateShadersEXT ? "VUID-VkShaderCreateInfoEXT-pCode-08737"
: "VUID-VkShaderModuleCreateInfo-pCode-08737";

if (spv_valid == SPV_WARNING) {
skip |= LogWarning(vuid, device, loc.dot(Field::pCode), "(spirv-val produced a warning):\n%s",
diag && diag->error ? diag->error : "(no error text)");
Expand All @@ -2498,45 +2502,48 @@ bool CoreChecks::RunSpirvValidation(spv_const_binary_t &binary, const Location &
return skip;
}

bool CoreChecks::PreCallValidateCreateShaderModule(VkDevice device, const VkShaderModuleCreateInfo *pCreateInfo,
const VkAllocationCallbacks *pAllocator, VkShaderModule *pShaderModule,
const ErrorObject &error_obj) const {
bool CoreChecks::ValidateShaderModuleCreateInfo(const VkShaderModuleCreateInfo &create_info,
const Location &create_info_loc) const {
bool skip = false;

if (disabled[shader_validation]) {
return false;
return skip;
}

const Location create_info_loc = error_obj.location.dot(Field::pCreateInfo);

if (pCreateInfo->pCode[0] != spv::MagicNumber) {
if (!create_info.pCode) {
return skip; // will be caught elsewhere
} else if (create_info.pCode[0] != spv::MagicNumber) {
if (!IsExtEnabled(device_extensions.vk_nv_glsl_shader)) {
skip |= LogError("VUID-VkShaderModuleCreateInfo-pCode-07912", device, create_info_loc.dot(Field::pCode),
"doesn't point to a SPIR-V module.");
"doesn't point to a SPIR-V module (The first dword is not the SPIR-V MagicNumber 0x07230203).");
}
} else if (SafeModulo(pCreateInfo->codeSize, 4) != 0) {
} else if (SafeModulo(create_info.codeSize, 4) != 0) {
skip |= LogError("VUID-VkShaderModuleCreateInfo-codeSize-08735", device, create_info_loc.dot(Field::codeSize),
"(%zu) must be a multiple of 4.", pCreateInfo->codeSize);
}
"(%zu) must be a multiple of 4.", create_info.codeSize);
} else {
// if pCode is garbage, don't pass along to spirv-val

if (skip) {
return skip; // if pCode is garbage, don't pass along to spirv-val
}
const auto validation_cache_ci = vku::FindStructInPNextChain<VkShaderModuleValidationCacheCreateInfoEXT>(create_info.pNext);
ValidationCache *cache =
validation_cache_ci ? CastFromHandle<ValidationCache *>(validation_cache_ci->validationCache) : nullptr;
// If app isn't using a shader validation cache, use the default one from CoreChecks
if (!cache) {
cache = CastFromHandle<ValidationCache *>(core_validation_cache);
}

const auto validation_cache_ci = vku::FindStructInPNextChain<VkShaderModuleValidationCacheCreateInfoEXT>(pCreateInfo->pNext);
ValidationCache *cache =
validation_cache_ci ? CastFromHandle<ValidationCache *>(validation_cache_ci->validationCache) : nullptr;
// If app isn't using a shader validation cache, use the default one from CoreChecks
if (!cache) {
cache = CastFromHandle<ValidationCache *>(core_validation_cache);
spv_const_binary_t binary{create_info.pCode, create_info.codeSize / sizeof(uint32_t)};
skip |= RunSpirvValidation(binary, create_info_loc, cache);
}

spv_const_binary_t binary{pCreateInfo->pCode, pCreateInfo->codeSize / sizeof(uint32_t)};
skip |= RunSpirvValidation(binary, create_info_loc, cache);

return skip;
}

bool CoreChecks::PreCallValidateCreateShaderModule(VkDevice device, const VkShaderModuleCreateInfo *pCreateInfo,
const VkAllocationCallbacks *pAllocator, VkShaderModule *pShaderModule,
const ErrorObject &error_obj) const {
return ValidateShaderModuleCreateInfo(*pCreateInfo, error_obj.location.dot(Field::pCreateInfo));
}

bool CoreChecks::PreCallValidateGetShaderModuleIdentifierEXT(VkDevice device, VkShaderModule shaderModule,
VkShaderModuleIdentifierEXT *pIdentifier,
const ErrorObject &error_obj) const {
Expand Down
1 change: 1 addition & 0 deletions layers/core_checks/core_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ class CoreChecks : public ValidationStateTracker {
bool RunSpirvValidation(spv_const_binary_t& binary, const Location& loc, ValidationCache* cache) const;
bool ValidateSpirvStateless(const spirv::Module& module_state, const spirv::StatelessData& stateless_data,
const Location& loc) const;
bool ValidateShaderModuleCreateInfo(const VkShaderModuleCreateInfo& create_info, const Location& create_info_loc) const;
bool PreCallValidateCreateShaderModule(VkDevice device, const VkShaderModuleCreateInfo* pCreateInfo,
const VkAllocationCallbacks* pAllocator, VkShaderModule* pShaderModule,
const ErrorObject& error_obj) const override;
Expand Down
4 changes: 2 additions & 2 deletions layers/state_tracker/pipeline_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ std::vector<ShaderStageState> Pipeline::GetStageStates(const ValidationStateTrac
if (!module_state || !module_state->spirv) {
// If module is null and there is a VkShaderModuleCreateInfo in the pNext chain of the stage info, then this
// module is part of a library and the state must be created
const auto shader_ci = vku::FindStructInPNextChain<VkShaderModuleCreateInfo>(stage_ci.pNext);
// This support was also added in VK_KHR_maintenance5
const uint32_t unique_shader_id = (shader_unique_id_map) ? (*shader_unique_id_map)[stage] : 0;
if (shader_ci) {
if (const auto shader_ci = vku::FindStructInPNextChain<VkShaderModuleCreateInfo>(stage_ci.pNext)) {
// don't need to worry about GroupDecoration in GPL
auto spirv_module = std::make_shared<spirv::Module>(shader_ci->codeSize, shader_ci->pCode);
module_state = std::make_shared<vvl::ShaderModule>(VK_NULL_HANDLE, spirv_module, unique_shader_id);
Expand Down
10 changes: 5 additions & 5 deletions layers/state_tracker/pipeline_sub_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ PreRasterState::PreRasterState(const vvl::Pipeline &p, const ValidationStateTrac
auto module_state = state_data.Get<vvl::ShaderModule>(stage_ci.module);
if (!module_state) {
// If module is null and there is a VkShaderModuleCreateInfo in the pNext chain of the stage info, then this
// module is part of a library and the state must be created
const auto shader_ci = vku::FindStructInPNextChain<VkShaderModuleCreateInfo>(stage_ci.pNext);
if (shader_ci) {
// module is part of a library and the state must be created.
// This support was also added in VK_KHR_maintenance5
if (const auto shader_ci = vku::FindStructInPNextChain<VkShaderModuleCreateInfo>(stage_ci.pNext)) {
// don't need to worry about GroupDecoration in GPL
auto spirv_module = std::make_shared<spirv::Module>(shader_ci->codeSize, shader_ci->pCode);
module_state = std::make_shared<vvl::ShaderModule>(VK_NULL_HANDLE, spirv_module, 0);
Expand Down Expand Up @@ -188,8 +188,8 @@ void SetFragmentShaderInfoPrivate(FragmentShaderState &fs_state, const Validatio
if (!module_state) {
// If module is null and there is a VkShaderModuleCreateInfo in the pNext chain of the stage info, then this
// module is part of a library and the state must be created
const auto shader_ci = vku::FindStructInPNextChain<VkShaderModuleCreateInfo>(create_info.pStages[i].pNext);
if (shader_ci) {
// This support was also added in VK_KHR_maintenance5
if (const auto shader_ci = vku::FindStructInPNextChain<VkShaderModuleCreateInfo>(create_info.pStages[i].pNext)) {
// don't need to worry about GroupDecoration in GPL
auto spirv_module = std::make_shared<spirv::Module>(shader_ci->codeSize, shader_ci->pCode);
module_state = std::make_shared<vvl::ShaderModule>(VK_NULL_HANDLE, spirv_module, 0);
Expand Down
3 changes: 3 additions & 0 deletions layers/state_tracker/shader_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,8 @@ std::optional<VkPrimitiveTopology> Module::GetTopology(const EntryPoint& entrypo
}

Module::StaticData::StaticData(const Module& module_state, StatelessData* stateless_data) {
if (!module_state.valid_spirv) return;

// Parse the words first so we have instruction class objects to use
{
std::vector<uint32_t>::const_iterator it = module_state.words_.cbegin();
Expand Down Expand Up @@ -1331,6 +1333,7 @@ std::string Module::DescribeVariable(uint32_t id) const {
}

std::shared_ptr<const EntryPoint> Module::FindEntrypoint(char const* name, VkShaderStageFlagBits stageBits) const {
if (!name) return nullptr;
for (const auto& entry_point : static_data_.entry_points) {
if (entry_point->name.compare(name) == 0 && entry_point->stage == stageBits) {
return entry_point;
Expand Down
11 changes: 9 additions & 2 deletions layers/state_tracker/shader_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,11 @@ struct Module {
vvl::unordered_map<const Instruction *, uint32_t> image_write_load_id_map; // <OpImageWrite, load id>
};

// VK_KHR_maintenance5 allows VkShaderModuleCreateInfo (the SPIR-V binary) to be passed at pipeline creation time, because the
// way we create our pipeline state objects first, we need to still create a valid Module object, but can signal that the
// underlying spirv is not worth validating further
const bool valid_spirv;

// This is the SPIR-V module data content
const std::vector<uint32_t> words_;

Expand All @@ -616,11 +621,13 @@ struct Module {
VulkanTypedHandle handle() const { return handle_; } // matches normal convention to get handle

// Used for when modifying the SPIR-V (spirv-opt, GPU-AV instrumentation, etc) and need reparse it for VVL validaiton
Module(vvl::span<const uint32_t> code) : words_(code.begin(), code.end()), static_data_(*this) {}
Module(vvl::span<const uint32_t> code) : valid_spirv(true), words_(code.begin(), code.end()), static_data_(*this) {}

// StatelessData is a pointer as we have cases were we don't need it and simpler to just null check the few cases that use it
Module(size_t codeSize, const uint32_t *pCode, StatelessData *stateless_data = nullptr)
: words_(pCode, pCode + codeSize / sizeof(uint32_t)), static_data_(*this, stateless_data) {}
: valid_spirv(pCode && pCode[0] == spv::MagicNumber && ((codeSize % 4) == 0)),
words_(pCode, pCode + codeSize / sizeof(uint32_t)),
static_data_(*this, stateless_data) {}

const Instruction *FindDef(uint32_t id) const {
auto it = static_data_.definitions.find(id);
Expand Down
26 changes: 13 additions & 13 deletions layers/stateless/sl_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,16 @@ bool StatelessValidation::manual_PreCallValidateCreatePipelineLayout(VkDevice de
return skip;
}

bool StatelessValidation::ValidatePipelineShaderStageCreateInfo(const VkPipelineShaderStageCreateInfo &create_info,
const Location &loc) const {
// Called from graphics, compute, raytracing, etc
bool StatelessValidation::ValidatePipelineShaderStageCreateInfoCommon(const VkPipelineShaderStageCreateInfo &create_info,
const Location &loc) const {
bool skip = false;

const auto *required_subgroup_size_features =
vku::FindStructInPNextChain<VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT>(create_info.pNext);
if (create_info.pName) {
skip |= ValidateString(loc.dot(Field::pName), "VUID-VkPipelineShaderStageCreateInfo-pName-parameter", create_info.pName);
}

if (required_subgroup_size_features) {
if (vku::FindStructInPNextChain<VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT>(create_info.pNext)) {
if ((create_info.flags & VK_PIPELINE_SHADER_STAGE_CREATE_ALLOW_VARYING_SUBGROUP_SIZE_BIT_EXT) != 0) {
skip |= LogError("VUID-VkPipelineShaderStageCreateInfo-pNext-02754", device, loc.dot(Field::flags),
"(%s) includes VK_PIPELINE_SHADER_STAGE_CREATE_ALLOW_VARYING_SUBGROUP_SIZE_BIT_EXT while "
Expand Down Expand Up @@ -464,15 +466,13 @@ bool StatelessValidation::manual_PreCallValidateCreateGraphicsPipelines(
active_shaders |= create_info.pStages[stage_index].stage;
const Location stage_loc = create_info_loc.dot(Field::pStages, stage_index);

skip |= ValidateRequiredPointer(stage_loc.dot(Field::pName), create_info.pStages[stage_index].pName,
"VUID-VkPipelineShaderStageCreateInfo-pName-parameter");

if (create_info.pStages[stage_index].pName) {
skip |= ValidateString(stage_loc.dot(Field::pName), "VUID-VkPipelineShaderStageCreateInfo-pName-parameter",
create_info.pStages[stage_index].pName);
}
skip |= ValidateStructType(stage_loc, "VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO",
&create_info.pStages[stage_index], VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
false, kVUIDUndefined, "VUID-VkPipelineShaderStageCreateInfo-sType-sType");

// special graphics-only generated call
ValidatePipelineShaderStageCreateInfo(create_info.pStages[stage_index], stage_loc);
ValidatePipelineShaderStageCreateInfoCommon(create_info.pStages[stage_index], stage_loc);
}
}

Expand Down Expand Up @@ -1245,7 +1245,7 @@ bool StatelessValidation::manual_PreCallValidateCreateComputePipelines(VkDevice
}
}

ValidatePipelineShaderStageCreateInfo(create_info.stage, create_info_loc.dot(Field::stage));
ValidatePipelineShaderStageCreateInfoCommon(create_info.stage, create_info_loc.dot(Field::stage));
}
return skip;
}
Expand Down
6 changes: 3 additions & 3 deletions layers/stateless/sl_ray_tracing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,8 @@ bool StatelessValidation::manual_PreCallValidateCreateRayTracingPipelinesNV(
const VkRayTracingPipelineCreateInfoNV &create_info = pCreateInfos[i];

for (uint32_t stage_index = 0; stage_index < create_info.stageCount; ++stage_index) {
ValidatePipelineShaderStageCreateInfo(create_info.pStages[stage_index],
create_info_loc.dot(Field::pStages, stage_index));
ValidatePipelineShaderStageCreateInfoCommon(create_info.pStages[stage_index],
create_info_loc.dot(Field::pStages, stage_index));
}
auto feedback_struct = vku::FindStructInPNextChain<VkPipelineCreationFeedbackCreateInfoEXT>(create_info.pNext);
if ((feedback_struct != nullptr) && (feedback_struct->pipelineStageCreationFeedbackCount != 0) &&
Expand Down Expand Up @@ -488,7 +488,7 @@ bool StatelessValidation::manual_PreCallValidateCreateRayTracingPipelinesKHR(

for (uint32_t stage_index = 0; stage_index < create_info.stageCount; ++stage_index) {
const Location stage_loc = create_info_loc.dot(Field::pStages, stage_index);
ValidatePipelineShaderStageCreateInfo(create_info.pStages[stage_index], stage_loc);
ValidatePipelineShaderStageCreateInfoCommon(create_info.pStages[stage_index], stage_loc);

const auto stage = create_info.pStages[stage_index].stage;
if ((stage & kShaderStageAllRayTracing) == 0) {
Expand Down
2 changes: 1 addition & 1 deletion layers/stateless/stateless_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ class StatelessValidation : public ValidationObject {
const VkAllocationCallbacks *pAllocator, VkPipelineLayout *pPipelineLayout,
const ErrorObject &error_obj) const;

bool ValidatePipelineShaderStageCreateInfo(const VkPipelineShaderStageCreateInfo &create_info, const Location &loc) const;
bool ValidatePipelineShaderStageCreateInfoCommon(const VkPipelineShaderStageCreateInfo &create_info, const Location &loc) const;
bool ValidatePipelineRenderingCreateInfo(const VkPipelineRenderingCreateInfo &rendering_struct, const Location &loc) const;
bool ValidateCreateGraphicsPipelinesFlags(const VkPipelineCreateFlags2KHR flags, const Location &flags_loc) const;
bool manual_PreCallValidateCreateGraphicsPipelines(VkDevice device, VkPipelineCache pipelineCache, uint32_t createInfoCount,
Expand Down
Loading

0 comments on commit 97af4dd

Please sign in to comment.