diff --git a/Cargo.lock b/Cargo.lock index 08c7166f..8e3c48d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -524,7 +524,6 @@ dependencies = [ [[package]] name = "candle-cublaslt" version = "0.2.2" -source = "git+https://github.com/huggingface/candle-cublaslt?rev=cf789b7dd6d4abb19b03b9556442f94f0588b4a0#cf789b7dd6d4abb19b03b9556442f94f0588b4a0" dependencies = [ "candle-core", "cudarc", @@ -874,7 +873,8 @@ dependencies = [ [[package]] name = "cudarc" version = "0.10.0" -source = "git+https://github.com/coreylowman/cudarc?rev=c388e724af93a3e8fbe484f5ded2d8b3c1badd8e#c388e724af93a3e8fbe484f5ded2d8b3c1badd8e" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9395df0cab995685664e79cc35ad6302bf08fb9c5d82301875a183affe1278b1" dependencies = [ "half", ] diff --git a/Cargo.toml b/Cargo.toml index ec3b941f..c3c9995f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,18 @@ members = [ "backends", "backends/candle", - # "backends/ort", + "backends/ort", + "backends/core", + "backends/python", + "backends/grpc-client", + "candle-extensions/candle-cublaslt", + "core", + "router", +] +default-members = [ + "backends", + "backends/candle", + "backends/ort", "backends/core", "backends/python", "backends/grpc-client", @@ -14,7 +25,7 @@ resolver = "2" [workspace.package] version = "1.6.0" edition = "2021" -authors = ["Olivier Dehaene"] +authors = ["Olivier Dehaene", "Nicolas Patry", "Alvaro Bartolome"] homepage = "https://github.com/huggingface/text-embeddings-inference" [workspace.dependencies] @@ -32,10 +43,15 @@ serde_json = "1.0" thiserror = "1.0" rand = "0.9" serial_test = "2.0.0" +# cudarc = { version = "0.13" , features =["cuda-version-from-build-system"]} +cudarc = { version = "0.10", default-features = false } +candle = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-core" } +candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-nn" } +candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-transformers" } +candle-flash-attn = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-flash-attn" } [patch.crates-io] -cudarc = { git = "https://github.com/coreylowman/cudarc", rev = "c388e724af93a3e8fbe484f5ded2d8b3c1badd8e" } candle = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-core" } candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-nn" } candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "7e02ad856104799b73a946ac1e153f0de77feaaf", package = "candle-transformers" } diff --git a/Dockerfile b/Dockerfile index 23364573..7023d409 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,6 +10,7 @@ RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sc FROM chef AS planner +COPY candle-extensions candle-extensions COPY backends backends COPY core core COPY router router diff --git a/Dockerfile-cuda b/Dockerfile-cuda index 3ffc13ee..72420f92 100644 --- a/Dockerfile-cuda +++ b/Dockerfile-cuda @@ -23,6 +23,7 @@ FROM base-builder AS planner WORKDIR /usr/src +COPY candle-extensions candle-extensions COPY backends backends COPY core core COPY router router @@ -73,6 +74,7 @@ RUN --mount=type=secret,id=actions_cache_url,env=ACTIONS_CACHE_URL \ cargo chef cook --release --features candle-cuda --features static-linking --no-default-features --recipe-path recipe.json && sccache -s; \ fi; +COPY candle-extensions candle-extensions COPY backends backends COPY core core COPY router router diff --git a/Dockerfile-cuda-all b/Dockerfile-cuda-all index 76ad31ff..5babbb63 100644 --- a/Dockerfile-cuda-all +++ b/Dockerfile-cuda-all @@ -23,6 +23,7 @@ FROM base-builder AS planner WORKDIR /usr/src +COPY candle-extensions candle-extensions COPY backends backends COPY core core COPY router router @@ -85,6 +86,7 @@ RUN --mount=type=secret,id=actions_cache_url,env=ACTIONS_CACHE_URL \ CUDA_COMPUTE_CAP=90 cargo chef cook --release --features candle-cuda --recipe-path recipe.json && sccache -s; \ fi; +COPY candle-extensions candle-extensions COPY backends backends COPY core core COPY router router diff --git a/Dockerfile-intel b/Dockerfile-intel index 909d1e8e..120e0b90 100644 --- a/Dockerfile-intel +++ b/Dockerfile-intel @@ -10,6 +10,7 @@ RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sc FROM chef AS planner +COPY candle-extensions candle-extensions COPY backends backends COPY core core COPY router router diff --git a/backends/candle/Cargo.toml b/backends/candle/Cargo.toml index 7333ba3d..bb70f563 100644 --- a/backends/candle/Cargo.toml +++ b/backends/candle/Cargo.toml @@ -14,7 +14,8 @@ candle-nn = { version = "*" } candle-transformers = { version = "*" } candle-flash-attn = { version = "*", optional = true } candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn-v1", rev = "3f1870b0d708579904c76e41745c659c3f9fa038", optional = true } -candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "cf789b7dd6d4abb19b03b9556442f94f0588b4a0", optional = true } +# candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "cf789b7dd6d4abb19b03b9556442f94f0588b4a0", optional = true } +candle-cublaslt = { path = "../../candle-extensions/candle-cublaslt", optional = true } candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "94c2add7d94c2d63aebde77f7534614e04dbaea1", optional = true } candle-rotary = { git = "https://github.com/huggingface/candle-rotary", rev = "0a718a0856569a92f3112e64f10d07e4447822e8", optional = true } nohash-hasher = { workspace = true } diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index 27e8ec32..55a225a4 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -443,7 +443,6 @@ impl GTEModel { vb: VarBuilder, config: >EConfig, ) -> Result<(Embedding, Option, GTEEncoder, LayerNorm)> { - let word_embeddings = Embedding::new( vb.pp("embeddings.word_embeddings") .get((config.vocab_size, config.hidden_size), "weight")?, diff --git a/candle-extensions/candle-cublaslt/.gitignore b/candle-extensions/candle-cublaslt/.gitignore new file mode 100644 index 00000000..fbc9a58c --- /dev/null +++ b/candle-extensions/candle-cublaslt/.gitignore @@ -0,0 +1,3 @@ +.idea +target +Cargo.lock diff --git a/candle-extensions/candle-cublaslt/Cargo.toml b/candle-extensions/candle-cublaslt/Cargo.toml new file mode 100644 index 00000000..ab0b240e --- /dev/null +++ b/candle-extensions/candle-cublaslt/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "candle-cublaslt" +version = "0.2.2" +edition = "2021" + +description = "CUBLASLt gemm for the candle ML framework." + +[dependencies] +# candle = { version = "0.8", package = "candle-core", features = ["cuda"]} +candle = { workspace=true, features = ["cuda"]} +cudarc = { workspace = true, features = [ "cublaslt", "f16" ]} +half = { version = "2.3.1", features = ["num-traits"] } + +[features] +static-linking = ["cudarc/static-linking"] diff --git a/candle-extensions/candle-cublaslt/LICENSE-APACHE b/candle-extensions/candle-cublaslt/LICENSE-APACHE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/candle-extensions/candle-cublaslt/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/candle-extensions/candle-cublaslt/LICENSE-MIT b/candle-extensions/candle-cublaslt/LICENSE-MIT new file mode 100644 index 00000000..31aa7938 --- /dev/null +++ b/candle-extensions/candle-cublaslt/LICENSE-MIT @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/candle-extensions/candle-cublaslt/README.md b/candle-extensions/candle-cublaslt/README.md new file mode 100644 index 00000000..d6f6ce53 --- /dev/null +++ b/candle-extensions/candle-cublaslt/README.md @@ -0,0 +1,4 @@ +# Candle CublasLt Matmul Layer + +CublasLt Matmul operation for the Candle ML framework. +Allows for bias and Relu/Gelu fusing. diff --git a/candle-extensions/candle-cublaslt/src/lib.rs b/candle-extensions/candle-cublaslt/src/lib.rs new file mode 100644 index 00000000..0ceba217 --- /dev/null +++ b/candle-extensions/candle-cublaslt/src/lib.rs @@ -0,0 +1,936 @@ +pub use cudarc::cublaslt::Activation; +use std::ffi::c_int; + +use candle::backend::BackendStorage; +use candle::cuda_backend::WrapErr; +use candle::{CpuStorage, Device, Layout, Result, Shape, Storage, Tensor}; +use half::{bf16, f16}; +use std::sync::Arc; + +use cudarc::cublaslt::{CudaBlasLT, Matmul, MatmulConfig}; + +#[derive(Debug, Clone)] +pub struct CublasLt(Arc); + +impl CublasLt { + pub fn new(device: &Device) -> Result { + let dev = match &*device { + Device::Cuda(d) => d, + _ => candle::bail!("`device` must be a `cuda` device"), + }; + + let inner = CudaBlasLT::new(dev.cuda_device()).unwrap(); + + Ok(Self(Arc::new(inner))) + } +} + +pub struct CublasLTMatmul { + pub cublaslt: Arc, + pub act: Option, + pub c: Option, + pub alpha: Option, + pub beta: Option, +} + +impl CublasLTMatmul { + pub fn fwd_f16( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: Option<&candle::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (m, k) = a_l.shape().dims2()?; + + let (n, b_1) = b_l.shape().dims2()?; + + if b_1 != k { + candle::bail!("This layer only supports TN layout"); + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let mut out = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle::bail!("`c` has to be contiguous"), + }; + if c_l.shape().dims2()? != (n, m) { + candle::bail!("`c` does not have the correct shape"); + } + + c.clone() + } else { + // Allocate out tensor + unsafe { dev.alloc::(out_shape.elem_count()).w()? } + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: None, + stride_b: None, + stride_c: None, + stride_bias: None, + batch_size: None, + }; + + unsafe { + self.cublaslt + .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle::Error::Cuda(Box::new(e)))?; + } + + let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } + + pub fn fwd_bf16( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: Option<&candle::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (m, k) = a_l.shape().dims2()?; + + let (n, b_1) = b_l.shape().dims2()?; + + if b_1 != k { + candle::bail!("This layer only supports TN layout"); + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let mut out = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle::bail!("`c` has to be contiguous"), + }; + if c_l.shape().dims2()? != (n, m) { + candle::bail!("`c` does not have the correct shape"); + } + + c.clone() + } else { + // Allocate out tensor + unsafe { dev.alloc::(out_shape.elem_count()).w()? } + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: None, + stride_b: None, + stride_c: None, + stride_bias: None, + batch_size: None, + }; + + unsafe { + self.cublaslt + .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle::Error::Cuda(Box::new(e)))?; + } + + let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } + + pub fn fwd_f32( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: Option<&candle::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (m, k) = a_l.shape().dims2()?; + + let (n, b_1) = b_l.shape().dims2()?; + + if b_1 != k { + candle::bail!("This layer only supports TN layout"); + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let mut out = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle::bail!("`c` has to be contiguous"), + }; + if c_l.shape().dims2()? != (n, m) { + candle::bail!("`c` does not have the correct shape"); + } + + c.clone() + } else { + // Allocate out tensor + unsafe { dev.alloc::(out_shape.elem_count()).w()? } + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: None, + stride_b: None, + stride_c: None, + stride_bias: None, + batch_size: None, + }; + + unsafe { + self.cublaslt + .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle::Error::Cuda(Box::new(e)))?; + } + + let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } +} + +impl candle::CustomOp2 for CublasLTMatmul { + fn name(&self) -> &'static str { + "cublaslt-matmul" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for cublaslt-matmul") + } + + fn cuda_fwd( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match a.dtype() { + candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), + candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), + candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), + dt => candle::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})"), + } + } +} + +impl candle::CustomOp3 for CublasLTMatmul { + fn name(&self) -> &'static str { + "cublaslt-matmul-add" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for cublaslt-matmul") + } + + fn cuda_fwd( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: &candle::CudaStorage, + bias_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match a.dtype() { + candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), + dt => candle::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})"), + } + } +} + +/// Fused matmul + add + Relu/Gelu activation using CublasLt +/// +/// # Arguments +/// +/// * `a` - Input tensor of size MxK +/// * `b` - Input tensor of size NxK +/// * `out` - Optional Output tensor of size NxK. +/// If set and beta != 0, will be added to the end result of A*B before `act` +/// * `alpha` - Optional scaling factor for A*B +/// * `beta` - Optional scaling factor for C +/// * `bias` - Optional bias tensor of size M +/// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result +/// * `cublaslt` - CublasLt handle +/// +/// The resulting tensor is of shape NxM +pub fn fused_matmul( + a: &Tensor, + b: &Tensor, + out: Option<&Tensor>, + alpha: Option, + beta: Option, + bias: Option<&Tensor>, + act: Option, + cublaslt: CublasLt, +) -> Result { + let op = CublasLTMatmul { + act, + cublaslt: cublaslt.0, + c: out.cloned(), + alpha, + beta, + }; + + if let Some(bias) = bias { + a.apply_op3(&b, &bias, op) + } else { + a.apply_op2(&b, op) + } +} + +pub struct CublasLTBatchMatmul { + pub cublaslt: Arc, + pub act: Option, + pub c: Option, + pub alpha: Option, + pub beta: Option, +} + +impl CublasLTBatchMatmul { + pub fn fwd_f16( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: Option<&candle::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (batch_size, m, k) = a_l.shape().dims3()?; + let (b_0, n, b_2) = b_l.shape().dims3()?; + + if b_2 != k { + candle::bail!("This layer only supports TN layout"); + } + + if b_0 != batch_size { + candle::bail!("`b` must have the same batch size as `a`") + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((batch_size, n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let (mut out, stride_c) = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle::bail!("`c` has to be contiguous"), + }; + + if c_l.shape().dims3()? != (batch_size, n, m) { + candle::bail!("`c` does not have the correct shape"); + } + + // Set beta to 0.0 if it is not set + (c.clone(), c_l.stride()[0]) + } else { + // Allocate out tensor + ( + unsafe { dev.alloc::(out_shape.elem_count()).w()? }, + (n * m), + ) + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: Some(a_l.stride()[0] as i64), + stride_b: Some(b_l.stride()[0] as i64), + stride_c: Some(stride_c as i64), + stride_bias: None, + batch_size: Some(batch_size as c_int), + }; + + unsafe { + self.cublaslt + .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle::Error::Cuda(Box::new(e)))?; + } + + let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } + + pub fn fwd_bf16( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: Option<&candle::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (batch_size, m, k) = a_l.shape().dims3()?; + let (b_0, n, b_2) = b_l.shape().dims3()?; + + if b_2 != k { + candle::bail!("This layer only supports TN layout"); + } + + if b_0 != batch_size { + candle::bail!("`b` must have the same batch size as `a`") + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((batch_size, n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let (mut out, stride_c) = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle::bail!("`c` has to be contiguous"), + }; + + if c_l.shape().dims3()? != (batch_size, n, m) { + candle::bail!("`c` does not have the correct shape"); + } + + // Set beta to 0.0 if it is not set + (c.clone(), c_l.stride()[0]) + } else { + // Allocate out tensor + ( + unsafe { dev.alloc::(out_shape.elem_count()).w()? }, + (n * m), + ) + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: Some(a_l.stride()[0] as i64), + stride_b: Some(b_l.stride()[0] as i64), + stride_c: Some(stride_c as i64), + stride_bias: None, + batch_size: Some(batch_size as c_int), + }; + + unsafe { + self.cublaslt + .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle::Error::Cuda(Box::new(e)))?; + } + + let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } + + pub fn fwd_f32( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: Option<&candle::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (batch_size, m, k) = a_l.shape().dims3()?; + let (b_0, n, b_2) = b_l.shape().dims3()?; + + if b_2 != k { + candle::bail!("This layer only supports TN layout"); + } + + if b_0 != batch_size { + candle::bail!("`b` must have the same batch size as `a`") + } + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((batch_size, n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.shape().dims1()? != m { + candle::bail!("Bias does not have the correct shape"); + } + + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) + } else { + None + }; + + let (mut out, stride_c) = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle::bail!("`c` has to be contiguous"), + }; + + if c_l.shape().dims3()? != (batch_size, n, m) { + candle::bail!("`c` does not have the correct shape"); + } + + // Set beta to 0.0 if it is not set + (c.clone(), c_l.stride()[0]) + } else { + // Allocate out tensor + ( + unsafe { dev.alloc::(out_shape.elem_count()).w()? }, + (n * m), + ) + }; + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: Some(a_l.stride()[0] as i64), + stride_b: Some(b_l.stride()[0] as i64), + stride_c: Some(stride_c as i64), + stride_bias: None, + batch_size: Some(batch_size as c_int), + }; + + unsafe { + self.cublaslt + .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) + .map_err(|e| candle::Error::Cuda(Box::new(e)))?; + } + + let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone()); + + Ok((out, out_shape)) + } +} + +impl candle::CustomOp2 for CublasLTBatchMatmul { + fn name(&self) -> &'static str { + "cublaslt-batch-matmul" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for cublaslt-batch-matmul") + } + + fn cuda_fwd( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match a.dtype() { + candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), + candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), + candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), + dt => { + candle::bail!("cublaslt-batch-matmul is only supported for f16/bf16/f32 ({dt:?})") + } + } + } +} + +impl candle::CustomOp3 for CublasLTBatchMatmul { + fn name(&self) -> &'static str { + "cublaslt-batch-matmul-add" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for cublaslt-batch-matmul-add") + } + + fn cuda_fwd( + &self, + a: &candle::CudaStorage, + a_l: &Layout, + b: &candle::CudaStorage, + b_l: &Layout, + bias: &candle::CudaStorage, + bias_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match a.dtype() { + candle::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + candle::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + candle::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), + dt => candle::bail!( + "cublaslt-batch-matmul-add is only supported for f16/bf16/f32 ({dt:?})" + ), + } + } +} + +/// Fused batch matmul + add + Relu/Gelu activation using CublasLt +/// +/// # Arguments +/// +/// * `a` - Input tensor of size BxMxK +/// * `b` - Input tensor of size BxNxK +/// * `out` - Optional Output tensor of size BxNxK. +/// If set and beta != 0, will be added to the end result of A*B before `act` +/// * `alpha` - Optional scaling factor for A*B +/// * `beta` - Optional scaling factor for C +/// * `bias` - Optional bias tensor of size M +/// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result +/// * `cublaslt` - CublasLt handle +/// +/// The resulting tensor is of shape NxM +pub fn fused_batch_matmul( + a: &Tensor, + b: &Tensor, + out: Option<&Tensor>, + alpha: Option, + beta: Option, + bias: Option<&Tensor>, + act: Option, + cublaslt: CublasLt, +) -> Result { + let op = CublasLTBatchMatmul { + act, + cublaslt: cublaslt.0, + c: out.cloned(), + alpha, + beta, + }; + + if let Some(bias) = bias { + a.apply_op3(&b, &bias, op) + } else { + a.apply_op2(&b, op) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use candle::{DType, Device}; + + fn to_vec2_round(t: Tensor, digits: i32) -> Result>> { + let b = 10f32.powi(digits); + let t = t.to_vec2::()?; + let t = t + .iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect(); + Ok(t) + } + + fn to_vec3_round(t: Tensor, digits: i32) -> Result>>> { + let b = 10f32.powi(digits); + let t = t.to_vec3::()?; + let t = t + .iter() + .map(|t| { + t.iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect() + }) + .collect(); + Ok(t) + } + + #[test] + fn test_fused_matmul() -> Result<()> { + let device = Device::new_cuda(0)?; + + let a = Tensor::randn(0., 1., (8, 4), &device)?.to_dtype(DType::F32)?; + let b = Tensor::randn(0., 1., (2, 4), &device)?.to_dtype(DType::F32)?; + let bias = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + + let cublaslt = CublasLt::new(&device)?; + + let res = fused_matmul(&a, &b, None, None, None, Some(&bias), None, cublaslt)?; + let expected = (b.matmul(&a.t()?)? + bias.broadcast_left(2)?)?; + + assert_eq!( + to_vec2_round(res.to_dtype(DType::F32)?, 4)?, + to_vec2_round(expected.to_dtype(DType::F32)?, 4)? + ); + Ok(()) + } + + #[test] + fn test_fused_batch_matmul() -> Result<()> { + let device = Device::new_cuda(0)?; + + let a = Tensor::randn(0., 1., (3, 8, 4), &device)?.to_dtype(DType::F32)?; + let b = Tensor::randn(0., 1., (3, 2, 4), &device)?.to_dtype(DType::F32)?; + let c = Tensor::randn(0., 1., (3, 2, 8), &device)?.to_dtype(DType::F32)?; + let bias = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?; + + let cublaslt = CublasLt::new(&device)?; + + let res = fused_batch_matmul( + &a, + &b, + Some(&c), + None, + Some(1.0), + Some(&bias), + None, + cublaslt, + )?; + let expected = (b.matmul(&a.t()?)?.add(&c)? + bias.broadcast_left((3, 2))?)?; + + assert_eq!( + to_vec3_round(res.to_dtype(DType::F32)?, 4)?, + to_vec3_round(expected.to_dtype(DType::F32)?, 4)? + ); + Ok(()) + } +}