Skip to content

Commit 249a8e8

Browse files
committed
switch from bool to int for sparse label dtype
avoids FutureWarnings when creating Pandas dfs from scipy sparse matrix types, but also seems like unnecessary memory use, consider reverting to bool pending info on this issue: pandas-dev/pandas#59739
1 parent 16ea52e commit 249a8e8

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

opensoundscape/annotations.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,19 +1171,22 @@ def integer_to_multi_hot(labels, n_classes, sparse=False):
11711171
labels: list of lists of integer labels, eg [[0,1,2],[3]]
11721172
n_classes: number of classes
11731173
Returns:
1174-
2d np.array with False for absent and True for present
1174+
if sparse is False: 2d np.array with False for absent and True for present
1175+
if sparse is True: scipy.sparse.csr_matrix with 0 for absent and 1 for present
11751176
"""
1177+
# TODO: consider using bool rather than int dtype, much smaller and int is unnecessary
1178+
# but bool leads to FutureWarning, see https://github.com/pandas-dev/pandas/issues/59739
11761179
if sparse:
11771180
vals = []
11781181
rows = []
11791182
cols = []
11801183
for i, row in enumerate(labels):
11811184
for col in row:
1182-
vals.append(True)
1185+
vals.append(1)
11831186
rows.append(i)
11841187
cols.append(col)
11851188
return scipy.sparse.csr_matrix(
1186-
(vals, (rows, cols)), shape=(len(labels), n_classes), dtype=bool
1189+
(vals, (rows, cols)), shape=(len(labels), n_classes), dtype=int
11871190
)
11881191
else:
11891192
multi_hot = np.zeros((len(labels), n_classes), dtype=bool)
@@ -1213,17 +1216,19 @@ def categorical_to_multi_hot(labels, classes=None, sparse=False):
12131216
rows = []
12141217
cols = []
12151218

1219+
# TODO: consider using bool rather than int dtype, much smaller and int is unnecessary
1220+
# but bool leads to FutureWarning, see https://github.com/pandas-dev/pandas/issues/59739
12161221
def add_labels(i, labels):
12171222
for label in labels:
12181223
if label in classes:
1219-
vals.append(True)
1224+
vals.append(1)
12201225
rows.append(i)
12211226
cols.append(label_idx_dict[label])
12221227

12231228
[add_labels(i, l) for i, l in enumerate(labels)]
12241229

12251230
multi_hot = scipy.sparse.csr_matrix(
1226-
(vals, (rows, cols)), shape=(len(labels), len(classes)), dtype=bool
1231+
(vals, (rows, cols)), shape=(len(labels), len(classes)), dtype=int
12271232
)
12281233

12291234
if sparse:
@@ -1394,10 +1399,6 @@ def find_overlapping_idxs_in_clip_df(
13941399

13951400

13961401
from itertools import chain
1397-
from opensoundscape.annotations import (
1398-
multi_hot_to_categorical,
1399-
categorical_to_multi_hot,
1400-
)
14011402

14021403

14031404
class CategoricalLabels:

0 commit comments

Comments
 (0)