Skip to content

Commit 96ceda0

Browse files
BordaSkafteNicki
andauthored
fix: compatibility audio do with new scipy (#2733)
* compatibility audio do with new `scipy` * smaller array to fix torch.unique case --------- Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
1 parent 80929b5 commit 96ceda0

File tree

7 files changed

+40
-2
lines changed

7 files changed

+40
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4545
- Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721))
4646

4747

48+
- Fixed compatibility of audio domain with new `scipy` ([#2733](https://github.com/Lightning-AI/torchmetrics/pull/2733))
49+
50+
4851
- Fixed how `prefix`/`postfix` works in `MultitaskWrapper` ([#2722](https://github.com/Lightning-AI/torchmetrics/pull/2722))
4952

5053

src/torchmetrics/audio/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,18 @@
2828
_ONNXRUNTIME_AVAILABLE,
2929
_PESQ_AVAILABLE,
3030
_PYSTOI_AVAILABLE,
31+
_SCIPI_AVAILABLE,
3132
_TORCHAUDIO_AVAILABLE,
3233
_TORCHAUDIO_GREATER_EQUAL_0_10,
3334
)
3435

36+
if _SCIPI_AVAILABLE:
37+
import scipy.signal
38+
39+
# back compatibility patch due to SMRMpy using scipy.signal.hamming
40+
if not hasattr(scipy.signal, "hamming"):
41+
scipy.signal.hamming = scipy.signal.windows.hamming
42+
3543
__all__ = [
3644
"PermutationInvariantTraining",
3745
"ScaleInvariantSignalDistortionRatio",

src/torchmetrics/functional/audio/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,18 @@
2828
_ONNXRUNTIME_AVAILABLE,
2929
_PESQ_AVAILABLE,
3030
_PYSTOI_AVAILABLE,
31+
_SCIPI_AVAILABLE,
3132
_TORCHAUDIO_AVAILABLE,
3233
_TORCHAUDIO_GREATER_EQUAL_0_10,
3334
)
3435

36+
if _SCIPI_AVAILABLE:
37+
import scipy.signal
38+
39+
# back compatibility patch due to SMRMpy using scipy.signal.hamming
40+
if not hasattr(scipy.signal, "hamming"):
41+
scipy.signal.hamming = scipy.signal.windows.hamming
42+
3543
__all__ = [
3644
"permutation_invariant_training",
3745
"pit_permutate",

src/torchmetrics/functional/nominal/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix
1516
from torchmetrics.functional.nominal.fleiss_kappa import fleiss_kappa
1617
from torchmetrics.functional.nominal.pearson import (
@@ -19,6 +20,14 @@
1920
)
2021
from torchmetrics.functional.nominal.theils_u import theils_u, theils_u_matrix
2122
from torchmetrics.functional.nominal.tschuprows import tschuprows_t, tschuprows_t_matrix
23+
from torchmetrics.utilities.imports import _SCIPI_AVAILABLE
24+
25+
if _SCIPI_AVAILABLE:
26+
import scipy.signal
27+
28+
# back compatibility patch due to SMRMpy using scipy.signal.hamming
29+
if not hasattr(scipy.signal, "hamming"):
30+
scipy.signal.hamming = scipy.signal.windows.hamming
2231

2332
__all__ = [
2433
"cramers_v",

src/torchmetrics/nominal/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
from torchmetrics.nominal.cramers import CramersV
1516
from torchmetrics.nominal.fleiss_kappa import FleissKappa
1617
from torchmetrics.nominal.pearson import PearsonsContingencyCoefficient
1718
from torchmetrics.nominal.theils_u import TheilsU
1819
from torchmetrics.nominal.tschuprows import TschuprowsT
20+
from torchmetrics.utilities.imports import _SCIPI_AVAILABLE
21+
22+
if _SCIPI_AVAILABLE:
23+
import scipy.signal
24+
25+
# back compatibility patch due to SMRMpy using scipy.signal.hamming
26+
if not hasattr(scipy.signal, "hamming"):
27+
scipy.signal.hamming = scipy.signal.windows.hamming
1928

2029
__all__ = [
2130
"CramersV",

src/torchmetrics/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
_MECAB_KO_DIC_AVAILABLE = RequirementCache("mecab_ko_dic")
6565
_IPADIC_AVAILABLE = RequirementCache("ipadic")
6666
_SENTENCEPIECE_AVAILABLE = RequirementCache("sentencepiece")
67+
_SCIPI_AVAILABLE = RequirementCache("scipy")
6768
_SKLEARN_GREATER_EQUAL_1_3 = RequirementCache("scikit-learn>=1.3.0")
6869

6970
_LATEX_AVAILABLE: bool = shutil.which("latex") is not None

tests/unittests/classification/test_stat_scores.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,8 @@ def test_support_for_int():
582582
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/1970."""
583583
seed_all(42)
584584
metric = MulticlassStatScores(num_classes=4, average="none", multidim_average="samplewise", ignore_index=0)
585-
prediction = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8)
586-
label = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8)
585+
prediction = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8)
586+
label = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8)
587587
score = metric(preds=prediction, target=label)
588588
assert score.shape == (1, 4, 5)
589589

0 commit comments

Comments
 (0)