@@ -1171,19 +1171,22 @@ def integer_to_multi_hot(labels, n_classes, sparse=False):
1171
1171
labels: list of lists of integer labels, eg [[0,1,2],[3]]
1172
1172
n_classes: number of classes
1173
1173
Returns:
1174
- 2d np.array with False for absent and True for present
1174
+ if sparse is False: 2d np.array with False for absent and True for present
1175
+ if sparse is True: scipy.sparse.csr_matrix with 0 for absent and 1 for present
1175
1176
"""
1177
+ # TODO: consider using bool rather than int dtype, much smaller and int is unnecessary
1178
+ # but bool leads to FutureWarning, see https://github.com/pandas-dev/pandas/issues/59739
1176
1179
if sparse :
1177
1180
vals = []
1178
1181
rows = []
1179
1182
cols = []
1180
1183
for i , row in enumerate (labels ):
1181
1184
for col in row :
1182
- vals .append (True )
1185
+ vals .append (1 )
1183
1186
rows .append (i )
1184
1187
cols .append (col )
1185
1188
return scipy .sparse .csr_matrix (
1186
- (vals , (rows , cols )), shape = (len (labels ), n_classes ), dtype = bool
1189
+ (vals , (rows , cols )), shape = (len (labels ), n_classes ), dtype = int
1187
1190
)
1188
1191
else :
1189
1192
multi_hot = np .zeros ((len (labels ), n_classes ), dtype = bool )
@@ -1213,17 +1216,19 @@ def categorical_to_multi_hot(labels, classes=None, sparse=False):
1213
1216
rows = []
1214
1217
cols = []
1215
1218
1219
+ # TODO: consider using bool rather than int dtype, much smaller and int is unnecessary
1220
+ # but bool leads to FutureWarning, see https://github.com/pandas-dev/pandas/issues/59739
1216
1221
def add_labels (i , labels ):
1217
1222
for label in labels :
1218
1223
if label in classes :
1219
- vals .append (True )
1224
+ vals .append (1 )
1220
1225
rows .append (i )
1221
1226
cols .append (label_idx_dict [label ])
1222
1227
1223
1228
[add_labels (i , l ) for i , l in enumerate (labels )]
1224
1229
1225
1230
multi_hot = scipy .sparse .csr_matrix (
1226
- (vals , (rows , cols )), shape = (len (labels ), len (classes )), dtype = bool
1231
+ (vals , (rows , cols )), shape = (len (labels ), len (classes )), dtype = int
1227
1232
)
1228
1233
1229
1234
if sparse :
@@ -1394,10 +1399,6 @@ def find_overlapping_idxs_in_clip_df(
1394
1399
1395
1400
1396
1401
from itertools import chain
1397
- from opensoundscape .annotations import (
1398
- multi_hot_to_categorical ,
1399
- categorical_to_multi_hot ,
1400
- )
1401
1402
1402
1403
1403
1404
class CategoricalLabels :
0 commit comments