Skip to content

Commit 7755758

Browse files
authored
Fixing the tokenization routes token (offsets are in bytes, not in (#576)
1 parent 1cd65ff commit 7755758

File tree

3 files changed

+198
-53
lines changed

3 files changed

+198
-53
lines changed

core/src/tokenization.rs

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@ pub struct Tokenization {
1616
sender: async_channel::Sender<TokenizerRequest>,
1717
}
1818

19+
#[derive(Debug)]
20+
#[cfg_attr(test, derive(PartialEq))]
21+
pub struct SimpleToken {
22+
pub id: u32,
23+
pub text: String,
24+
pub special: bool,
25+
pub start: Option<usize>,
26+
pub stop: Option<usize>,
27+
}
28+
1929
impl Tokenization {
2030
pub fn new(
2131
workers: usize,
@@ -485,3 +495,155 @@ enum TokenizerRequest {
485495
Span,
486496
),
487497
}
498+
499+
pub fn into_tokens(encoding: tokenizers::Encoding, input: &str) -> Vec<SimpleToken> {
500+
encoding
501+
.get_ids()
502+
.iter()
503+
.zip(encoding.get_offsets())
504+
.zip(encoding.get_special_tokens_mask())
505+
.zip(encoding.get_tokens())
506+
.map(|(((&id, &(start, stop)), special), token)| {
507+
let special = *special == 1;
508+
match special {
509+
true => SimpleToken {
510+
id,
511+
text: token.clone(),
512+
special,
513+
start: None,
514+
stop: None,
515+
},
516+
false => {
517+
let text: Vec<u8> = input.bytes().skip(start).take(stop - start).collect();
518+
let text: String = String::from_utf8_lossy(&text).to_string();
519+
SimpleToken {
520+
id,
521+
text,
522+
special,
523+
start: Some(start),
524+
stop: Some(stop),
525+
}
526+
}
527+
}
528+
})
529+
.collect()
530+
}
531+
532+
#[cfg(test)]
533+
mod tests {
534+
use super::*;
535+
use hf_hub::api::sync::ApiBuilder;
536+
537+
#[test]
538+
fn tokenizer() {
539+
let api = ApiBuilder::from_env().build().unwrap();
540+
let filename = api
541+
.model("BAAI/bge-m3".to_string())
542+
.get("tokenizer.json")
543+
.unwrap();
544+
let string = "这是一个文本向量化的测试句子";
545+
let tokenizer = Tokenizer::from_file(filename).unwrap();
546+
547+
let encoded = tokenizer.encode(string, true).unwrap();
548+
assert_eq!(
549+
encoded.get_offsets(),
550+
vec![
551+
(0, 0),
552+
(0, 3),
553+
(0, 12),
554+
(12, 18),
555+
(18, 21),
556+
(21, 24),
557+
(24, 30),
558+
(30, 36),
559+
(36, 39),
560+
(39, 42),
561+
(0, 0)
562+
]
563+
);
564+
565+
let tokens = into_tokens(encoded, &string);
566+
assert_eq!(
567+
tokens,
568+
vec![
569+
SimpleToken {
570+
id: 0,
571+
text: "<s>".to_string(),
572+
special: true,
573+
start: None,
574+
stop: None
575+
},
576+
SimpleToken {
577+
id: 6,
578+
text: "这".to_string(),
579+
special: false,
580+
start: Some(0),
581+
stop: Some(3)
582+
},
583+
SimpleToken {
584+
id: 100013,
585+
text: "这是一个".to_string(),
586+
special: false,
587+
start: Some(0),
588+
stop: Some(12)
589+
},
590+
SimpleToken {
591+
id: 189061,
592+
text: "文本".to_string(),
593+
special: false,
594+
start: Some(12),
595+
stop: Some(18)
596+
},
597+
SimpleToken {
598+
id: 2110,
599+
text: "向".to_string(),
600+
special: false,
601+
start: Some(18),
602+
stop: Some(21)
603+
},
604+
SimpleToken {
605+
id: 3272,
606+
text: "量".to_string(),
607+
special: false,
608+
start: Some(21),
609+
stop: Some(24)
610+
},
611+
SimpleToken {
612+
id: 41904,
613+
text: "化的".to_string(),
614+
special: false,
615+
start: Some(24),
616+
stop: Some(30)
617+
},
618+
SimpleToken {
619+
id: 49125,
620+
text: "测试".to_string(),
621+
special: false,
622+
start: Some(30),
623+
stop: Some(36)
624+
},
625+
SimpleToken {
626+
id: 27683,
627+
text: "句".to_string(),
628+
special: false,
629+
start: Some(36),
630+
stop: Some(39)
631+
},
632+
SimpleToken {
633+
id: 1344,
634+
text: "子".to_string(),
635+
special: false,
636+
start: Some(39),
637+
stop: Some(42)
638+
},
639+
SimpleToken {
640+
id: 2,
641+
text: "</s>".to_string(),
642+
special: true,
643+
start: None,
644+
stop: None
645+
}
646+
]
647+
);
648+
}
649+
}

router/src/grpc/server.rs

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ use std::future::Future;
1515
use std::net::SocketAddr;
1616
use std::time::{Duration, Instant};
1717
use text_embeddings_core::infer::Infer;
18-
use text_embeddings_core::tokenization::EncodingInput;
18+
use text_embeddings_core::tokenization::{
19+
into_tokens, EncodingInput, SimpleToken as CoreSimpleToken,
20+
};
1921
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
2022
use tokio_stream::wrappers::UnboundedReceiverStream;
2123
use tokio_stream::StreamExt;
@@ -340,32 +342,22 @@ impl TextEmbeddingsService {
340342
.map_err(ErrorResponse::from)?;
341343
let inputs = encoded_inputs.unwrap_or(inputs);
342344

343-
let tokens: Vec<SimpleToken> = encoding
344-
.get_ids()
345-
.iter()
346-
.zip(encoding.get_offsets())
347-
.zip(encoding.get_special_tokens_mask())
348-
.zip(encoding.get_tokens())
349-
.map(|(((&id, &(start, stop)), special), token)| {
350-
let special = *special == 1;
351-
match special {
352-
true => SimpleToken {
353-
id,
354-
text: token.clone(),
355-
special,
356-
start: None,
357-
stop: None,
358-
},
359-
false => {
360-
let text: String = inputs.chars().skip(start).take(stop - start).collect();
361-
SimpleToken {
362-
id,
363-
text,
364-
special,
365-
start: Some(start as u32),
366-
stop: Some(stop as u32),
367-
}
368-
}
345+
let tokens: Vec<SimpleToken> = into_tokens(encoding, &inputs)
346+
.into_iter()
347+
.map(|t| {
348+
let CoreSimpleToken {
349+
id,
350+
text,
351+
special,
352+
start,
353+
stop,
354+
} = t;
355+
SimpleToken {
356+
id,
357+
text,
358+
special,
359+
start: start.map(|s| s as u32),
360+
stop: stop.map(|s| s as u32),
369361
}
370362
})
371363
.collect();

router/src/http/server.rs

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use text_embeddings_backend::BackendError;
3434
use text_embeddings_core::infer::{
3535
AllEmbeddingsInferResponse, Infer, InferMetadata, PooledEmbeddingsInferResponse,
3636
};
37+
use text_embeddings_core::tokenization::{into_tokens, SimpleToken as CoreSimpleToken};
3738
use text_embeddings_core::TextEmbeddingsError;
3839
use tokio::sync::OwnedSemaphorePermit;
3940
use tower_http::cors::{AllowOrigin, CorsLayer};
@@ -1295,32 +1296,22 @@ async fn tokenize(
12951296
.map_err(ErrorResponse::from)?;
12961297
let input = encoded_input.unwrap_or(input);
12971298

1298-
let tokens: Vec<SimpleToken> = encoding
1299-
.get_ids()
1300-
.iter()
1301-
.zip(encoding.get_offsets())
1302-
.zip(encoding.get_special_tokens_mask())
1303-
.zip(encoding.get_tokens())
1304-
.map(|(((&id, &(start, stop)), special), token)| {
1305-
let special = *special == 1;
1306-
match special {
1307-
true => SimpleToken {
1308-
id,
1309-
text: token.clone(),
1310-
special,
1311-
start: None,
1312-
stop: None,
1313-
},
1314-
false => {
1315-
let text: String = input.chars().skip(start).take(stop - start).collect();
1316-
SimpleToken {
1317-
id,
1318-
text,
1319-
special,
1320-
start: Some(start),
1321-
stop: Some(stop),
1322-
}
1323-
}
1299+
let tokens: Vec<SimpleToken> = into_tokens(encoding, &input)
1300+
.into_iter()
1301+
.map(|t| {
1302+
let CoreSimpleToken {
1303+
id,
1304+
text,
1305+
special,
1306+
start,
1307+
stop,
1308+
} = t;
1309+
SimpleToken {
1310+
id,
1311+
text,
1312+
special,
1313+
start,
1314+
stop,
13241315
}
13251316
})
13261317
.collect();

0 commit comments

Comments
 (0)