Skip to content

Commit e336c81

Browse files
author
pytorchbot
committed
2024-04-19 nightly release (e90d070)
1 parent ce61976 commit e336c81

File tree

4 files changed

+130
-1
lines changed

4 files changed

+130
-1
lines changed

core/conversion/converters/impl/interpolate.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,37 @@ auto interpolate_registrations TORCHTRT_UNUSED =
520520
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners);
521521
}
522522

523+
return true;
524+
}})
525+
.pattern(
526+
{"aten::grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor",
527+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
528+
auto in = args[0].ITensorOrFreeze(ctx);
529+
auto grid = args[1].ITensorOrFreeze(ctx);
530+
auto interpolation_mode = args[2].unwrapToInt();
531+
auto padding_mode = args[3].unwrapToInt();
532+
auto align_corners = args[4].unwrapToBool();
533+
534+
static const auto sample_map = std::map<int, nvinfer1::SampleMode>{
535+
{0, nvinfer1::SampleMode::kFILL},
536+
{1, nvinfer1::SampleMode::kCLAMP},
537+
{2, nvinfer1::SampleMode::kREFLECT}};
538+
539+
static const auto interpolation_map = std::map<int, nvinfer1::InterpolationMode>{
540+
{0, nvinfer1::InterpolationMode::kLINEAR},
541+
{1, nvinfer1::InterpolationMode::kNEAREST},
542+
{2, nvinfer1::InterpolationMode::kCUBIC}};
543+
544+
auto grid_sample_layer = ctx->net->addGridSample(*in, *grid);
545+
TORCHTRT_CHECK(
546+
grid_sample_layer, "Unable to create grid_sample layer from node: " << util::node_info(n));
547+
548+
grid_sample_layer->setAlignCorners(align_corners);
549+
grid_sample_layer->setSampleMode(sample_map.at(padding_mode));
550+
grid_sample_layer->setInterpolationMode(interpolation_map.at(interpolation_mode));
551+
552+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], grid_sample_layer->getOutput(0));
553+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
523554
return true;
524555
}});
525556

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2623,6 +2623,7 @@ def aten_ops_pad(
26232623
)
26242624

26252625

2626+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.default)
26262627
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.vec)
26272628
def upsample_nearest2d(
26282629
ctx: ConversionContext,
@@ -2644,6 +2645,7 @@ def upsample_nearest2d(
26442645
)
26452646

26462647

2648+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.default)
26472649
@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.vec)
26482650
def upsample_bilinear2d(
26492651
ctx: ConversionContext,

tests/core/conversion/converters/test_interpolate.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,99 @@ ATEN_INTERPOLATE_STATIC_ONLY_TEST(
377377
%7 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %6)
378378
return (%7))IR",
379379
std::vector<int64_t>({10, 2, 2, 2, 2}));
380+
381+
TEST(Converters, GridSampleConvertsCorrectly) {
382+
const auto graph = R"IR(
383+
graph(%input : Tensor, %grid : Tensor):
384+
%5 : int = prim::Constant[value=2]()
385+
%6 : int = prim::Constant[value=2]()
386+
%7 : bool = prim::Constant[value=1]()
387+
%8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7)
388+
return (%8))IR";
389+
auto g = std::make_shared<torch::jit::Graph>();
390+
391+
torch::jit::parseIR(graph, g.get());
392+
393+
auto input = at::arange(16).view({1, 1, 4, 4}).to(at::kFloat).to(at::kCUDA);
394+
auto d = at::linspace(-1, 1, 8);
395+
auto mesh = at::meshgrid({d, d});
396+
auto mesh_x = mesh[0];
397+
auto mesh_y = mesh[1];
398+
auto grid = at::stack({mesh_x, mesh_y}, 2).unsqueeze(0).to(at::kCUDA);
399+
400+
auto trt_input = input.clone();
401+
auto trt_grid = grid.clone();
402+
403+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
404+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {input, grid});
405+
406+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_input, trt_grid});
407+
408+
for (size_t i = 0; i < jit_results.size(); i++) {
409+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6));
410+
}
411+
}
412+
413+
TEST(Converters, GridSampleOptions1ConvertsCorrectly) {
414+
const auto graph = R"IR(
415+
graph(%input : Tensor, %grid : Tensor):
416+
%5 : int = prim::Constant[value=1]()
417+
%6 : int = prim::Constant[value=1]()
418+
%7 : bool = prim::Constant[value=0]()
419+
%8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7)
420+
return (%8))IR";
421+
auto g = std::make_shared<torch::jit::Graph>();
422+
423+
torch::jit::parseIR(graph, g.get());
424+
425+
auto input = at::arange(16).view({1, 1, 4, 4}).to(at::kFloat).to(at::kCUDA);
426+
auto d = at::linspace(-1, 1, 8);
427+
auto mesh = at::meshgrid({d, d});
428+
auto mesh_x = mesh[0];
429+
auto mesh_y = mesh[1];
430+
auto grid = at::stack({mesh_x, mesh_y}, 2).unsqueeze(0).to(at::kCUDA);
431+
432+
auto trt_input = input.clone();
433+
auto trt_grid = grid.clone();
434+
435+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
436+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {input, grid});
437+
438+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_input, trt_grid});
439+
440+
for (size_t i = 0; i < jit_results.size(); i++) {
441+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6));
442+
}
443+
}
444+
445+
TEST(Converters, GridSampleOptions2ConvertsCorrectly) {
446+
const auto graph = R"IR(
447+
graph(%input : Tensor, %grid : Tensor):
448+
%5 : int = prim::Constant[value=0]()
449+
%6 : int = prim::Constant[value=0]()
450+
%7 : bool = prim::Constant[value=0]()
451+
%8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7)
452+
return (%8))IR";
453+
auto g = std::make_shared<torch::jit::Graph>();
454+
455+
torch::jit::parseIR(graph, g.get());
456+
457+
auto input = at::arange(16).view({1, 1, 4, 4}).to(at::kFloat).to(at::kCUDA);
458+
auto d = at::linspace(-1, 1, 8);
459+
auto mesh = at::meshgrid({d, d});
460+
auto mesh_x = mesh[0];
461+
auto mesh_y = mesh[1];
462+
auto grid = at::stack({mesh_x, mesh_y}, 2).unsqueeze(0).to(at::kCUDA);
463+
464+
auto trt_input = input.clone();
465+
auto trt_grid = grid.clone();
466+
467+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
468+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {input, grid});
469+
470+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_input, trt_grid});
471+
472+
for (size_t i = 0; i < jit_results.size(); i++) {
473+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6));
474+
}
475+
}

tools/perf/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ numpy
22
argparse
33
pyyaml
44
onnx
5-
transformers==4.33.2
5+
transformers==4.36.0
66
diffusers==0.21.4
77
pandas==2.0.1
88
timm==0.9.8

0 commit comments

Comments
 (0)