|
14 | 14 | from executorch.exir import EdgeCompileConfig, to_edge
|
15 | 15 |
|
16 | 16 | from executorch.exir.dialects._ops import ops
|
| 17 | +from torch import nn |
17 | 18 | from torch._export.verifier import SpecViolationError
|
18 | 19 | from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
19 | 20 | from torch.export import export
|
| 21 | +from torch.export.experimental import _export_forward_backward |
20 | 22 |
|
21 | 23 | from ..verifier import EXIREdgeDialectVerifier
|
22 | 24 |
|
@@ -123,3 +125,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
123 | 125 | dim_order_verifier(stride_edge_model.exported_program())
|
124 | 126 | with self.assertRaises(SpecViolationError):
|
125 | 127 | stride_verifier(dim_order_edge_model.exported_program())
|
| 128 | + |
| 129 | + def test_none_return_verifier(self) -> None: |
| 130 | + class Net(nn.Module): |
| 131 | + def __init__(self): |
| 132 | + super().__init__() |
| 133 | + self.conv1 = nn.Conv2d(6, 6, 5) |
| 134 | + self.linear = nn.Linear(6, 2) |
| 135 | + |
| 136 | + def forward(self, x): |
| 137 | + return self.linear(self.conv1(x).flatten(1)) |
| 138 | + |
| 139 | + class TrainingNet(nn.Module): |
| 140 | + def __init__(self, net): |
| 141 | + super().__init__() |
| 142 | + self.net = net |
| 143 | + self.loss = nn.CrossEntropyLoss() |
| 144 | + |
| 145 | + def forward(self, input, label): |
| 146 | + pred = self.net(input) |
| 147 | + return self.loss(pred, label) |
| 148 | + |
| 149 | + # conv returns (None, Tensor, Tensor) which is uncommon to see since |
| 150 | + # the schema is (Tensor, Tensor, Tensor). This is to test that |
| 151 | + # the verifier just ignores the None return value (since itll be |
| 152 | + # unused in the runtime). |
| 153 | + net = TrainingNet(Net()) |
| 154 | + inputs = (torch.randn(1, 6, 5, 5), torch.ones(1, dtype=torch.int64)) |
| 155 | + |
| 156 | + export_model = export(net, inputs) |
| 157 | + export_model = _export_forward_backward(export_model) |
| 158 | + |
| 159 | + edge = to_edge(export_model) |
| 160 | + |
| 161 | + edge_verifier = EXIREdgeDialectVerifier() |
| 162 | + |
| 163 | + edge_verifier(edge.exported_program()) |
0 commit comments