@@ -377,3 +377,99 @@ ATEN_INTERPOLATE_STATIC_ONLY_TEST(
377
377
%7 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %6)
378
378
return (%7))IR" ,
379
379
std::vector<int64_t >({10 , 2 , 2 , 2 , 2 }));
380
+
381
+ TEST (Converters, GridSampleConvertsCorrectly) {
382
+ const auto graph = R"IR(
383
+ graph(%input : Tensor, %grid : Tensor):
384
+ %5 : int = prim::Constant[value=2]()
385
+ %6 : int = prim::Constant[value=2]()
386
+ %7 : bool = prim::Constant[value=1]()
387
+ %8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7)
388
+ return (%8))IR" ;
389
+ auto g = std::make_shared<torch::jit::Graph>();
390
+
391
+ torch::jit::parseIR (graph, g.get ());
392
+
393
+ auto input = at::arange (16 ).view ({1 , 1 , 4 , 4 }).to (at::kFloat ).to (at::kCUDA );
394
+ auto d = at::linspace (-1 , 1 , 8 );
395
+ auto mesh = at::meshgrid ({d, d});
396
+ auto mesh_x = mesh[0 ];
397
+ auto mesh_y = mesh[1 ];
398
+ auto grid = at::stack ({mesh_x, mesh_y}, 2 ).unsqueeze (0 ).to (at::kCUDA );
399
+
400
+ auto trt_input = input.clone ();
401
+ auto trt_grid = grid.clone ();
402
+
403
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
404
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {input, grid});
405
+
406
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_input, trt_grid});
407
+
408
+ for (size_t i = 0 ; i < jit_results.size (); i++) {
409
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt_results[i], 2e-6 ));
410
+ }
411
+ }
412
+
413
+ TEST (Converters, GridSampleOptions1ConvertsCorrectly) {
414
+ const auto graph = R"IR(
415
+ graph(%input : Tensor, %grid : Tensor):
416
+ %5 : int = prim::Constant[value=1]()
417
+ %6 : int = prim::Constant[value=1]()
418
+ %7 : bool = prim::Constant[value=0]()
419
+ %8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7)
420
+ return (%8))IR" ;
421
+ auto g = std::make_shared<torch::jit::Graph>();
422
+
423
+ torch::jit::parseIR (graph, g.get ());
424
+
425
+ auto input = at::arange (16 ).view ({1 , 1 , 4 , 4 }).to (at::kFloat ).to (at::kCUDA );
426
+ auto d = at::linspace (-1 , 1 , 8 );
427
+ auto mesh = at::meshgrid ({d, d});
428
+ auto mesh_x = mesh[0 ];
429
+ auto mesh_y = mesh[1 ];
430
+ auto grid = at::stack ({mesh_x, mesh_y}, 2 ).unsqueeze (0 ).to (at::kCUDA );
431
+
432
+ auto trt_input = input.clone ();
433
+ auto trt_grid = grid.clone ();
434
+
435
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
436
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {input, grid});
437
+
438
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_input, trt_grid});
439
+
440
+ for (size_t i = 0 ; i < jit_results.size (); i++) {
441
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt_results[i], 2e-6 ));
442
+ }
443
+ }
444
+
445
+ TEST (Converters, GridSampleOptions2ConvertsCorrectly) {
446
+ const auto graph = R"IR(
447
+ graph(%input : Tensor, %grid : Tensor):
448
+ %5 : int = prim::Constant[value=0]()
449
+ %6 : int = prim::Constant[value=0]()
450
+ %7 : bool = prim::Constant[value=0]()
451
+ %8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7)
452
+ return (%8))IR" ;
453
+ auto g = std::make_shared<torch::jit::Graph>();
454
+
455
+ torch::jit::parseIR (graph, g.get ());
456
+
457
+ auto input = at::arange (16 ).view ({1 , 1 , 4 , 4 }).to (at::kFloat ).to (at::kCUDA );
458
+ auto d = at::linspace (-1 , 1 , 8 );
459
+ auto mesh = at::meshgrid ({d, d});
460
+ auto mesh_x = mesh[0 ];
461
+ auto mesh_y = mesh[1 ];
462
+ auto grid = at::stack ({mesh_x, mesh_y}, 2 ).unsqueeze (0 ).to (at::kCUDA );
463
+
464
+ auto trt_input = input.clone ();
465
+ auto trt_grid = grid.clone ();
466
+
467
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
468
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {input, grid});
469
+
470
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_input, trt_grid});
471
+
472
+ for (size_t i = 0 ; i < jit_results.size (); i++) {
473
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt_results[i], 2e-6 ));
474
+ }
475
+ }
0 commit comments