Skip to content

Commit f12e7af

Browse files
BordaSkafteNicki
andcommitted
fix: compatibility audio do with new scipy (#2733)
* compatibility audio do with new `scipy` * smaller array to fix torch.unique case --------- Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
1 parent 80929b5 commit f12e7af

File tree

6 files changed

+15
-2
lines changed

6 files changed

+15
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4545
- Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721))
4646

4747

48+
- Fixed compatibility of audio domain with new `scipy` ([#2733](https://github.com/Lightning-AI/torchmetrics/pull/2733))
49+
50+
4851
- Fixed how `prefix`/`postfix` works in `MultitaskWrapper` ([#2722](https://github.com/Lightning-AI/torchmetrics/pull/2722))
4952

5053

src/torchmetrics/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
if not hasattr(PIL, "PILLOW_VERSION"):
2121
PIL.PILLOW_VERSION = PIL.__version__
2222

23+
if package_available("scipy"):
24+
import scipy.signal
25+
26+
# back compatibility patch due to SMRMpy using scipy.signal.hamming
27+
if not hasattr(scipy.signal, "hamming"):
28+
scipy.signal.hamming = scipy.signal.windows.hamming
29+
2330
from torchmetrics import functional # noqa: E402
2431
from torchmetrics.aggregation import ( # noqa: E402
2532
CatMetric,

src/torchmetrics/functional/nominal/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix
1516
from torchmetrics.functional.nominal.fleiss_kappa import fleiss_kappa
1617
from torchmetrics.functional.nominal.pearson import (

src/torchmetrics/nominal/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
from torchmetrics.nominal.cramers import CramersV
1516
from torchmetrics.nominal.fleiss_kappa import FleissKappa
1617
from torchmetrics.nominal.pearson import PearsonsContingencyCoefficient

src/torchmetrics/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
_MECAB_KO_DIC_AVAILABLE = RequirementCache("mecab_ko_dic")
6565
_IPADIC_AVAILABLE = RequirementCache("ipadic")
6666
_SENTENCEPIECE_AVAILABLE = RequirementCache("sentencepiece")
67+
_SCIPI_AVAILABLE = RequirementCache("scipy")
6768
_SKLEARN_GREATER_EQUAL_1_3 = RequirementCache("scikit-learn>=1.3.0")
6869

6970
_LATEX_AVAILABLE: bool = shutil.which("latex") is not None

tests/unittests/classification/test_stat_scores.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,8 @@ def test_support_for_int():
582582
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/1970."""
583583
seed_all(42)
584584
metric = MulticlassStatScores(num_classes=4, average="none", multidim_average="samplewise", ignore_index=0)
585-
prediction = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8)
586-
label = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8)
585+
prediction = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8)
586+
label = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8)
587587
score = metric(preds=prediction, target=label)
588588
assert score.shape == (1, 4, 5)
589589

0 commit comments

Comments
 (0)