Skip to content

Commit 1be6995

Browse files
committed
[ET-VK] Adding extract_scalar_or function to extract scalar value or return a default if value at index is none.
Pull Request resolved: #12357 This diff adds a new function `extract_scalar_or` to the `ComputeGraph` class, which extracts a scalar value from a `ValueRef` index. If the value at the index is `None`, it returns a default value. ghstack-source-id: 296059970 @exported-using-ghexport Differential Revision: [D78094858](https://our.internmc.facebook.com/intern/diff/D78094858/)
1 parent 1540659 commit 1be6995

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,12 @@ class ComputeGraph final {
424424
// Scalar Value Extraction
425425
//
426426

427+
bool is_scalar_or_none(const ValueRef idx) const {
428+
const Value& value = values_.at(idx);
429+
return value.isInt() || value.isDouble() || value.isBool() ||
430+
value.isNone();
431+
}
432+
427433
template <typename T>
428434
T extract_scalar(const ValueRef idx) {
429435
Value& value = values_.at(idx);
@@ -439,6 +445,15 @@ class ComputeGraph final {
439445
VK_THROW("Cannot extract scalar from Value with type ", value.type());
440446
}
441447

448+
template <typename T>
449+
T extract_scalar_or(const ValueRef idx, const T default_value) {
450+
Value& value = values_.at(idx);
451+
if (value.isNone()) {
452+
return default_value;
453+
}
454+
return extract_scalar<T>(idx);
455+
}
456+
442457
template <typename T>
443458
std::optional<T> extract_optional_scalar(const ValueRef idx) {
444459
if (val_is_none(idx)) {

0 commit comments

Comments
 (0)