Skip to content

Commit 5638657

Browse files
authored
Add MSE numerical comparator
Differential Revision: D76781331 Pull Request resolved: #11759
1 parent 7503bb3 commit 5638657

File tree

5 files changed

+110
-3
lines changed

5 files changed

+110
-3
lines changed

devtools/inspector/numerical_comparator/TARGETS

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ python_library(
99
deps = [],
1010
)
1111

12-
1312
python_library(
1413
name = "l1_numerical_comparator",
1514
srcs = ["l1_numerical_comparator.py"],
@@ -19,12 +18,20 @@ python_library(
1918
],
2019
)
2120

22-
21+
python_library(
22+
name = "mse_numerical_comparator",
23+
srcs = ["mse_numerical_comparator.py"],
24+
deps = [
25+
"//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base",
26+
"//executorch/devtools/inspector:lib",
27+
],
28+
)
2329

2430
python_library(
2531
name = "lib",
2632
srcs = ["__init__.py"],
2733
deps = [
2834
":l1_numerical_comparator",
35+
":mse_numerical_comparator",
2936
],
3037
)

devtools/inspector/numerical_comparator/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,9 @@
99
L1Comparator,
1010
)
1111

12+
from executorch.devtools.inspector.numerical_comparator.mse_numerical_comparator import (
13+
MSEComparator,
14+
)
15+
1216

13-
__all__ = ["L1Comparator"]
17+
__all__ = ["L1Comparator", "MSEComparator"]
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any
8+
9+
import torch
10+
from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor
11+
from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import (
12+
NumericalComparatorBase,
13+
)
14+
15+
16+
class MSEComparator(NumericalComparatorBase):
17+
def compare(self, a: Any, b: Any) -> float:
18+
"""Compare mean squared difference between two outputs."""
19+
20+
t_a = convert_to_float_tensor(a)
21+
t_b = convert_to_float_tensor(b)
22+
if torch.isnan(t_a).any() or torch.isnan(t_b).any():
23+
t_a = torch.nan_to_num(t_a)
24+
t_b = torch.nan_to_num(t_b)
25+
26+
try:
27+
res = float(torch.mean(torch.square(t_a - t_b)))
28+
except Exception as e:
29+
raise ValueError(
30+
f"Error computing MSE difference between tensors: {str(e)}"
31+
)
32+
return res

devtools/inspector/tests/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ python_unittest(
6262
],
6363
)
6464

65+
python_unittest(
66+
name = "mse_comparator_test",
67+
srcs = ["mse_comparator_test.py"],
68+
deps = [
69+
"//executorch/devtools/inspector/numerical_comparator:lib",
70+
],
71+
)
72+
6573
python_library(
6674
name = "inspector_test_utils",
6775
srcs = [
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from executorch.devtools.inspector.numerical_comparator import MSEComparator
12+
13+
14+
class TestMSEComparator(unittest.TestCase):
15+
mse_comparator = MSEComparator()
16+
17+
def test_identical_tensors(self):
18+
a = torch.tensor([[10, 4], [3, 4]])
19+
b = torch.tensor([[10, 4], [3, 4]])
20+
expected = 0.0
21+
result = self.mse_comparator.compare(a, b)
22+
self.assertAlmostEqual(result, expected)
23+
24+
def test_scalar(self):
25+
a = 10
26+
b = 2
27+
expected = 64.0
28+
result = self.mse_comparator.compare(a, b)
29+
self.assertAlmostEqual(result, expected)
30+
31+
def test_with_nans_replaced_with_zero(self):
32+
a = torch.tensor([3, 1, -3, float("nan")])
33+
b = torch.tensor([float("nan"), 0, -3, 2])
34+
expected = (9.0 + 1.0 + 0.0 + 4.0) / 4.0
35+
result = self.mse_comparator.compare(a, b)
36+
self.assertAlmostEqual(result, expected)
37+
38+
def test_shape_mismatch_raises_exception(self):
39+
a = torch.tensor([0, 2, -1])
40+
b = torch.tensor([1, 1, -3, 4])
41+
with self.assertRaises(ValueError):
42+
self.mse_comparator.compare(a, b)
43+
44+
def test_2D_tensors(self):
45+
a = torch.tensor([[4, 9], [6, 4]])
46+
b = torch.tensor([[1, 2], [3, 10]])
47+
expected = (9.0 + 49.0 + 9.0 + 36.0) / 4.0
48+
result = self.mse_comparator.compare(a, b)
49+
self.assertAlmostEqual(result, expected)
50+
51+
def test_list_of_tensors(self):
52+
a = [torch.tensor([2, 4]), torch.tensor([15, 2])]
53+
b = [torch.tensor([1, 2]), torch.tensor([9, 5])]
54+
expected = (1.0 + 4.0 + 36.0 + 9.0) / 4.0
55+
result = self.mse_comparator.compare(a, b)
56+
self.assertAlmostEqual(result, expected)

0 commit comments

Comments
 (0)