Skip to content

Deform conv2d mps support #9026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
93e044b
Start of branch
goldfishsound Oct 7, 2024
d838bf7
Setting up for development
goldfishsound Oct 10, 2024
95eb1cd
Initial commit for deform_conv2d for MPS
goldfishsound Oct 25, 2024
c53e1bd
New mps kernel for deform_conv2d and updated shader functions in kern…
goldfishsound Nov 12, 2024
1153b84
Renaming source file.
goldfishsound Nov 15, 2024
1c87a26
Changed part of the file name from _kernal to _kernel
goldfishsound Nov 15, 2024
8a984de
Remove files in product dir
goldfishsound Nov 16, 2024
970183d
Removing framework dir and included files.
goldfishsound Nov 16, 2024
2895f4f
Removing build_xcode dir and included files.
goldfishsound Nov 16, 2024
2f06f7f
Changed location references to pytorch
goldfishsound Nov 16, 2024
66d76d3
Clean up git - Removing .DS_Store
goldfishsound Nov 16, 2024
c8eb2ea
Altering the kernel deformable_im2col to mimic the cpp kernel impleme…
goldfishsound Nov 16, 2024
c92eaa4
Re-ordering include sequence
goldfishsound Nov 16, 2024
b445aed
Including mps in TestDeformConv::test_is_leaf_node
goldfishsound Nov 16, 2024
951880c
Updates gitignore
goldfishsound Dec 1, 2024
1aa7c0b
Merge branch 'main' into deform_conv2d_mps_support
goldfishsound Dec 1, 2024
83080da
Merge branch 'pytorch:main' into deform_conv2d_mps_support
goldfishsound Dec 2, 2024
9f68fd4
Update .gitignore
goldfishsound Dec 2, 2024
dc305ae
CleanUp
goldfishsound Dec 2, 2024
e25e620
Cleaned up - removed added exclusions.
goldfishsound Dec 4, 2024
e4fb8c5
Updated
goldfishsound Dec 4, 2024
e39867f
Removed CMakePresets.json
goldfishsound Dec 4, 2024
3e2bc0e
Updated to exclude CMakePresets.json
goldfishsound Dec 4, 2024
bd62ab3
Added bilinear_interpolate_2 function which is identical to the one u…
goldfishsound Mar 4, 2025
350454f
Reorganized the numbering of argumnet indexes in img2col
goldfishsound Mar 6, 2025
b31a28c
Added threadgroups_per_grid to deformable_col2im and deformable_col2i…
goldfishsound Mar 11, 2025
358dacc
Added printTensor utility function - only temporarily
goldfishsound Mar 11, 2025
7da876a
Modifying TestDeformConv to include mps tests.
goldfishsound Mar 11, 2025
da6134d
Merge branch 'pytorch:main' into deform_conv2d_mps_support
goldfishsound Apr 19, 2025
25a2944
House Cleaning:
goldfishsound Mar 11, 2025
bf7784d
Renaming of bilinear_interpolate2 to bilinear_interpolate_deform_conv2d
goldfishsound Apr 20, 2025
9d3105f
Added constant mps_backward_eps for eps in backward test.
goldfishsound Apr 20, 2025
501d617
Removed unused includes
goldfishsound Apr 20, 2025
a294c2e
Delete
goldfishsound Apr 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# CMAKE
CmakePresets.json

# MacOS
**/.DS_Store

build_xcode/
build/
dist/
framework/
torchvision.egg-info/
torchvision/version.py
*/**/__pycache__
@@ -10,6 +18,9 @@ torchvision/version.py
*/**/*~
*~

#Misc
collect_env.py

docs/build
# sphinx-gallery
docs/source/auto_examples/
5 changes: 3 additions & 2 deletions android/gradle/wrapper/gradle-wrapper.properties
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#Tue Aug 27 15:56:14 CEST 2024
distributionBase=GRADLE_USER_HOME
distributionUrl=https\://services.gradle.org/distributions/gradle-8.9-bin.zip
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
zipStoreBase=GRADLE_USER_HOME
15 changes: 15 additions & 0 deletions test/optest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# =========================================================
# BEGIN REPRO SCRIPT
# =========================================================
import torch
from torch.testing._internal.optests import opcheck

# Make sure you have loaded the library that contains the op
# via an import or torch.ops.load_library(...)
# op = torch.ops.torchvision.deform_conv2d.default
op = torch.ops.torchvision.roi_align.default
args, kwargs = torch.load("/var/folders/m7/m4jyvbb97ml6nftpw7b6fsk00000gn/T/pytorch_opcheck_safe_to_delete/repro_173109241941725.22.pt")
opcheck(op, args, kwargs, test_utils="test_autograd_registration")
# =========================================================
# END REPRO SCRIPT
# =========================================================
6 changes: 6 additions & 0 deletions test/playground/test_mps_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch
import torchvision as tv



print(torch.backends.mps.is_available())
47 changes: 32 additions & 15 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -929,7 +929,10 @@ def test_batched_nms_implementations(self, seed):

class TestDeformConv:
dtype = torch.float64

mps_dtype = torch.float32
mps_backward_atol = 2e-2
mps_backward_eps = 1e-3

def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
stride_h, stride_w = _pair(stride)
pad_h, pad_w = _pair(padding)
@@ -1041,7 +1044,7 @@ def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2,
)
return DeformConvModuleWrapper(obj) if wrap else obj

@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
def test_is_leaf_node(self, device):
op_obj = self.make_obj(wrap=True).to(device=device)
graph_node_names = get_graph_node_names(op_obj)
@@ -1050,12 +1053,17 @@ def test_is_leaf_node(self, device):
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64)) # , ids=str)
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.opcheck_only_one()
def test_forward(self, device, contiguous, batch_sz, dtype=None):
def test_forward(self, device, contiguous, batch_sz, dtype):
dtype = dtype or self.dtype

if device == "mps" and dtype is torch.float64:
pytest.skip("MPS does not support float64")

x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
in_channels = 6
out_channels = 2
@@ -1103,28 +1111,37 @@ def test_wrong_sizes(self):
wrong_mask = torch.rand_like(mask[:, :2])
layer(x, offset, wrong_mask)

@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.opcheck_only_one()
def test_backward(self, device, contiguous, batch_sz):
def test_backward(self, device, contiguous, batch_sz, deterministic=False):
# Batch size of zero fails a check un OperationUtils.mm because tensors with zero as a dimension
# cause the Placeholder::Placeholder to fail.
if device == "mps" and batch_sz == 0:
pytest.skip("MPS does not currently support zero batch size for backpropagation")

atol = self.mps_backward_atol if device == "mps" else 1e-05
dtype = self.mps_dtype if device == "mps" else self.dtype
eps = self.mps_backward_eps if device == "mps" else 1e-6

x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
device, contiguous, batch_sz, self.dtype
device, contiguous, batch_sz, dtype
)

def func(x_, offset_, mask_, weight_, bias_):
return ops.deform_conv2d(
x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_
)

gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True)

gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True, atol=atol, eps=eps)
def func_no_mask(x_, offset_, weight_, bias_):
return ops.deform_conv2d(
x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None
)

gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True)
gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True, atol=atol, eps=eps)

@torch.jit.script
def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
@@ -1137,7 +1154,7 @@ def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
(x, offset, mask, weight, bias),
nondet_tol=1e-5,
fast_mode=True,
fast_mode=True, eps=eps, atol=atol
)

@torch.jit.script
@@ -1151,7 +1168,7 @@ def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
(x, offset, weight, bias),
nondet_tol=1e-5,
fast_mode=True,
fast_mode=True, eps=eps, atol=atol
)

@needs_cuda
@@ -2035,4 +2052,4 @@ def test_is_leaf_node(self, dim, p, block_size, inplace):


if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__ + "::TestDeformConv::test_forward"])
917 changes: 917 additions & 0 deletions torchvision/csrc/ops/mps/deform_conv2d_kernel.mm

Large diffs are not rendered by default.

473 changes: 472 additions & 1 deletion torchvision/csrc/ops/mps/mps_kernels.h

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions torchvision/installTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <Accelerate/Accelerate.h>
#include <string>