Skip to content

Commit f4b1aa3

Browse files
authored
Upgrade candle3 (#545)
1 parent a38cb0c commit f4b1aa3

File tree

10 files changed

+757
-433
lines changed

10 files changed

+757
-433
lines changed

Cargo.lock

Lines changed: 723 additions & 410 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,20 @@ serde_json = "1.0"
4646
thiserror = "1.0"
4747
rand = "0.9"
4848
serial_test = "2.0.0"
49-
# cudarc = { version = "0.13" , features =["cuda-version-from-build-system"]}
50-
cudarc = { version = "0.10", default-features = false }
51-
candle = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-core" }
52-
candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-nn" }
53-
candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-transformers" }
54-
candle-flash-attn = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-flash-attn" }
49+
cudarc = { version = "0.13" , features =["cuda-12020"]}
50+
candle = { version = "0.8", package = "candle-core" }
51+
candle-nn = { version = "0.8", package = "candle-nn" }
52+
candle-transformers = { version = "0.8", package = "candle-transformers" }
53+
candle-flash-attn = { version = "0.8", package = "candle-flash-attn" }
5554
half = { version = "2.3.1", features = ["num-traits"] }
5655

56+
[patch.crates-io]
57+
cudarc = { git = "https://github.com/Narsil/cudarc" , rev = "1956436aeddea1da04fc3226282bc07c07eeaa35"}
58+
candle = { git = "https://github.com/Narsil/candle", rev = "2e273ddf31b1b796d3cfcd181ccb98deaa48466e", package = "candle-core" }
59+
candle-nn = { git = "https://github.com/Narsil/candle", rev = "2e273ddf31b1b796d3cfcd181ccb98deaa48466e", package = "candle-nn" }
60+
candle-transformers = { git = "https://github.com/Narsil/candle", rev = "2e273ddf31b1b796d3cfcd181ccb98deaa48466e", package = "candle-transformers" }
61+
candle-flash-attn = { git = "https://github.com/Narsil/candle", rev = "2e273ddf31b1b796d3cfcd181ccb98deaa48466e", package = "candle-flash-attn" }
62+
5763
[profile.release]
5864
debug = 0
5965
# lto = "fat"

backends/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,5 @@ metal = ["text-embeddings-backend-candle?/metal"]
2727
mkl = ["text-embeddings-backend-candle?/mkl"]
2828
mkl-dynamic = ["text-embeddings-backend-candle?/mkl-dynamic"]
2929
accelerate = ["text-embeddings-backend-candle?/accelerate"]
30-
static-linking = ["text-embeddings-backend-candle?/static-linking"]
3130
flash-attn = ["text-embeddings-backend-candle?/flash-attn"]
3231
flash-attn-v1 = ["text-embeddings-backend-candle?/flash-attn-v1"]

backends/candle/Cargo.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ anyhow = { version = "1", features = ["backtrace"] }
4040
[features]
4141
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
4242
metal = ["candle/metal", "candle-nn/metal"]
43-
mkl = ["dep:intel-mkl-src", "intel-mkl-src/mkl-static-lp64-iomp", "candle/mkl", "candle-nn/mkl"]
44-
mkl-dynamic = ["dep:intel-mkl-src", "intel-mkl-src/mkl-dynamic-lp64-iomp", "candle/mkl-dynamic", "candle-nn/mkl-dynamic"]
43+
mkl = ["dep:intel-mkl-src", "intel-mkl-src/mkl-static-lp64-iomp", "candle/mkl"]
44+
mkl-dynamic = ["dep:intel-mkl-src", "intel-mkl-src/mkl-dynamic-lp64-iomp", "candle/mkl"]
4545
cuda = ["candle/cuda", "candle-nn/cuda", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary"]
4646
flash-attn-v1 = ["dep:candle-flash-attn-v1", "cuda"]
4747
flash-attn = ["dep:candle-flash-attn", "cuda"]
48-
static-linking = ["candle-cublaslt?/static-linking"]

candle-extensions/candle-cublaslt/Cargo.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,3 @@ description = "CUBLASLt gemm for the candle ML framework."
99
candle = { workspace=true, features = ["cuda"]}
1010
cudarc = { workspace = true, features = [ "cublaslt", "f16" ]}
1111
half = { workspace = true}
12-
13-
[features]
14-
static-linking = ["cudarc/static-linking"]

candle-extensions/candle-flash-attn-v1/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ rayon = "1.7.0"
2020

2121
[dev-dependencies]
2222
anyhow = { version = "1", features = ["backtrace"] }
23-
candle-nn = { version = "0.3.0", features = ["cuda"] }
23+
candle-nn = { workspace = true }

candle-extensions/candle-rotary/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ bindgen_cuda = "0.1.1"
1919

2020
[dev-dependencies]
2121
anyhow = { version = "1", features = ["backtrace"] }
22-
candle-nn = { version = "0.3.0", features = ["cuda"] }
22+
candle-nn = { workspace = true }

flake.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

flake.nix

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@
148148
# hash = "sha256-1AN2E9t/lZhbXdVznhTcniy+7ZzlaEp/gwLEAucs6EA=";
149149
# # hash = lib.fakeHash;
150150
# };
151+
mkl2024 = import ./nix/mkl.nix;
152+
153+
onnxruntimeGcc13 = pkgs.onnxruntime.override {
154+
stdenv = pkgs.cudaPackages.backendStdenv;
155+
};
151156

152157
in
153158
# cargoDeps = pkgs.rustPlatform.fetchCargoVendor {
@@ -195,7 +200,7 @@
195200
devShells.default =
196201
pkgs.mkShell.override
197202
{
198-
stdenv = pkgs.gcc13Stdenv;
203+
stdenv = pkgs.cudaPackages.backendStdenv;
199204
}
200205
{
201206

@@ -208,6 +213,8 @@
208213
cudaPackages.cudatoolkit
209214
python3Packages.python
210215
python3Packages.venvShellHook
216+
onnxruntimeGcc13
217+
mkl
211218
];
212219
venvDir = "./.venv";
213220
LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib:/run/opengl-driver/lib";

router/Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ tonic-health = { version = "0.11.0", optional = true }
5757
tonic-reflection = { version = "0.11.0", optional = true }
5858
tokio-stream = { version = "0.1.14", optional = true }
5959

60+
# Optional
61+
cudarc = { workspace = true, optional = true }
62+
6063
# Malloc trim hack for linux
6164
[target.'cfg(target_os = "linux")'.dependencies]
6265
libc = "0.2.149"
@@ -88,5 +91,5 @@ candle = ["text-embeddings-backend/candle"]
8891
candle-cuda = ["candle", "text-embeddings-backend/flash-attn"]
8992
candle-cuda-turing = ["candle", "text-embeddings-backend/flash-attn-v1"]
9093
candle-cuda-volta = ["candle", "text-embeddings-backend/cuda"]
91-
static-linking = ["text-embeddings-backend/static-linking"]
94+
static-linking = ["cudarc/static-linking"]
9295
google = []

0 commit comments

Comments
 (0)