44
55import PackageDescription
66
7+ #if os(Linux)
8+ let platformExcludes : [ String ] = [
9+ // Linux specific excludes
10+ " framework " ,
11+ " include-framework " ,
12+ " metal-cpp " ,
13+ // Exclude Metal backend files on Linux, but keep no_metal.cpp for stubs
14+ " mlx/mlx/backend/metal/allocator.cpp " ,
15+ " mlx/mlx/backend/metal/binary.cpp " ,
16+ " mlx/mlx/backend/metal/compiled.cpp " ,
17+ " mlx/mlx/backend/metal/conv.cpp " ,
18+ " mlx/mlx/backend/metal/copy.cpp " ,
19+ " mlx/mlx/backend/metal/custom_kernel.cpp " ,
20+ " mlx/mlx/backend/metal/device.cpp " ,
21+ " mlx/mlx/backend/metal/distributed.cpp " ,
22+ " mlx/mlx/backend/metal/eval.cpp " ,
23+ " mlx/mlx/backend/metal/event.cpp " ,
24+ " mlx/mlx/backend/metal/fence.cpp " ,
25+ " mlx/mlx/backend/metal/fft.cpp " ,
26+ " mlx/mlx/backend/metal/hadamard.cpp " ,
27+ " mlx/mlx/backend/metal/indexing.cpp " ,
28+ " mlx/mlx/backend/metal/jit_kernels.cpp " ,
29+ " mlx/mlx/backend/metal/logsumexp.cpp " ,
30+ " mlx/mlx/backend/metal/matmul.cpp " ,
31+ " mlx/mlx/backend/metal/metal.cpp " ,
32+ " mlx/mlx/backend/metal/normalization.cpp " ,
33+ " mlx/mlx/backend/metal/primitives.cpp " ,
34+ " mlx/mlx/backend/metal/quantized.cpp " ,
35+ " mlx/mlx/backend/metal/reduce.cpp " ,
36+ " mlx/mlx/backend/metal/resident.cpp " ,
37+ " mlx/mlx/backend/metal/rope.cpp " ,
38+ " mlx/mlx/backend/metal/scaled_dot_product_attention.cpp " ,
39+ " mlx/mlx/backend/metal/scan.cpp " ,
40+ " mlx/mlx/backend/metal/slicing.cpp " ,
41+ " mlx/mlx/backend/metal/softmax.cpp " ,
42+ " mlx/mlx/backend/metal/sort.cpp " ,
43+ " mlx/mlx/backend/metal/ternary.cpp " ,
44+ " mlx/mlx/backend/metal/unary.cpp " ,
45+ " mlx/mlx/backend/metal/utils.cpp " ,
46+ " mlx/mlx/backend/metal/kernels " , // Exclude kernels directory
47+ " mlx/mlx/backend/metal/jit " , // Exclude jit directory
48+
49+ " mlx/mlx/backend/gpu " , // Exclude GPU backend on Linux, use no_gpu instead
50+ " mlx/mlx/backend/no_cpu " , // Exclude no_cpu backend on Linux, use cpu instead
51+ " mlx/mlx/backend/cpu/gemms/bnns.cpp " , // macOS Accelerate version
52+ " mlx-conditional " ,
53+ " mlx-c/mlx/c/metal.cpp " ,
54+ // Note: mlx-c/mlx/c/fast.cpp is included on Linux - it calls C++ fast functions
55+ // which have fallback implementations that don't require fast primitives.
56+ // The metal_kernel stub from no_metal.cpp handles any metal_kernel calls.
57+ ]
58+
59+ let cxxSettings : [ CXXSetting ] = [ ]
60+
61+ let linkerSettings : [ LinkerSetting ] = [
62+ . linkedLibrary( " gfortran " , . when( platforms: [ . linux] ) ) ,
63+ . linkedLibrary( " blas " , . when( platforms: [ . linux] ) ) ,
64+ . linkedLibrary( " lapack " , . when( platforms: [ . linux] ) ) ,
65+ . linkedLibrary( " openblas " , . when( platforms: [ . linux] ) ) ,
66+ ]
67+
68+ let mlxSwiftExcludes : [ String ] = [
69+ " GPU+Metal.swift " ,
70+ " MLXArray+Metal.swift " ,
71+ ]
72+ #else
73+ let platformExcludes : [ String ] = [
74+ " mlx/mlx/backend/cpu/compiled.cpp " ,
75+
76+ // opt-out of these backends (using metal)
77+ " mlx/mlx/backend/no_gpu " ,
78+ " mlx/mlx/backend/no_cpu " ,
79+ " mlx/mlx/backend/metal/no_metal.cpp " ,
80+
81+ // bnns instead of simd (accelerate)
82+ " mlx/mlx/backend/cpu/gemms/simd_fp16.cpp " ,
83+ " mlx/mlx/backend/cpu/gemms/simd_bf16.cpp " ,
84+ ]
85+
86+ let cxxSettings : [ CXXSetting ] = [
87+ . headerSearchPath( " metal-cpp " ) ,
88+
89+ . define( " MLX_USE_ACCELERATE " ) ,
90+ . define( " ACCELERATE_NEW_LAPACK " ) ,
91+ . define( " _METAL_ " ) ,
92+ . define( " SWIFTPM_BUNDLE " , to: " \" mlx-swift_Cmlx \" " ) ,
93+ . define( " METAL_PATH " , to: " \" default.metallib \" " ) ,
94+ ]
95+
96+ let linkerSettings : [ LinkerSetting ] = [
97+ . linkedFramework( " Foundation " ) ,
98+ . linkedFramework( " Metal " ) ,
99+ . linkedFramework( " Accelerate " ) ,
100+ ]
101+
102+ let mlxSwiftExcludes : [ String ] = [ ]
103+ #endif
104+
105+ let cmlx = Target . target (
106+ name: " Cmlx " ,
107+ path: " Source/Cmlx " ,
108+ exclude: platformExcludes + [
109+ // vendor docs
110+ " vendor-README.md " ,
111+
112+ // example code + mlx-c distributed
113+ " mlx-c/examples " ,
114+ " mlx-c/mlx/c/distributed.cpp " ,
115+ " mlx-c/mlx/c/distributed_group.cpp " ,
116+
117+ // vendored library, include header only
118+ " json " ,
119+
120+ // vendored library
121+ " fmt/test " ,
122+ " fmt/doc " ,
123+ " fmt/support " ,
124+ " fmt/src/os.cc " ,
125+ " fmt/src/fmt.cc " ,
126+
127+ // these are selected conditionally
128+ " mlx/mlx/backend/no_cpu/compiled.cpp " ,
129+
130+ // mlx files that are not part of the build
131+ " mlx/ACKNOWLEDGMENTS.md " ,
132+ " mlx/CMakeLists.txt " ,
133+ " mlx/CODE_OF_CONDUCT.md " ,
134+ " mlx/CONTRIBUTING.md " ,
135+ " mlx/LICENSE " ,
136+ " mlx/MANIFEST.in " ,
137+ " mlx/README.md " ,
138+ " mlx/benchmarks " ,
139+ " mlx/cmake " ,
140+ " mlx/docs " ,
141+ " mlx/examples " ,
142+ " mlx/mlx.pc.in " ,
143+ " mlx/pyproject.toml " ,
144+ " mlx/python " ,
145+ " mlx/setup.py " ,
146+ " mlx/tests " ,
147+
148+ // special handling for cuda -- we need to keep one file:
149+ // mlx/mlx/backend/cuda/no_cuda.cpp
150+
151+ " mlx/mlx/backend/cuda/allocator.cpp " ,
152+ " mlx/mlx/backend/cuda/compiled.cpp " ,
153+ " mlx/mlx/backend/cuda/conv.cpp " ,
154+ " mlx/mlx/backend/cuda/cublas_utils.cpp " ,
155+ " mlx/mlx/backend/cuda/cuda.cpp " ,
156+ " mlx/mlx/backend/cuda/cudnn_utils.cpp " ,
157+ " mlx/mlx/backend/cuda/custom_kernel.cpp " ,
158+ " mlx/mlx/backend/cuda/device.cpp " ,
159+ " mlx/mlx/backend/cuda/eval.cpp " ,
160+ " mlx/mlx/backend/cuda/fence.cpp " ,
161+ " mlx/mlx/backend/cuda/indexing.cpp " ,
162+ " mlx/mlx/backend/cuda/jit_module.cpp " ,
163+ " mlx/mlx/backend/cuda/load.cpp " ,
164+ " mlx/mlx/backend/cuda/matmul.cpp " ,
165+ " mlx/mlx/backend/cuda/primitives.cpp " ,
166+ " mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp " ,
167+ " mlx/mlx/backend/cuda/slicing.cpp " ,
168+ " mlx/mlx/backend/cuda/utils.cpp " ,
169+ " mlx/mlx/backend/cuda/worker.cpp " ,
170+
171+ " mlx/mlx/backend/cuda/binary " ,
172+ " mlx/mlx/backend/cuda/conv " ,
173+ " mlx/mlx/backend/cuda/copy " ,
174+ " mlx/mlx/backend/cuda/device " ,
175+ " mlx/mlx/backend/cuda/gemms " ,
176+ " mlx/mlx/backend/cuda/quantized " ,
177+ " mlx/mlx/backend/cuda/reduce " ,
178+ " mlx/mlx/backend/cuda/steel " ,
179+ " mlx/mlx/backend/cuda/unary " ,
180+
181+ // build variants (we are opting _out_ of these)
182+ " mlx/mlx/io/no_safetensors.cpp " ,
183+ " mlx/mlx/io/gguf.cpp " ,
184+ " mlx/mlx/io/gguf_quants.cpp " ,
185+
186+ // see PrepareMetalShaders -- don't build the kernels in place
187+ " mlx/mlx/backend/metal/kernels " ,
188+ " mlx/mlx/backend/metal/nojit_kernels.cpp " ,
189+
190+ // do not build distributed support (yet)
191+ " mlx/mlx/distributed/jaccl/jaccl.cpp " ,
192+ " mlx/mlx/distributed/mpi/mpi.cpp " ,
193+ " mlx/mlx/distributed/ring/ring.cpp " ,
194+ " mlx/mlx/distributed/nccl/nccl.cpp " ,
195+ " mlx/mlx/distributed/nccl/nccl_stub " ,
196+ ] ,
197+ cSettings: [
198+ . headerSearchPath( " mlx " ) ,
199+ . headerSearchPath( " mlx-c " ) ,
200+ ] ,
201+ cxxSettings: cxxSettings + [
202+ . headerSearchPath( " mlx " ) ,
203+ . headerSearchPath( " mlx-c " ) ,
204+ . headerSearchPath( " json/single_include/nlohmann " ) ,
205+ . headerSearchPath( " fmt/include " ) ,
206+ . define( " MLX_VERSION " , to: " \" 0.24.2 \" " ) ,
207+ . define( " MLX_ENABLE_NAX " , to: " 1 " ) ,
208+ ] ,
209+ linkerSettings: linkerSettings
210+ )
211+
7212let package = Package (
8213 name: " mlx-swift " ,
9214
@@ -29,140 +234,7 @@ let package = Package(
29234 . package ( url: " https://github.com/apple/swift-numerics " , from: " 1.0.0 " )
30235 ] ,
31236 targets: [
32- . target(
33- name: " Cmlx " ,
34- exclude: [
35- // xcodeproj pieces
36- " framework " ,
37- " include-framework " ,
38-
39- // vendor docs
40- " vendor-README.md " ,
41-
42- // example code + mlx-c distributed
43- " mlx-c/examples " ,
44- " mlx-c/mlx/c/distributed.cpp " ,
45- " mlx-c/mlx/c/distributed_group.cpp " ,
46-
47- // vendored library, include header only
48- " json " ,
49-
50- // vendored library
51- " fmt/test " ,
52- " fmt/doc " ,
53- " fmt/support " ,
54- " fmt/src/os.cc " ,
55- " fmt/src/fmt.cc " ,
56-
57- // these are selected conditionally
58- " mlx/mlx/backend/no_cpu/compiled.cpp " ,
59- " mlx/mlx/backend/cpu/compiled.cpp " ,
60-
61- // mlx files that are not part of the build
62- " mlx/ACKNOWLEDGMENTS.md " ,
63- " mlx/CMakeLists.txt " ,
64- " mlx/CODE_OF_CONDUCT.md " ,
65- " mlx/CONTRIBUTING.md " ,
66- " mlx/LICENSE " ,
67- " mlx/MANIFEST.in " ,
68- " mlx/README.md " ,
69- " mlx/benchmarks " ,
70- " mlx/cmake " ,
71- " mlx/docs " ,
72- " mlx/examples " ,
73- " mlx/mlx.pc.in " ,
74- " mlx/pyproject.toml " ,
75- " mlx/python " ,
76- " mlx/setup.py " ,
77- " mlx/tests " ,
78-
79- // opt-out of these backends (using metal)
80- " mlx/mlx/backend/no_gpu " ,
81- " mlx/mlx/backend/no_cpu " ,
82- " mlx/mlx/backend/metal/no_metal.cpp " ,
83-
84- // special handling for cuda -- we need to keep one file:
85- // mlx/mlx/backend/cuda/no_cuda.cpp
86-
87- " mlx/mlx/backend/cuda/allocator.cpp " ,
88- " mlx/mlx/backend/cuda/compiled.cpp " ,
89- " mlx/mlx/backend/cuda/conv.cpp " ,
90- " mlx/mlx/backend/cuda/cublas_utils.cpp " ,
91- " mlx/mlx/backend/cuda/cuda.cpp " ,
92- " mlx/mlx/backend/cuda/cudnn_utils.cpp " ,
93- " mlx/mlx/backend/cuda/custom_kernel.cpp " ,
94- " mlx/mlx/backend/cuda/device.cpp " ,
95- " mlx/mlx/backend/cuda/eval.cpp " ,
96- " mlx/mlx/backend/cuda/fence.cpp " ,
97- " mlx/mlx/backend/cuda/indexing.cpp " ,
98- " mlx/mlx/backend/cuda/jit_module.cpp " ,
99- " mlx/mlx/backend/cuda/load.cpp " ,
100- " mlx/mlx/backend/cuda/matmul.cpp " ,
101- " mlx/mlx/backend/cuda/primitives.cpp " ,
102- " mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp " ,
103- " mlx/mlx/backend/cuda/slicing.cpp " ,
104- " mlx/mlx/backend/cuda/utils.cpp " ,
105- " mlx/mlx/backend/cuda/worker.cpp " ,
106-
107- " mlx/mlx/backend/cuda/binary " ,
108- " mlx/mlx/backend/cuda/conv " ,
109- " mlx/mlx/backend/cuda/copy " ,
110- " mlx/mlx/backend/cuda/device " ,
111- " mlx/mlx/backend/cuda/gemms " ,
112- " mlx/mlx/backend/cuda/quantized " ,
113- " mlx/mlx/backend/cuda/reduce " ,
114- " mlx/mlx/backend/cuda/steel " ,
115- " mlx/mlx/backend/cuda/unary " ,
116-
117- // build variants (we are opting _out_ of these)
118- " mlx/mlx/io/no_safetensors.cpp " ,
119- " mlx/mlx/io/gguf.cpp " ,
120- " mlx/mlx/io/gguf_quants.cpp " ,
121-
122- // see PrepareMetalShaders -- don't build the kernels in place
123- " mlx/mlx/backend/metal/kernels " ,
124- " mlx/mlx/backend/metal/nojit_kernels.cpp " ,
125-
126- // do not build distributed support (yet)
127- " mlx/mlx/distributed/jaccl/jaccl.cpp " ,
128- " mlx/mlx/distributed/mpi/mpi.cpp " ,
129- " mlx/mlx/distributed/ring/ring.cpp " ,
130- " mlx/mlx/distributed/nccl/nccl.cpp " ,
131- " mlx/mlx/distributed/nccl/nccl_stub " ,
132-
133- // bnns instead of simd (accelerate)
134- " mlx/mlx/backend/cpu/gemms/simd_fp16.cpp " ,
135- " mlx/mlx/backend/cpu/gemms/simd_bf16.cpp " ,
136- ] ,
137-
138- cSettings: [
139- . headerSearchPath( " mlx " ) ,
140- . headerSearchPath( " mlx-c " ) ,
141- ] ,
142-
143- cxxSettings: [
144- . headerSearchPath( " mlx " ) ,
145- . headerSearchPath( " mlx-c " ) ,
146- . headerSearchPath( " metal-cpp " ) ,
147- . headerSearchPath( " json/single_include/nlohmann " ) ,
148- . headerSearchPath( " fmt/include " ) ,
149-
150- . define( " MLX_USE_ACCELERATE " ) ,
151- . define( " ACCELERATE_NEW_LAPACK " ) ,
152- . define( " _METAL_ " ) ,
153- . define( " SWIFTPM_BUNDLE " , to: " \" mlx-swift_Cmlx \" " ) ,
154- . define( " METAL_PATH " , to: " \" default.metallib \" " ) ,
155- . define( " MLX_VERSION " , to: " \" 0.24.2 \" " ) ,
156-
157- // Note: not set yet
158- // .define("MLX_ENABLE_NAX", to: "1"),
159- ] ,
160- linkerSettings: [
161- . linkedFramework( " Foundation " ) ,
162- . linkedFramework( " Metal " ) ,
163- . linkedFramework( " Accelerate " ) ,
164- ]
165- ) ,
237+ cmlx,
166238 . testTarget(
167239 name: " CmlxTests " ,
168240 dependencies: [ " Cmlx " ]
@@ -174,6 +246,7 @@ let package = Package(
174246 " Cmlx " ,
175247 . product( name: " Numerics " , package : " swift-numerics " ) ,
176248 ] ,
249+ exclude: mlxSwiftExcludes,
177250 swiftSettings: [
178251 . enableExperimentalFeature( " StrictConcurrency " )
179252 ]
@@ -187,14 +260,23 @@ let package = Package(
187260 ) ,
188261 . target(
189262 name: " MLXFast " ,
190- dependencies: [ " MLX " , " Cmlx " ] ,
263+ dependencies: [
264+ " MLX " ,
265+ " Cmlx " ,
266+ ] ,
191267 swiftSettings: [
192268 . enableExperimentalFeature( " StrictConcurrency " )
193269 ]
194270 ) ,
195271 . target(
196272 name: " MLXNN " ,
197- dependencies: [ " MLX " ] ,
273+ dependencies: [
274+ " MLX " ,
275+ " MLXRandom " ,
276+ . target(
277+ name: " MLXFast " ,
278+ condition: . when( platforms: [ . macOS, . iOS, . tvOS, . visionOS, . watchOS] ) ) ,
279+ ] ,
198280 swiftSettings: [
199281 . enableExperimentalFeature( " StrictConcurrency " )
200282 ]
@@ -224,7 +306,10 @@ let package = Package(
224306 . testTarget(
225307 name: " MLXTests " ,
226308 dependencies: [
227- " MLX " , " MLXNN " , " MLXOptimizers " ,
309+ " MLX " , " MLXRandom " , " MLXNN " , " MLXOptimizers " , " MLXFFT " , " MLXLinalg " ,
310+ . target(
311+ name: " MLXFast " ,
312+ condition: . when( platforms: [ . macOS, . iOS, . tvOS, . visionOS, . watchOS] ) ) ,
228313 ]
229314 ) ,
230315
0 commit comments