Skip to content

Commit 29a9dae

Browse files
minor: address deprecation warnings
Signed-off-by: thiswillbeyourgithub <[email protected]>
1 parent bed301f commit 29a9dae

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

wdoc/utils/tasks/query.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,9 @@ def semantic_batching(
331331
# at the average number of token in each clusters
332332
total_mean = 0
333333
for lab in labels:
334-
lt = [texts[int(ind)] for ind in np.argwhere(cluster_labels == lab)]
334+
lt = [
335+
texts[int(ind.squeeze())] for ind in np.argwhere(cluster_labels == lab)
336+
]
335337
lsize = sum([text_sizes[t] for t in lt])
336338
lmean = lsize / len(lt)
337339
total_mean += lmean
@@ -402,7 +404,7 @@ def semantic_batching(
402404
assert len(lab_ind) > 1, f"{lab_ind}\n{cluster_labels}"
403405
assert len(lab_ind) < len(texts), f"{lab_ind}\n{cluster_labels}"
404406
for clustid in lab_ind:
405-
text = texts[int(clustid)]
407+
text = texts[int(clustid.squeeze())]
406408
size = text_sizes[text]
407409
if (current_tokens + size > max_token) and current_bucket:
408410
buckets.append(current_bucket)

0 commit comments

Comments
 (0)