Skip to content

Commit 9263eb1

Browse files
fix: limit peak memory to build cuda-all docker image (huggingface#246)
1 parent d33d44a commit 9263eb1

File tree

11 files changed

+132
-96
lines changed

11 files changed

+132
-96
lines changed

.github/workflows/build_all.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
on:
44
workflow_dispatch:
55
push:
6+
branches:
7+
- 'main'
68
tags:
79
- 'v*'
810

Dockerfile-cuda-all

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,49 +33,46 @@ FROM base-builder AS builder
3333

3434
ARG GIT_SHA
3535
ARG DOCKER_LABEL
36-
ARG VERTEX
36+
ARG VERTEX="false"
3737

3838
# sccache specific variables
3939
ARG ACTIONS_CACHE_URL
4040
ARG ACTIONS_RUNTIME_TOKEN
4141
ARG SCCACHE_GHA_ENABLED
4242

4343
# limit the number of kernels built at the same time
44-
ARG RAYON_NUM_THREADS=2
44+
ARG RAYON_NUM_THREADS=4
4545

4646
WORKDIR /usr/src
4747

4848
COPY --from=planner /usr/src/recipe.json recipe.json
4949

50-
FROM builder as builder-75
51-
5250
RUN if [ $VERTEX = "true" ]; \
5351
then \
54-
CUDA_COMPUTE_CAP=75 cargo chef cook --release --features google --features candle-cuda-turing --features http --no-default-features --recipe-path recipe.json && sccache -s; \
52+
cargo chef cook --release --features google --recipe-path recipe.json && sccache -s; \
5553
else \
56-
CUDA_COMPUTE_CAP=75 cargo chef cook --release --features candle-cuda-turing --no-default-features --features http --recipe-path recipe.json && sccache -s; \
54+
cargo chef cook --release --recipe-path recipe.json && sccache -s; \
5755
fi;
5856

59-
COPY backends backends
60-
COPY core core
61-
COPY router router
62-
COPY Cargo.toml ./
63-
COPY Cargo.lock ./
64-
6557
RUN if [ $VERTEX = "true" ]; \
6658
then \
67-
CUDA_COMPUTE_CAP=75 cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F http -F google --no-default-features && sccache -s; \
59+
CUDA_COMPUTE_CAP=75 cargo chef cook --release --features google --features candle-cuda-turing --recipe-path recipe.json && sccache -s; \
6860
else \
69-
CUDA_COMPUTE_CAP=75 cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F http --no-default-features && sccache -s; \
61+
CUDA_COMPUTE_CAP=75 cargo chef cook --release --features candle-cuda-turing --recipe-path recipe.json && sccache -s; \
7062
fi;
7163

72-
FROM builder as builder-80
64+
RUN if [ $VERTEX = "true" ]; \
65+
then \
66+
CUDA_COMPUTE_CAP=80 cargo chef cook --release --features google --features candle-cuda --recipe-path recipe.json && sccache -s; \
67+
else \
68+
CUDA_COMPUTE_CAP=80 cargo chef cook --release --features candle-cuda --recipe-path recipe.json && sccache -s; \
69+
fi;
7370

7471
RUN if [ $VERTEX = "true" ]; \
7572
then \
76-
CUDA_COMPUTE_CAP=80 cargo chef cook --release --features google --features candle-cuda --features http --no-default-features --recipe-path recipe.json && sccache -s; \
73+
CUDA_COMPUTE_CAP=90 cargo chef cook --release --features google --features candle-cuda --recipe-path recipe.json && sccache -s; \
7774
else \
78-
CUDA_COMPUTE_CAP=80 cargo chef cook --release --features candle-cuda --no-default-features --features http --recipe-path recipe.json && sccache -s; \
75+
CUDA_COMPUTE_CAP=90 cargo chef cook --release --features candle-cuda --recipe-path recipe.json && sccache -s; \
7976
fi;
8077

8178
COPY backends backends
@@ -86,33 +83,31 @@ COPY Cargo.lock ./
8683

8784
RUN if [ $VERTEX = "true" ]; \
8885
then \
89-
CUDA_COMPUTE_CAP=80 cargo build --release --bin text-embeddings-router -F candle-cuda -F http -F google --no-default-features && sccache -s; \
86+
CUDA_COMPUTE_CAP=75 cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F google && sccache -s; \
9087
else \
91-
CUDA_COMPUTE_CAP=80 cargo build --release --bin text-embeddings-router -F candle-cuda -F http --no-default-features && sccache -s; \
88+
CUDA_COMPUTE_CAP=75 cargo build --release --bin text-embeddings-router -F candle-cuda-turing && sccache -s; \
9289
fi;
9390

94-
FROM builder as builder-90
91+
RUN mv /usr/src/target/release/text-embeddings-router /usr/src/target/release/text-embeddings-router-75
9592

9693
RUN if [ $VERTEX = "true" ]; \
9794
then \
98-
CUDA_COMPUTE_CAP=90 cargo chef cook --release --features google --features candle-cuda --features http --no-default-features --recipe-path recipe.json && sccache -s; \
95+
CUDA_COMPUTE_CAP=80 cargo build --release --bin text-embeddings-router -F candle-cuda -F google && sccache -s; \
9996
else \
100-
CUDA_COMPUTE_CAP=90 cargo chef cook --release --features candle-cuda --features http --no-default-features --recipe-path recipe.json && sccache -s; \
97+
CUDA_COMPUTE_CAP=80 cargo build --release --bin text-embeddings-router -F candle-cuda && sccache -s; \
10198
fi;
10299

103-
COPY backends backends
104-
COPY core core
105-
COPY router router
106-
COPY Cargo.toml ./
107-
COPY Cargo.lock ./
100+
RUN mv /usr/src/target/release/text-embeddings-router /usr/src/target/release/text-embeddings-router-80
108101

109102
RUN if [ $VERTEX = "true" ]; \
110103
then \
111-
CUDA_COMPUTE_CAP=90 cargo build --release --bin text-embeddings-router -F candle-cuda -F http -F google --no-default-features && sccache -s; \
104+
CUDA_COMPUTE_CAP=90 cargo build --release --bin text-embeddings-router -F candle-cuda -F google && sccache -s; \
112105
else \
113-
CUDA_COMPUTE_CAP=90 cargo build --release --bin text-embeddings-router -F candle-cuda -F http --no-default-features && sccache -s; \
106+
CUDA_COMPUTE_CAP=90 cargo build --release --bin text-embeddings-router -F candle-cuda && sccache -s; \
114107
fi;
115108

109+
RUN mv /usr/src/target/release/text-embeddings-router /usr/src/target/release/text-embeddings-router-90
110+
116111
FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04 as base
117112

118113
ARG DEFAULT_USE_FLASH_ATTENTION=True
@@ -121,9 +116,9 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
121116
PORT=80 \
122117
USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION
123118

124-
COPY --from=builder-75 /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router-75
125-
COPY --from=builder-80 /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router-80
126-
COPY --from=builder-90 /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router-90
119+
COPY --from=builder /usr/src/target/release/text-embeddings-router-75 /usr/local/bin/text-embeddings-router-75
120+
COPY --from=builder /usr/src/target/release/text-embeddings-router-80 /usr/local/bin/text-embeddings-router-80
121+
COPY --from=builder /usr/src/target/release/text-embeddings-router-90 /usr/local/bin/text-embeddings-router-90
127122

128123
# Amazon SageMaker compatible image
129124
FROM base AS sagemaker

backends/candle/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ authors.workspace = true
66
homepage.workspace = true
77

88
[dependencies]
9+
anyhow = "^1.0"
910
accelerate-src = { version = "0.3.2", optional = true }
1011
intel-mkl-src = { version = "0.8.1", optional = true }
1112
candle = { version = "*", package = "candle-core", default-features = false }

backends/candle/src/compute_cap.rs

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,26 @@
1+
use anyhow::Context;
2+
use candle::cuda_backend::cudarc::driver;
13
use candle::cuda_backend::cudarc::driver::sys::CUdevice_attribute::{
24
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
35
};
46
use candle::cuda_backend::cudarc::driver::CudaDevice;
5-
use std::sync::Once;
67

7-
static INIT: Once = Once::new();
8-
static mut RUNTIME_COMPUTE_CAP: usize = 0;
9-
static mut COMPILE_COMPUTE_CAP: usize = 0;
10-
11-
fn init_compute_caps() {
12-
unsafe {
13-
INIT.call_once(|| {
14-
let device = CudaDevice::new(0).expect("cuda is not available");
15-
let major = device
16-
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)
17-
.unwrap();
18-
let minor = device
19-
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)
20-
.unwrap();
21-
RUNTIME_COMPUTE_CAP = (major * 10 + minor) as usize;
22-
COMPILE_COMPUTE_CAP = env!("CUDA_COMPUTE_CAP").parse::<usize>().unwrap();
23-
});
24-
}
25-
}
26-
27-
pub fn get_compile_compute_cap() -> usize {
28-
unsafe {
29-
init_compute_caps();
30-
COMPILE_COMPUTE_CAP
31-
}
8+
pub fn get_compile_compute_cap() -> Result<usize, anyhow::Error> {
9+
env!("CUDA_COMPUTE_CAP")
10+
.parse::<usize>()
11+
.context("Could not retrieve compile time CUDA_COMPUTE_CAP")
3212
}
3313

34-
pub fn get_runtime_compute_cap() -> usize {
35-
unsafe {
36-
init_compute_caps();
37-
RUNTIME_COMPUTE_CAP
38-
}
14+
pub fn get_runtime_compute_cap() -> Result<usize, anyhow::Error> {
15+
driver::result::init().context("CUDA is not available")?;
16+
let device = CudaDevice::new(0).context("CUDA is not available")?;
17+
let major = device
18+
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)
19+
.context("Could not retrieve device compute capability major")?;
20+
let minor = device
21+
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)
22+
.context("Could not retrieve device compute capability minor")?;
23+
Ok((major * 10 + minor) as usize)
3924
}
4025

4126
fn compute_cap_matching(runtime_compute_cap: usize, compile_compute_cap: usize) -> bool {
@@ -49,10 +34,13 @@ fn compute_cap_matching(runtime_compute_cap: usize, compile_compute_cap: usize)
4934
}
5035
}
5136

52-
pub fn incompatible_compute_cap() -> bool {
53-
let compile_compute_cap = get_compile_compute_cap();
54-
let runtime_compute_cap = get_runtime_compute_cap();
55-
!compute_cap_matching(runtime_compute_cap, compile_compute_cap)
37+
pub fn compatible_compute_cap() -> Result<bool, anyhow::Error> {
38+
let compile_compute_cap = get_compile_compute_cap()?;
39+
let runtime_compute_cap = get_runtime_compute_cap()?;
40+
Ok(compute_cap_matching(
41+
runtime_compute_cap,
42+
compile_compute_cap,
43+
))
5644
}
5745

5846
#[cfg(test)]

backends/candle/src/flash_attn.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
1-
use crate::compute_cap::get_runtime_compute_cap;
21
use candle::Tensor;
2+
use std::sync::Once;
3+
4+
static INIT: Once = Once::new();
5+
static mut RUNTIME_COMPUTE_CAP: usize = 0;
6+
fn init_runtime_compute_cap() {
7+
unsafe {
8+
INIT.call_once(|| {
9+
use crate::compute_cap::get_runtime_compute_cap;
10+
RUNTIME_COMPUTE_CAP = get_runtime_compute_cap().unwrap();
11+
});
12+
}
13+
}
14+
15+
pub fn get_runtime_compute_cap() -> usize {
16+
unsafe {
17+
init_runtime_compute_cap();
18+
RUNTIME_COMPUTE_CAP
19+
}
20+
}
321

422
#[allow(clippy::too_many_arguments, unused)]
523
pub(crate) fn flash_attn_varlen(

backends/candle/src/layers/cublaslt.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,27 @@ static mut CUBLASLT: Option<CublasLtWrapper> = None;
1111
pub fn get_cublas_lt_wrapper() -> Option<&'static CublasLtWrapper> {
1212
unsafe {
1313
INIT.call_once(|| {
14-
CUBLASLT = match Device::cuda_if_available(0) {
15-
Ok(device) => {
16-
#[cfg(feature = "cuda")]
17-
{
18-
Some(CublasLtWrapper {
14+
#[cfg(not(feature = "cuda"))]
15+
{
16+
CUBLASLT = None;
17+
}
18+
19+
#[cfg(feature = "cuda")]
20+
{
21+
// Check if we can call the driver
22+
// Then check if we can create a device
23+
// Then check that the device is CUDA
24+
use candle::cuda_backend::cudarc::driver;
25+
CUBLASLT = driver::result::init()
26+
.ok()
27+
.and_then(|_| Device::cuda_if_available(0).ok())
28+
.and_then(|device| match device {
29+
Device::Cuda(_) => Some(CublasLtWrapper {
1930
cublaslt: CublasLt::new(&device).unwrap(),
20-
})
21-
}
22-
#[cfg(not(feature = "cuda"))]
23-
{
24-
None
25-
}
26-
}
27-
Err(_) => None,
28-
};
31+
}),
32+
_ => None,
33+
});
34+
}
2935
});
3036
CUBLASLT.as_ref()
3137
}

backends/candle/src/lib.rs

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ mod models;
88

99
#[cfg(feature = "cuda")]
1010
use crate::compute_cap::{
11-
get_compile_compute_cap, get_runtime_compute_cap, incompatible_compute_cap,
11+
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
1212
};
1313
use crate::models::{
1414
BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, Model, NomicBertModel,
@@ -43,6 +43,7 @@ enum Config {
4343
}
4444

4545
pub struct CandleBackend {
46+
device: Device,
4647
model: Box<dyn Model + Send>,
4748
}
4849

@@ -61,14 +62,23 @@ impl CandleBackend {
6162
// Get candle device
6263
let device = if candle::utils::cuda_is_available() {
6364
#[cfg(feature = "cuda")]
64-
if incompatible_compute_cap() {
65-
return Err(BackendError::Start(format!(
66-
"Runtime compute cap {} is not compatible with compile time compute cap {}",
67-
get_runtime_compute_cap(),
68-
get_compile_compute_cap()
69-
)));
65+
match compatible_compute_cap() {
66+
Ok(true) => Device::new_cuda(0),
67+
Ok(false) => {
68+
return Err(BackendError::Start(format!(
69+
"Runtime compute cap {} is not compatible with compile time compute cap {}",
70+
get_runtime_compute_cap().unwrap(),
71+
get_compile_compute_cap().unwrap()
72+
)))
73+
}
74+
Err(err) => {
75+
tracing::warn!("Could not find a compatible CUDA device on host: {err}");
76+
tracing::warn!("Using CPU instead");
77+
Ok(Device::Cpu)
78+
}
7079
}
71-
Device::new_cuda(0)
80+
#[cfg(not(feature = "cuda"))]
81+
Ok(Device::Cpu)
7282
} else if candle::utils::metal_is_available() {
7383
Device::new_metal(0)
7484
} else {
@@ -225,11 +235,22 @@ impl CandleBackend {
225235
}
226236
};
227237

228-
Ok(Self { model: model? })
238+
Ok(Self {
239+
device,
240+
model: model?,
241+
})
229242
}
230243
}
231244

232245
impl Backend for CandleBackend {
246+
fn max_batch_size(&self) -> Option<usize> {
247+
// Limit max batch size to 4 on CPU
248+
if matches!(self.device, Device::Cpu) {
249+
return Some(4);
250+
}
251+
None
252+
}
253+
233254
fn health(&self) -> Result<(), BackendError> {
234255
Ok(())
235256
}

backends/candle/src/models/bert.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,14 @@ impl ClassificationHead for BertClassificationHead {
405405
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
406406
let _enter = self.span.enter();
407407

408-
let mut hidden_states = hidden_states.clone();
408+
let mut hidden_states = hidden_states.unsqueeze(1)?;
409409
if let Some(pooler) = self.pooler.as_ref() {
410410
hidden_states = pooler.forward(&hidden_states)?;
411411
hidden_states = hidden_states.tanh()?;
412412
}
413413

414414
let hidden_states = self.output.forward(&hidden_states)?;
415+
let hidden_states = hidden_states.squeeze(1)?;
415416
Ok(hidden_states)
416417
}
417418
}
@@ -453,10 +454,11 @@ impl ClassificationHead for RobertaClassificationHead {
453454
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
454455
let _enter = self.span.enter();
455456

456-
let hidden_states = self.intermediate.forward(hidden_states)?;
457+
let hidden_states = hidden_states.unsqueeze(1)?;
458+
let hidden_states = self.intermediate.forward(&hidden_states)?;
457459
let hidden_states = hidden_states.tanh()?;
458460
let hidden_states = self.output.forward(&hidden_states)?;
459-
461+
let hidden_states = hidden_states.squeeze(1)?;
460462
Ok(hidden_states)
461463
}
462464
}

router/src/http/server.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1548,7 +1548,9 @@ pub async fn run(
15481548
}
15491549

15501550
// Run server
1551-
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
1551+
let listener = tokio::net::TcpListener::bind(&addr)
1552+
.await
1553+
.context(format!("Could not bind TCP Listener on {addr}"))?;
15521554

15531555
tracing::info!("Starting HTTP server: {}", &addr);
15541556
tracing::info!("Ready");

0 commit comments

Comments
 (0)