Skip to content

Commit f9e6510

Browse files
jiayisunxEikanWang
authored andcommitted
add external OPs from MaskRCNN
1 parent 0564a19 commit f9e6510

File tree

7 files changed

+748
-0
lines changed

7 files changed

+748
-0
lines changed

intel_pytorch_extension_py/ops/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@
66
from .jit import *
77
from .save import *
88
from .to import *
9+
from .roi_align import ROIAlign
10+
from .roi_align import roi_align
11+
from .nms import nms
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import _torch_ipex as core
2+
3+
nms = core.nms
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2+
import torch
3+
from torch import nn
4+
from torch.autograd import Function
5+
from torch.autograd.function import once_differentiable
6+
from torch.nn.modules.utils import _pair
7+
8+
import _torch_ipex as core
9+
10+
11+
class _ROIAlign(Function):
12+
@staticmethod
13+
def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio):
14+
ctx.save_for_backward(roi)
15+
ctx.output_size = _pair(output_size)
16+
ctx.spatial_scale = spatial_scale
17+
ctx.sampling_ratio = sampling_ratio
18+
ctx.input_shape = input.size()
19+
output = core.roi_align_forward(
20+
input, roi, spatial_scale, output_size[0], output_size[1], sampling_ratio
21+
)
22+
return output
23+
24+
@staticmethod
25+
@once_differentiable
26+
def backward(ctx, grad_output):
27+
rois, = ctx.saved_tensors
28+
output_size = ctx.output_size
29+
spatial_scale = ctx.spatial_scale
30+
sampling_ratio = ctx.sampling_ratio
31+
bs, ch, h, w = ctx.input_shape
32+
grad_input = core.roi_align_backward(
33+
grad_output,
34+
rois,
35+
spatial_scale,
36+
output_size[0],
37+
output_size[1],
38+
bs,
39+
ch,
40+
h,
41+
w,
42+
sampling_ratio,
43+
)
44+
return grad_input, None, None, None, None
45+
46+
47+
roi_align = _ROIAlign.apply
48+
49+
50+
class ROIAlign(nn.Module):
51+
def __init__(self, output_size, spatial_scale, sampling_ratio):
52+
super(ROIAlign, self).__init__()
53+
self.output_size = output_size
54+
self.spatial_scale = spatial_scale
55+
self.sampling_ratio = sampling_ratio
56+
57+
def forward(self, input, rois):
58+
return roi_align(
59+
input, rois, self.output_size, self.spatial_scale, self.sampling_ratio
60+
)
61+
62+
def __repr__(self):
63+
tmpstr = self.__class__.__name__ + "("
64+
tmpstr += "output_size=" + str(self.output_size)
65+
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
66+
tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
67+
tmpstr += ")"
68+
return tmpstr

torch_ipex/csrc/cpu/ExternalOPs.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Autogenerated file by gen-common-ops.py. Do not edit directly!
2+
#pragma once
3+
4+
#include <ATen/Tensor.h>
5+
6+
namespace torch_ipex {
7+
8+
class IpexExternal {
9+
public:
10+
static at::Tensor ROIAlign_forward(const at::Tensor& input,
11+
const at::Tensor& rois,
12+
const float spatial_scale,
13+
const int pooled_height,
14+
const int pooled_width,
15+
const int sampling_ratio);
16+
17+
static at::Tensor ROIAlign_backward(const at::Tensor& grad,
18+
const at::Tensor& rois,
19+
const float spatial_scale,
20+
const int pooled_height,
21+
const int pooled_width,
22+
const int batch_size,
23+
const int channels,
24+
const int height,
25+
const int width,
26+
const int sampling_ratio);
27+
28+
static at::Tensor nms(const at::Tensor& dets,
29+
const at::Tensor& scores,
30+
const float threshold);
31+
};
32+
33+
} // namespace torch_ipex

0 commit comments

Comments
 (0)