Skip to content

Commit bd63fca

Browse files
committed
Support for building through SwiftPM on Linux (CPU only)
1 parent 4dccaed commit bd63fca

2 files changed

Lines changed: 237 additions & 140 deletions

File tree

Package.swift

Lines changed: 222 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,211 @@
44

55
import PackageDescription
66

7+
#if os(Linux)
8+
let platformExcludes: [String] = [
9+
// Linux specific excludes
10+
"framework",
11+
"include-framework",
12+
"metal-cpp",
13+
// Exclude Metal backend files on Linux, but keep no_metal.cpp for stubs
14+
"mlx/mlx/backend/metal/allocator.cpp",
15+
"mlx/mlx/backend/metal/binary.cpp",
16+
"mlx/mlx/backend/metal/compiled.cpp",
17+
"mlx/mlx/backend/metal/conv.cpp",
18+
"mlx/mlx/backend/metal/copy.cpp",
19+
"mlx/mlx/backend/metal/custom_kernel.cpp",
20+
"mlx/mlx/backend/metal/device.cpp",
21+
"mlx/mlx/backend/metal/distributed.cpp",
22+
"mlx/mlx/backend/metal/eval.cpp",
23+
"mlx/mlx/backend/metal/event.cpp",
24+
"mlx/mlx/backend/metal/fence.cpp",
25+
"mlx/mlx/backend/metal/fft.cpp",
26+
"mlx/mlx/backend/metal/hadamard.cpp",
27+
"mlx/mlx/backend/metal/indexing.cpp",
28+
"mlx/mlx/backend/metal/jit_kernels.cpp",
29+
"mlx/mlx/backend/metal/logsumexp.cpp",
30+
"mlx/mlx/backend/metal/matmul.cpp",
31+
"mlx/mlx/backend/metal/metal.cpp",
32+
"mlx/mlx/backend/metal/normalization.cpp",
33+
"mlx/mlx/backend/metal/primitives.cpp",
34+
"mlx/mlx/backend/metal/quantized.cpp",
35+
"mlx/mlx/backend/metal/reduce.cpp",
36+
"mlx/mlx/backend/metal/resident.cpp",
37+
"mlx/mlx/backend/metal/rope.cpp",
38+
"mlx/mlx/backend/metal/scaled_dot_product_attention.cpp",
39+
"mlx/mlx/backend/metal/scan.cpp",
40+
"mlx/mlx/backend/metal/slicing.cpp",
41+
"mlx/mlx/backend/metal/softmax.cpp",
42+
"mlx/mlx/backend/metal/sort.cpp",
43+
"mlx/mlx/backend/metal/ternary.cpp",
44+
"mlx/mlx/backend/metal/unary.cpp",
45+
"mlx/mlx/backend/metal/utils.cpp",
46+
"mlx/mlx/backend/metal/kernels", // Exclude kernels directory
47+
"mlx/mlx/backend/metal/jit", // Exclude jit directory
48+
49+
"mlx/mlx/backend/gpu", // Exclude GPU backend on Linux, use no_gpu instead
50+
"mlx/mlx/backend/no_cpu", // Exclude no_cpu backend on Linux, use cpu instead
51+
"mlx/mlx/backend/cpu/gemms/bnns.cpp", // macOS Accelerate version
52+
"mlx-conditional",
53+
"mlx-c/mlx/c/metal.cpp",
54+
// Note: mlx-c/mlx/c/fast.cpp is included on Linux - it calls C++ fast functions
55+
// which have fallback implementations that don't require fast primitives.
56+
// The metal_kernel stub from no_metal.cpp handles any metal_kernel calls.
57+
]
58+
59+
let cxxSettings: [CXXSetting] = []
60+
61+
let linkerSettings: [LinkerSetting] = [
62+
.linkedLibrary("gfortran", .when(platforms: [.linux])),
63+
.linkedLibrary("blas", .when(platforms: [.linux])),
64+
.linkedLibrary("lapack", .when(platforms: [.linux])),
65+
.linkedLibrary("openblas", .when(platforms: [.linux])),
66+
]
67+
68+
let mlxSwiftExcludes: [String] = [
69+
"GPU+Metal.swift",
70+
"MLXArray+Metal.swift",
71+
]
72+
#else
73+
let platformExcludes: [String] = [
74+
"mlx/mlx/backend/cpu/compiled.cpp",
75+
76+
// opt-out of these backends (using metal)
77+
"mlx/mlx/backend/no_gpu",
78+
"mlx/mlx/backend/no_cpu",
79+
"mlx/mlx/backend/metal/no_metal.cpp",
80+
81+
// bnns instead of simd (accelerate)
82+
"mlx/mlx/backend/cpu/gemms/simd_fp16.cpp",
83+
"mlx/mlx/backend/cpu/gemms/simd_bf16.cpp",
84+
]
85+
86+
let cxxSettings: [CXXSetting] = [
87+
.headerSearchPath("metal-cpp"),
88+
89+
.define("MLX_USE_ACCELERATE"),
90+
.define("ACCELERATE_NEW_LAPACK"),
91+
.define("_METAL_"),
92+
.define("SWIFTPM_BUNDLE", to: "\"mlx-swift_Cmlx\""),
93+
.define("METAL_PATH", to: "\"default.metallib\""),
94+
]
95+
96+
let linkerSettings: [LinkerSetting] = [
97+
.linkedFramework("Foundation"),
98+
.linkedFramework("Metal"),
99+
.linkedFramework("Accelerate"),
100+
]
101+
102+
let mlxSwiftExcludes: [String] = []
103+
#endif
104+
105+
let cmlx = Target.target(
106+
name: "Cmlx",
107+
path: "Source/Cmlx",
108+
exclude: platformExcludes + [
109+
// vendor docs
110+
"vendor-README.md",
111+
112+
// example code + mlx-c distributed
113+
"mlx-c/examples",
114+
"mlx-c/mlx/c/distributed.cpp",
115+
"mlx-c/mlx/c/distributed_group.cpp",
116+
117+
// vendored library, include header only
118+
"json",
119+
120+
// vendored library
121+
"fmt/test",
122+
"fmt/doc",
123+
"fmt/support",
124+
"fmt/src/os.cc",
125+
"fmt/src/fmt.cc",
126+
127+
// these are selected conditionally
128+
"mlx/mlx/backend/no_cpu/compiled.cpp",
129+
130+
// mlx files that are not part of the build
131+
"mlx/ACKNOWLEDGMENTS.md",
132+
"mlx/CMakeLists.txt",
133+
"mlx/CODE_OF_CONDUCT.md",
134+
"mlx/CONTRIBUTING.md",
135+
"mlx/LICENSE",
136+
"mlx/MANIFEST.in",
137+
"mlx/README.md",
138+
"mlx/benchmarks",
139+
"mlx/cmake",
140+
"mlx/docs",
141+
"mlx/examples",
142+
"mlx/mlx.pc.in",
143+
"mlx/pyproject.toml",
144+
"mlx/python",
145+
"mlx/setup.py",
146+
"mlx/tests",
147+
148+
// special handling for cuda -- we need to keep one file:
149+
// mlx/mlx/backend/cuda/no_cuda.cpp
150+
151+
"mlx/mlx/backend/cuda/allocator.cpp",
152+
"mlx/mlx/backend/cuda/compiled.cpp",
153+
"mlx/mlx/backend/cuda/conv.cpp",
154+
"mlx/mlx/backend/cuda/cublas_utils.cpp",
155+
"mlx/mlx/backend/cuda/cuda.cpp",
156+
"mlx/mlx/backend/cuda/cudnn_utils.cpp",
157+
"mlx/mlx/backend/cuda/custom_kernel.cpp",
158+
"mlx/mlx/backend/cuda/device.cpp",
159+
"mlx/mlx/backend/cuda/eval.cpp",
160+
"mlx/mlx/backend/cuda/fence.cpp",
161+
"mlx/mlx/backend/cuda/indexing.cpp",
162+
"mlx/mlx/backend/cuda/jit_module.cpp",
163+
"mlx/mlx/backend/cuda/load.cpp",
164+
"mlx/mlx/backend/cuda/matmul.cpp",
165+
"mlx/mlx/backend/cuda/primitives.cpp",
166+
"mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp",
167+
"mlx/mlx/backend/cuda/slicing.cpp",
168+
"mlx/mlx/backend/cuda/utils.cpp",
169+
"mlx/mlx/backend/cuda/worker.cpp",
170+
171+
"mlx/mlx/backend/cuda/binary",
172+
"mlx/mlx/backend/cuda/conv",
173+
"mlx/mlx/backend/cuda/copy",
174+
"mlx/mlx/backend/cuda/device",
175+
"mlx/mlx/backend/cuda/gemms",
176+
"mlx/mlx/backend/cuda/quantized",
177+
"mlx/mlx/backend/cuda/reduce",
178+
"mlx/mlx/backend/cuda/steel",
179+
"mlx/mlx/backend/cuda/unary",
180+
181+
// build variants (we are opting _out_ of these)
182+
"mlx/mlx/io/no_safetensors.cpp",
183+
"mlx/mlx/io/gguf.cpp",
184+
"mlx/mlx/io/gguf_quants.cpp",
185+
186+
// see PrepareMetalShaders -- don't build the kernels in place
187+
"mlx/mlx/backend/metal/kernels",
188+
"mlx/mlx/backend/metal/nojit_kernels.cpp",
189+
190+
// do not build distributed support (yet)
191+
"mlx/mlx/distributed/jaccl/jaccl.cpp",
192+
"mlx/mlx/distributed/mpi/mpi.cpp",
193+
"mlx/mlx/distributed/ring/ring.cpp",
194+
"mlx/mlx/distributed/nccl/nccl.cpp",
195+
"mlx/mlx/distributed/nccl/nccl_stub",
196+
],
197+
cSettings: [
198+
.headerSearchPath("mlx"),
199+
.headerSearchPath("mlx-c"),
200+
],
201+
cxxSettings: cxxSettings + [
202+
.headerSearchPath("mlx"),
203+
.headerSearchPath("mlx-c"),
204+
.headerSearchPath("json/single_include/nlohmann"),
205+
.headerSearchPath("fmt/include"),
206+
.define("MLX_VERSION", to: "\"0.24.2\""),
207+
.define("MLX_ENABLE_NAX", to: "1"),
208+
],
209+
linkerSettings: linkerSettings
210+
)
211+
7212
let package = Package(
8213
name: "mlx-swift",
9214

@@ -29,140 +234,7 @@ let package = Package(
29234
.package(url: "https://github.com/apple/swift-numerics", from: "1.0.0")
30235
],
31236
targets: [
32-
.target(
33-
name: "Cmlx",
34-
exclude: [
35-
// xcodeproj pieces
36-
"framework",
37-
"include-framework",
38-
39-
// vendor docs
40-
"vendor-README.md",
41-
42-
// example code + mlx-c distributed
43-
"mlx-c/examples",
44-
"mlx-c/mlx/c/distributed.cpp",
45-
"mlx-c/mlx/c/distributed_group.cpp",
46-
47-
// vendored library, include header only
48-
"json",
49-
50-
// vendored library
51-
"fmt/test",
52-
"fmt/doc",
53-
"fmt/support",
54-
"fmt/src/os.cc",
55-
"fmt/src/fmt.cc",
56-
57-
// these are selected conditionally
58-
"mlx/mlx/backend/no_cpu/compiled.cpp",
59-
"mlx/mlx/backend/cpu/compiled.cpp",
60-
61-
// mlx files that are not part of the build
62-
"mlx/ACKNOWLEDGMENTS.md",
63-
"mlx/CMakeLists.txt",
64-
"mlx/CODE_OF_CONDUCT.md",
65-
"mlx/CONTRIBUTING.md",
66-
"mlx/LICENSE",
67-
"mlx/MANIFEST.in",
68-
"mlx/README.md",
69-
"mlx/benchmarks",
70-
"mlx/cmake",
71-
"mlx/docs",
72-
"mlx/examples",
73-
"mlx/mlx.pc.in",
74-
"mlx/pyproject.toml",
75-
"mlx/python",
76-
"mlx/setup.py",
77-
"mlx/tests",
78-
79-
// opt-out of these backends (using metal)
80-
"mlx/mlx/backend/no_gpu",
81-
"mlx/mlx/backend/no_cpu",
82-
"mlx/mlx/backend/metal/no_metal.cpp",
83-
84-
// special handling for cuda -- we need to keep one file:
85-
// mlx/mlx/backend/cuda/no_cuda.cpp
86-
87-
"mlx/mlx/backend/cuda/allocator.cpp",
88-
"mlx/mlx/backend/cuda/compiled.cpp",
89-
"mlx/mlx/backend/cuda/conv.cpp",
90-
"mlx/mlx/backend/cuda/cublas_utils.cpp",
91-
"mlx/mlx/backend/cuda/cuda.cpp",
92-
"mlx/mlx/backend/cuda/cudnn_utils.cpp",
93-
"mlx/mlx/backend/cuda/custom_kernel.cpp",
94-
"mlx/mlx/backend/cuda/device.cpp",
95-
"mlx/mlx/backend/cuda/eval.cpp",
96-
"mlx/mlx/backend/cuda/fence.cpp",
97-
"mlx/mlx/backend/cuda/indexing.cpp",
98-
"mlx/mlx/backend/cuda/jit_module.cpp",
99-
"mlx/mlx/backend/cuda/load.cpp",
100-
"mlx/mlx/backend/cuda/matmul.cpp",
101-
"mlx/mlx/backend/cuda/primitives.cpp",
102-
"mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp",
103-
"mlx/mlx/backend/cuda/slicing.cpp",
104-
"mlx/mlx/backend/cuda/utils.cpp",
105-
"mlx/mlx/backend/cuda/worker.cpp",
106-
107-
"mlx/mlx/backend/cuda/binary",
108-
"mlx/mlx/backend/cuda/conv",
109-
"mlx/mlx/backend/cuda/copy",
110-
"mlx/mlx/backend/cuda/device",
111-
"mlx/mlx/backend/cuda/gemms",
112-
"mlx/mlx/backend/cuda/quantized",
113-
"mlx/mlx/backend/cuda/reduce",
114-
"mlx/mlx/backend/cuda/steel",
115-
"mlx/mlx/backend/cuda/unary",
116-
117-
// build variants (we are opting _out_ of these)
118-
"mlx/mlx/io/no_safetensors.cpp",
119-
"mlx/mlx/io/gguf.cpp",
120-
"mlx/mlx/io/gguf_quants.cpp",
121-
122-
// see PrepareMetalShaders -- don't build the kernels in place
123-
"mlx/mlx/backend/metal/kernels",
124-
"mlx/mlx/backend/metal/nojit_kernels.cpp",
125-
126-
// do not build distributed support (yet)
127-
"mlx/mlx/distributed/jaccl/jaccl.cpp",
128-
"mlx/mlx/distributed/mpi/mpi.cpp",
129-
"mlx/mlx/distributed/ring/ring.cpp",
130-
"mlx/mlx/distributed/nccl/nccl.cpp",
131-
"mlx/mlx/distributed/nccl/nccl_stub",
132-
133-
// bnns instead of simd (accelerate)
134-
"mlx/mlx/backend/cpu/gemms/simd_fp16.cpp",
135-
"mlx/mlx/backend/cpu/gemms/simd_bf16.cpp",
136-
],
137-
138-
cSettings: [
139-
.headerSearchPath("mlx"),
140-
.headerSearchPath("mlx-c"),
141-
],
142-
143-
cxxSettings: [
144-
.headerSearchPath("mlx"),
145-
.headerSearchPath("mlx-c"),
146-
.headerSearchPath("metal-cpp"),
147-
.headerSearchPath("json/single_include/nlohmann"),
148-
.headerSearchPath("fmt/include"),
149-
150-
.define("MLX_USE_ACCELERATE"),
151-
.define("ACCELERATE_NEW_LAPACK"),
152-
.define("_METAL_"),
153-
.define("SWIFTPM_BUNDLE", to: "\"mlx-swift_Cmlx\""),
154-
.define("METAL_PATH", to: "\"default.metallib\""),
155-
.define("MLX_VERSION", to: "\"0.24.2\""),
156-
157-
// Note: not set yet
158-
// .define("MLX_ENABLE_NAX", to: "1"),
159-
],
160-
linkerSettings: [
161-
.linkedFramework("Foundation"),
162-
.linkedFramework("Metal"),
163-
.linkedFramework("Accelerate"),
164-
]
165-
),
237+
cmlx,
166238
.testTarget(
167239
name: "CmlxTests",
168240
dependencies: ["Cmlx"]
@@ -174,6 +246,7 @@ let package = Package(
174246
"Cmlx",
175247
.product(name: "Numerics", package: "swift-numerics"),
176248
],
249+
exclude: mlxSwiftExcludes,
177250
swiftSettings: [
178251
.enableExperimentalFeature("StrictConcurrency")
179252
]
@@ -187,14 +260,23 @@ let package = Package(
187260
),
188261
.target(
189262
name: "MLXFast",
190-
dependencies: ["MLX", "Cmlx"],
263+
dependencies: [
264+
"MLX",
265+
"Cmlx",
266+
],
191267
swiftSettings: [
192268
.enableExperimentalFeature("StrictConcurrency")
193269
]
194270
),
195271
.target(
196272
name: "MLXNN",
197-
dependencies: ["MLX"],
273+
dependencies: [
274+
"MLX",
275+
"MLXRandom",
276+
.target(
277+
name: "MLXFast",
278+
condition: .when(platforms: [.macOS, .iOS, .tvOS, .visionOS, .watchOS])),
279+
],
198280
swiftSettings: [
199281
.enableExperimentalFeature("StrictConcurrency")
200282
]
@@ -224,7 +306,10 @@ let package = Package(
224306
.testTarget(
225307
name: "MLXTests",
226308
dependencies: [
227-
"MLX", "MLXNN", "MLXOptimizers",
309+
"MLX", "MLXRandom", "MLXNN", "MLXOptimizers", "MLXFFT", "MLXLinalg",
310+
.target(
311+
name: "MLXFast",
312+
condition: .when(platforms: [.macOS, .iOS, .tvOS, .visionOS, .watchOS])),
228313
]
229314
),
230315

0 commit comments

Comments
 (0)