diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 7077a9df59c..28e7574537c 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -83,10 +83,14 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) { return vkapi::kChar; case vkgraph::VkDataType::INT32: return vkapi::kInt; + case vkgraph::VkDataType::INT64: + return vkapi::kLong; case vkgraph::VkDataType::FLOAT16: return vkapi::kHalf; case vkgraph::VkDataType::FLOAT32: return vkapi::kFloat; + case vkgraph::VkDataType::FLOAT64: + return vkapi::kDouble; } } diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index f112581c498..99ba6a86594 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -18,6 +18,8 @@ enum VkDataType : byte { INT32 = 3, FLOAT16 = 4, FLOAT32 = 5, + FLOAT64 = 6, + INT64 = 7, } // Describes what kind of GPU resource should be used to represent a tensor. The diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 5bae0475c28..cd876bd6305 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -45,9 +45,11 @@ def __init__( self, program: ExportedProgram, delegate_mapping_builder: DelegateMappingBuilder, + downcast_64_bit: bool = True, ) -> None: self.program = program self.delegate_mapping_builder = delegate_mapping_builder + self.downcast_64_bit = downcast_64_bit self.chain = [] self.values = [] self.input_ids = [] @@ -72,13 +74,14 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: return vk_graph_schema.VkDataType.INT8 elif torch_dtype == torch.int32: return vk_graph_schema.VkDataType.INT32 + elif torch_dtype == torch.int64: + return vk_graph_schema.VkDataType.INT64 elif torch_dtype == torch.float16: return vk_graph_schema.VkDataType.FLOAT16 elif torch_dtype == torch.float32: return vk_graph_schema.VkDataType.FLOAT32 - # Narrowing conversion for index tensor produced by max_poolNd_with_indices. - elif torch_dtype == torch.int64: - return vk_graph_schema.VkDataType.INT32 + elif torch_dtype == torch.float64: + return vk_graph_schema.VkDataType.FLOAT64 else: raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") @@ -201,11 +204,20 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: # pyre-ignore[16] memory_layout = spec.vk_memory_layout + # Apply downcast logic before getting VK datatype + effective_dtype = spec.dtype + if self.downcast_64_bit and spec.dtype == torch.float64: + effective_dtype = torch.float32 + elif self.downcast_64_bit and spec.dtype == torch.int64: + effective_dtype = torch.int32 + + datatype = self.get_vk_datatype(effective_dtype) + new_id = len(self.values) self.values.append( vk_graph_schema.VkValue( value=vk_graph_schema.VkTensor( - datatype=self.get_vk_datatype(spec.dtype), + datatype=datatype, dims=spec.shape, constant_id=constant_id, mem_obj_id=mem_obj_id, diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index 35113bc623a..f845e5601a7 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -29,6 +29,8 @@ class VkDataType(IntEnum): INT32 = 3 FLOAT16 = 4 FLOAT32 = 5 + FLOAT64 = 6 + INT64 = 7 class VkStorageType(IntEnum): diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index a22afc3f42e..a6d5737dbb8 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -67,7 +67,6 @@ # pyre-ignore def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: for p in passes: - if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): new_gm = program.graph_module # This is a workaround to allow the memory planning pass to work without @@ -110,6 +109,9 @@ def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: if spec.key == "skip_tag_memory_metadata": options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + if spec.key == "downcast_64_bit": + options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + # Unhandled options are ignored return options @@ -142,6 +144,7 @@ def preprocess( # noqa: C901 default_memory_layout = compile_options.get( "memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED ) + downcast_64_bit = compile_options.get("downcast_64_bit", True) program = unsafe_remove_auto_functionalized_pass(program) @@ -213,7 +216,9 @@ def preprocess( # noqa: C901 ) graph_builder = VkGraphBuilder( - program, DelegateMappingBuilder(generated_identifiers=True) + program, + DelegateMappingBuilder(generated_identifiers=True), + downcast_64_bit=downcast_64_bit, ) vk_graph = graph_builder.build_graph()