Skip to content

Commit ae61caa

Browse files
Update verifier to handle None Tensor outputs
Differential Revision: D69209059 Pull Request resolved: #8235
1 parent edf3952 commit ae61caa

File tree

2 files changed

+55
-6
lines changed

2 files changed

+55
-6
lines changed

exir/verification/arg_validator.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,23 @@ def call_function( # noqa: C901 # pyre-fixme[14]
108108
for schema_ret in target._schema.returns:
109109
name = schema_ret.name if schema_ret.name else f"__ret_{ret_index}"
110110
kernel_ret = next(ret_iter)
111-
# Return value should not be in OptionalTensor type, so only check torch.TensorType here.
112-
if isinstance(schema_ret.type, torch.TensorType) and isinstance(
113-
kernel_ret, torch.Tensor
114-
):
115-
tensor_arg_types[name] = kernel_ret.dtype
116-
ret_index += 1
111+
if isinstance(schema_ret.type, torch.TensorType):
112+
if isinstance(kernel_ret, torch.Tensor):
113+
tensor_arg_types[name] = kernel_ret.dtype
114+
ret_index += 1
115+
# Exceptionally rarely (basically only backwards ops) you might see an OptionalTensor returned.
116+
# The schema of these ops though is typically -> (Tensor, Tensor ...). So the actual type
117+
# returned in cpp is empty/undefined tensor. There is no analogy to this in python so it
118+
# gets crudely mapped to None. To properly fix this core pytorch would have to change the
119+
# schema to (Tensor?, ...) which is just never going to happen. So we have to handle this case
120+
# here in the verifier and in memory planning as well.
121+
elif kernel_ret is None:
122+
tensor_arg_types[name] = schema_ret.default_value
123+
ret_index += 1
124+
else:
125+
raise InternalError(
126+
f"encountered return with type Tensor but value wasnt a tensor or None. schema:{target._schema}, output:{ret_index}"
127+
)
117128
elif schema_ret.type == torch.ListType.ofTensors() and all(
118129
isinstance(kernel_ret[i], torch.Tensor) for i in range(len(kernel_ret))
119130
):

exir/verification/test/test_verifier.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
from executorch.exir import EdgeCompileConfig, to_edge
1515

1616
from executorch.exir.dialects._ops import ops
17+
from torch import nn
1718
from torch._export.verifier import SpecViolationError
1819
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1920
from torch.export import export
21+
from torch.export.experimental import _export_forward_backward
2022

2123
from ..verifier import EXIREdgeDialectVerifier
2224

@@ -123,3 +125,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
123125
dim_order_verifier(stride_edge_model.exported_program())
124126
with self.assertRaises(SpecViolationError):
125127
stride_verifier(dim_order_edge_model.exported_program())
128+
129+
def test_none_return_verifier(self) -> None:
130+
class Net(nn.Module):
131+
def __init__(self):
132+
super().__init__()
133+
self.conv1 = nn.Conv2d(6, 6, 5)
134+
self.linear = nn.Linear(6, 2)
135+
136+
def forward(self, x):
137+
return self.linear(self.conv1(x).flatten(1))
138+
139+
class TrainingNet(nn.Module):
140+
def __init__(self, net):
141+
super().__init__()
142+
self.net = net
143+
self.loss = nn.CrossEntropyLoss()
144+
145+
def forward(self, input, label):
146+
pred = self.net(input)
147+
return self.loss(pred, label)
148+
149+
# conv returns (None, Tensor, Tensor) which is uncommon to see since
150+
# the schema is (Tensor, Tensor, Tensor). This is to test that
151+
# the verifier just ignores the None return value (since itll be
152+
# unused in the runtime).
153+
net = TrainingNet(Net())
154+
inputs = (torch.randn(1, 6, 5, 5), torch.ones(1, dtype=torch.int64))
155+
156+
export_model = export(net, inputs)
157+
export_model = _export_forward_backward(export_model)
158+
159+
edge = to_edge(export_model)
160+
161+
edge_verifier = EXIREdgeDialectVerifier()
162+
163+
edge_verifier(edge.exported_program())

0 commit comments

Comments
 (0)