Skip to content

Commit a38cb0c

Browse files
authored
Upgrade candle2 (#543)
1 parent f8c9852 commit a38cb0c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+10925
-31
lines changed

.github/workflows/build.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ jobs:
5858
steps:
5959
- name: Checkout repository
6060
uses: actions/checkout@v4
61+
with:
62+
submodules: true
6163

6264
- name: Initialize Docker Buildx
6365
uses: docker/setup-buildx-action@v3

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "candle-extensions/candle-flash-attn-v1/cutlass"]
2+
path = candle-extensions/candle-flash-attn-v1/cutlass
3+
url = https://github.com/NVIDIA/cutlass.git

Cargo.lock

Lines changed: 60 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ members = [
77
"backends/python",
88
"backends/grpc-client",
99
"candle-extensions/candle-cublaslt",
10+
"candle-extensions/candle-flash-attn-v1",
11+
"candle-extensions/candle-layer-norm",
12+
"candle-extensions/candle-rotary",
1013
"core",
1114
"router",
1215
]
@@ -49,13 +52,7 @@ candle = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104
4952
candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-nn" }
5053
candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-transformers" }
5154
candle-flash-attn = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-flash-attn" }
52-
53-
54-
[patch.crates-io]
55-
candle = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-core" }
56-
candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-nn" }
57-
candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-transformers" }
58-
candle-flash-attn = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-flash-attn" }
55+
half = { version = "2.3.1", features = ["num-traits"] }
5956

6057
[profile.release]
6158
debug = 0

backends/candle/Cargo.toml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@ homepage.workspace = true
99
anyhow = { workspace = true }
1010
accelerate-src = { version = "0.3.2", optional = true }
1111
intel-mkl-src = { version = "0.8.1", optional = true }
12-
candle = { version = "*", package = "candle-core", default-features = false }
13-
candle-nn = { version = "*" }
14-
candle-transformers = { version = "*" }
15-
candle-flash-attn = { version = "*", optional = true }
16-
candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn-v1", rev = "3f1870b0d708579904c76e41745c659c3f9fa038", optional = true }
17-
# candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "cf789b7dd6d4abb19b03b9556442f94f0588b4a0", optional = true }
12+
candle = { workspace = true }
13+
candle-nn = { workspace = true }
14+
candle-transformers = { workspace = true }
15+
candle-flash-attn = { workspace = true, optional = true}
16+
candle-flash-attn-v1 = { path = "../../candle-extensions/candle-flash-attn-v1", optional = true }
1817
candle-cublaslt = { path = "../../candle-extensions/candle-cublaslt", optional = true }
19-
candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "94c2add7d94c2d63aebde77f7534614e04dbaea1", optional = true }
20-
candle-rotary = { git = "https://github.com/huggingface/candle-rotary", rev = "0a718a0856569a92f3112e64f10d07e4447822e8", optional = true }
18+
candle-layer-norm = { path = "../../candle-extensions/candle-layer-norm", optional = true }
19+
candle-rotary = { path = "../../candle-extensions/candle-rotary", optional = true }
2120
nohash-hasher = { workspace = true }
2221
text-embeddings-backend-core = { path = "../core" }
2322
tracing = { workspace = true }

candle-extensions/candle-cublaslt/Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@ edition = "2021"
66
description = "CUBLASLt gemm for the candle ML framework."
77

88
[dependencies]
9-
# candle = { version = "0.8", package = "candle-core", features = ["cuda"]}
109
candle = { workspace=true, features = ["cuda"]}
1110
cudarc = { workspace = true, features = [ "cublaslt", "f16" ]}
12-
half = { version = "2.3.1", features = ["num-traits"] }
11+
half = { workspace = true}
1312

1413
[features]
1514
static-linking = ["cudarc/static-linking"]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
.idea
2+
target
3+
Cargo.lock
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "cutlass"]
2+
path = cutlass
3+
url = https://github.com/NVIDIA/cutlass.git
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
[package]
2+
name = "candle-flash-attn-v1"
3+
version = "0.0.1"
4+
edition = "2021"
5+
6+
description = "Flash attention V1 layer for the candle ML framework."
7+
keywords = ["blas", "tensor", "machine-learning"]
8+
categories = ["science"]
9+
license = "MIT OR Apache-2.0"
10+
readme = "README.md"
11+
12+
[dependencies]
13+
candle = { workspace = true }
14+
half = { workspace = true }
15+
16+
[build-dependencies]
17+
anyhow = { version = "1", features = ["backtrace"] }
18+
num_cpus = "1.15.0"
19+
rayon = "1.7.0"
20+
21+
[dev-dependencies]
22+
anyhow = { version = "1", features = ["backtrace"] }
23+
candle-nn = { version = "0.3.0", features = ["cuda"] }

0 commit comments

Comments
 (0)