@@ -15,10 +15,13 @@ if (NOT CMAKE_BUILD_TYPE)
1515 set (CMAKE_BUILD_TYPE Release)
1616endif ()
1717
18+ # Platform options
1819option (TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF )
1920option (TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF )
2021option (TORCHAO_BUILD_CPU_AARCH64 "Build torchao's CPU aarch64 kernels" OFF )
2122option (TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF )
23+ option (TORCHAO_ENABLE_ARM_NEON_DOT "Enable ARM Neon Dot Product extension" OFF )
24+ option (TORCHAO_ENABLE_ARM_I8MM "Enable ARM 8-bit Integer Matrix Multiply instructions" OFF )
2225
2326if (NOT TORCHAO_INCLUDE_DIRS)
2427 set (TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR} /../..)
@@ -28,19 +31,49 @@ if(NOT DEFINED TORCHAO_PARALLEL_BACKEND)
2831 set (TORCHAO_PARALLEL_BACKEND aten_openmp)
2932endif ()
3033
31- include (CMakePrintHelpers)
32-
34+ # Set default compiler options
3335add_compile_options ("-Wall" "-Werror" "-Wno-deprecated" )
3436
3537include (CMakePrintHelpers)
3638message ("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS} " )
3739include_directories (${TORCHAO_INCLUDE_DIRS} )
3840
39-
4041if (TORCHAO_BUILD_CPU_AARCH64)
4142 message (STATUS "Building with cpu/aarch64" )
4243 add_compile_definitions (TORCHAO_BUILD_CPU_AARCH64)
43- add_compile_definitions (TORCHAO_ENABLE_ARM_NEON_DOT)
44+
45+ # Set aarch64 compiler options
46+ if (CMAKE_SYSTEM_NAME STREQUAL "Linux" )
47+ message (STATUS "Add aarch64 linux compiler options" )
48+ add_compile_options (
49+ "-fPIC"
50+ "-Wno-error=unknown-pragmas"
51+ "-Wno-array-parameter"
52+ "-Wno-maybe-uninitialized"
53+ "-Wno-sign-compare"
54+ )
55+
56+ # Since versions are hierarchical (each includes features from prior versions):
57+ # - dotprod is included by default in armv8.4-a and later
58+ # - i8mm is included by default in armv8.6-a and later
59+ if (TORCHAO_ENABLE_ARM_I8MM)
60+ message (STATUS "Using armv8.6-a (includes 'i8mm' and 'dotprod' flags)" )
61+ add_compile_options ("-march=armv8.6-a" )
62+ elseif (TORCHAO_ENABLE_ARM_NEON_DOT)
63+ message (STATUS "Using armv8.4-a (includes '+dotprod' flag)" )
64+ add_compile_options ("-march=armv8.4-a" )
65+ endif ()
66+ endif ()
67+
68+ if (TORCHAO_ENABLE_ARM_NEON_DOT)
69+ message (STATUS "Building with ARM NEON dot product support" )
70+ add_compile_definitions (TORCHAO_ENABLE_ARM_NEON_DOT)
71+ endif ()
72+
73+ if (TORCHAO_ENABLE_ARM_I8MM)
74+ message (STATUS "Building with ARM I8MM support" )
75+ add_compile_definitions (TORCHAO_ENABLE_ARM_I8MM)
76+ endif ()
4477
4578 # Defines torchao_kernels_aarch64
4679 add_subdirectory (kernels/cpu/aarch64)
@@ -51,26 +84,33 @@ if(TORCHAO_BUILD_CPU_AARCH64)
5184 endif ()
5285endif ()
5386
87+ # Add quantized operation dir
5488add_subdirectory (ops/linear_8bit_act_xbit_weight)
5589add_subdirectory (ops/embedding_xbit)
5690
91+ # ATen ops lib
5792add_library (torchao_ops_aten SHARED)
5893target_link_libraries (
5994 torchao_ops_aten PRIVATE
6095 torchao_ops_linear_8bit_act_xbit_weight_aten
6196 torchao_ops_embedding_xbit_aten
6297)
98+
99+ # Add MPS support if enabled
63100if (TORCHAO_BUILD_MPS_OPS)
64101 message (STATUS "Building with MPS support" )
65102 add_subdirectory (ops/mps)
66103 target_link_libraries (torchao_ops_aten PRIVATE torchao_ops_mps_aten)
67104endif ()
68105
106+ # Install ATen targets
69107install (
70108 TARGETS torchao_ops_aten
71109 EXPORT _targets
72110 DESTINATION lib
73111)
112+
113+ # Build executorch lib if enabled
74114if (TORCHAO_BUILD_EXECUTORCH_OPS)
75115 add_library (torchao_ops_executorch STATIC )
76116 target_link_libraries (torchao_ops_executorch PRIVATE
0 commit comments