diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index cb14a41e98a..2dc02b8b800 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -549,6 +549,22 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer( } } +int32_t ComputeGraph::get_or_create_int(const ValueRef idx) { + if (values_.at(idx).isInt()) { + return extract_scalar(idx); + } + VK_THROW("Cannot create a int param buffer for the given value"); +} + +int32_t ComputeGraph::get_or_create_int( + const ValueRef idx, + const int32_t default_val) { + if (values_.at(idx).isNone()) { + return default_val; + } + return get_or_create_int(idx); +} + void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) { get_symint(idx)->set(val); } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 78135a434e5..7a73ae1dee5 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -424,6 +424,12 @@ class ComputeGraph final { // Scalar Value Extraction // + bool is_scalar(const ValueRef idx) const { + const Value& value = values_.at(idx); + return value.isInt() || value.isDouble() || value.isBool() || + value.isNone(); + } + template T extract_scalar(const ValueRef idx) { Value& value = values_.at(idx); @@ -679,6 +685,10 @@ class ComputeGraph final { const ValueRef idx, const int32_t default_value); + int32_t get_or_create_int(const ValueRef idx); + + int32_t get_or_create_int(const ValueRef idx, const int32_t default_value); + void set_symint(const ValueRef idx, const int32_t val); int32_t read_symint(const ValueRef idx);