File tree Expand file tree Collapse file tree 1 file changed +7
-8
lines changed Expand file tree Collapse file tree 1 file changed +7
-8
lines changed Original file line number Diff line number Diff line change @@ -475,14 +475,13 @@ def forward(self, x):
475
475
optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
476
476
torch_model_results = fx_graph (* inputs ).detach ().cpu ()
477
477
478
- max_diff = float (
479
- torch .max (torch .abs (optimized_model_results - torch_model_results ))
480
- )
481
- self .assertAlmostEqual (
482
- max_diff ,
483
- 0 ,
484
- DECIMALS_OF_AGREEMENT ,
485
- f"Select_scatter TRT outputs don't match with the original model." ,
478
+ optimized_model_results_shape = optimized_model_results .size ()
479
+ torch_model_results_shape = torch_model_results .size ()
480
+
481
+ self .assertEquals (
482
+ optimized_model_results_shape ,
483
+ torch_model_results_shape ,
484
+ f"The optimized model results shape and torch model results shape should be equal in empty_like" ,
486
485
)
487
486
488
487
You can’t perform that action at this time.
0 commit comments