Skip to content

Commit b1cd9eb

Browse files
committed
[ET-VK] Adding get or create int function to read int value.
Pull Request resolved: #12357 This diff adds a new function `get_or_create_int` to the `ComputeGraph` class, which allows reading an integer value from a `ValueRef` index. The function returns the extracted integer value if the value at the index is an integer, otherwise it throws an error. Additionally, an overload of the function is added to return a default value if the value at the index is `None`. ghstack-source-id: 295655467 @exported-using-ghexport Differential Revision: [D78094858](https://our.internmc.facebook.com/intern/diff/D78094858/)
1 parent 6309119 commit b1cd9eb

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,22 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
549549
}
550550
}
551551

552+
int32_t ComputeGraph::get_or_create_int(const ValueRef idx) {
553+
if (values_.at(idx).isInt()) {
554+
return extract_scalar<int32_t>(idx);
555+
}
556+
VK_THROW("Cannot create a int param buffer for the given value");
557+
}
558+
559+
int32_t ComputeGraph::get_or_create_int(
560+
const ValueRef idx,
561+
const int32_t default_val) {
562+
if (values_.at(idx).isNone()) {
563+
return default_val;
564+
}
565+
return get_or_create_int(idx);
566+
}
567+
552568
void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
553569
get_symint(idx)->set(val);
554570
}

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,12 @@ class ComputeGraph final {
424424
// Scalar Value Extraction
425425
//
426426

427+
bool is_scalar(const ValueRef idx) const {
428+
const Value& value = values_.at(idx);
429+
return value.isInt() || value.isDouble() || value.isBool() ||
430+
value.isNone();
431+
}
432+
427433
template <typename T>
428434
T extract_scalar(const ValueRef idx) {
429435
Value& value = values_.at(idx);
@@ -679,6 +685,10 @@ class ComputeGraph final {
679685
const ValueRef idx,
680686
const int32_t default_value);
681687

688+
int32_t get_or_create_int(const ValueRef idx);
689+
690+
int32_t get_or_create_int(const ValueRef idx, const int32_t default_value);
691+
682692
void set_symint(const ValueRef idx, const int32_t val);
683693

684694
int32_t read_symint(const ValueRef idx);

0 commit comments

Comments
 (0)