Skip to content

[ET-VK][ez] enabling fp64->fp32 converison for vulkan compatibility #12201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
return vkapi::kChar;
case vkgraph::VkDataType::INT32:
return vkapi::kInt;
case vkgraph::VkDataType::INT64:
return vkapi::kLong;
case vkgraph::VkDataType::FLOAT16:
return vkapi::kHalf;
case vkgraph::VkDataType::FLOAT32:
return vkapi::kFloat;
case vkgraph::VkDataType::FLOAT64:
return vkapi::kDouble;
}
}

Expand Down
6 changes: 4 additions & 2 deletions backends/vulkan/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ enum VkDataType : byte {
UINT8 = 1,
INT8 = 2,
INT32 = 3,
FLOAT16 = 4,
FLOAT32 = 5,
INT64 = 4,
FLOAT16 = 5,
FLOAT32 = 6,
FLOAT64 = 7,
}

// Describes what kind of GPU resource should be used to represent a tensor. The
Expand Down
20 changes: 16 additions & 4 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ def __init__(
self,
program: ExportedProgram,
delegate_mapping_builder: DelegateMappingBuilder,
downcast_64_bit: bool = False,
) -> None:
self.program = program
self.delegate_mapping_builder = delegate_mapping_builder
self.downcast_64_bit = downcast_64_bit
self.chain = []
self.values = []
self.input_ids = []
Expand All @@ -72,13 +74,14 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
return vk_graph_schema.VkDataType.INT8
elif torch_dtype == torch.int32:
return vk_graph_schema.VkDataType.INT32
elif torch_dtype == torch.int64:
return vk_graph_schema.VkDataType.INT64
elif torch_dtype == torch.float16:
return vk_graph_schema.VkDataType.FLOAT16
elif torch_dtype == torch.float32:
return vk_graph_schema.VkDataType.FLOAT32
# Narrowing conversion for index tensor produced by max_poolNd_with_indices.
elif torch_dtype == torch.int64:
return vk_graph_schema.VkDataType.INT32
elif torch_dtype == torch.float64:
return vk_graph_schema.VkDataType.FLOAT64
else:
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")

Expand Down Expand Up @@ -201,11 +204,20 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
# pyre-ignore[16]
memory_layout = spec.vk_memory_layout

# Apply downcast logic before getting VK datatype
effective_dtype = spec.dtype
if self.downcast_64_bit and spec.dtype == torch.float64:
effective_dtype = torch.float32
elif self.downcast_64_bit and spec.dtype == torch.int64:
effective_dtype = torch.int32

datatype = self.get_vk_datatype(effective_dtype)

new_id = len(self.values)
self.values.append(
vk_graph_schema.VkValue(
value=vk_graph_schema.VkTensor(
datatype=self.get_vk_datatype(spec.dtype),
datatype=datatype,
dims=spec.shape,
constant_id=constant_id,
mem_obj_id=mem_obj_id,
Expand Down
6 changes: 4 additions & 2 deletions backends/vulkan/serialization/vulkan_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ class VkDataType(IntEnum):
UINT8 = 1
INT8 = 2
INT32 = 3
FLOAT16 = 4
FLOAT32 = 5
INT64 = 4
FLOAT16 = 5
FLOAT32 = 6
FLOAT64 = 7


class VkStorageType(IntEnum):
Expand Down
9 changes: 7 additions & 2 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
# pyre-ignore
def apply_passes(program: ExportedProgram, passes) -> ExportedProgram:
for p in passes:

if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase):
new_gm = program.graph_module
# This is a workaround to allow the memory planning pass to work without
Expand Down Expand Up @@ -110,6 +109,9 @@ def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]:
if spec.key == "skip_tag_memory_metadata":
options[spec.key] = bool.from_bytes(spec.value, byteorder="little")

if spec.key == "downcast_64_bit":
options[spec.key] = bool.from_bytes(spec.value, byteorder="little")

# Unhandled options are ignored

return options
Expand Down Expand Up @@ -142,6 +144,7 @@ def preprocess( # noqa: C901
default_memory_layout = compile_options.get(
"memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED
)
downcast_64_bit = compile_options.get("downcast_64_bit", False)

program = unsafe_remove_auto_functionalized_pass(program)

Expand Down Expand Up @@ -213,7 +216,9 @@ def preprocess( # noqa: C901
)

graph_builder = VkGraphBuilder(
program, DelegateMappingBuilder(generated_identifiers=True)
program,
DelegateMappingBuilder(generated_identifiers=True),
downcast_64_bit=downcast_64_bit,
)
vk_graph = graph_builder.build_graph()

Expand Down
Loading