File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -331,7 +331,9 @@ def semantic_batching(
331
331
# at the average number of token in each clusters
332
332
total_mean = 0
333
333
for lab in labels :
334
- lt = [texts [int (ind )] for ind in np .argwhere (cluster_labels == lab )]
334
+ lt = [
335
+ texts [int (ind .squeeze ())] for ind in np .argwhere (cluster_labels == lab )
336
+ ]
335
337
lsize = sum ([text_sizes [t ] for t in lt ])
336
338
lmean = lsize / len (lt )
337
339
total_mean += lmean
@@ -402,7 +404,7 @@ def semantic_batching(
402
404
assert len (lab_ind ) > 1 , f"{ lab_ind } \n { cluster_labels } "
403
405
assert len (lab_ind ) < len (texts ), f"{ lab_ind } \n { cluster_labels } "
404
406
for clustid in lab_ind :
405
- text = texts [int (clustid )]
407
+ text = texts [int (clustid . squeeze () )]
406
408
size = text_sizes [text ]
407
409
if (current_tokens + size > max_token ) and current_bucket :
408
410
buckets .append (current_bucket )
You can’t perform that action at this time.
0 commit comments