From 93e044b0cfa0c831877c161f1409ae87b8398814 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Mon, 7 Oct 2024 16:21:44 +0200 Subject: [PATCH 01/31] Start of branch --- .DS_Store | Bin 0 -> 6148 bytes .../gradle/wrapper/gradle-wrapper.properties | 5 +- framework/.DS_Store | Bin 0 -> 6148 bytes framework/include/.DS_Store | Bin 0 -> 6148 bytes framework/include/torchvision/.DS_Store | Bin 0 -> 6148 bytes .../torchvision/io/image/cpu/common_jpeg.cpp | 26 + .../torchvision/io/image/cpu/common_jpeg.h | 27 + .../torchvision/io/image/cpu/common_png.h | 6 + .../torchvision/io/image/cpu/decode_image.cpp | 41 + .../torchvision/io/image/cpu/decode_image.h | 15 + .../torchvision/io/image/cpu/decode_jpeg.cpp | 271 ++++ .../torchvision/io/image/cpu/decode_jpeg.h | 18 + .../torchvision/io/image/cpu/decode_png.cpp | 259 ++++ .../torchvision/io/image/cpu/decode_png.h | 16 + .../torchvision/io/image/cpu/encode_jpeg.cpp | 113 ++ .../torchvision/io/image/cpu/encode_jpeg.h | 13 + .../torchvision/io/image/cpu/encode_png.cpp | 180 +++ .../torchvision/io/image/cpu/encode_png.h | 13 + .../include/torchvision/io/image/cpu/exif.h | 264 ++++ .../io/image/cpu/read_write_file.cpp | 108 ++ .../io/image/cpu/read_write_file.h | 13 + .../io/image/cuda/decode_jpeg_cuda.cpp | 208 +++ .../io/image/cuda/decode_jpeg_cuda.h | 15 + .../include/torchvision/io/image/image.cpp | 39 + .../include/torchvision/io/image/image.h | 9 + .../torchvision/io/image/image_read_mode.h | 17 + framework/include/torchvision/macros.h | 22 + framework/include/torchvision/ops/.DS_Store | Bin 0 -> 6148 bytes .../ops/autograd/deform_conv2d_kernel.cpp | 266 ++++ .../ops/autograd/ps_roi_align_kernel.cpp | 167 +++ .../ops/autograd/ps_roi_pool_kernel.cpp | 152 ++ .../ops/autograd/roi_align_kernel.cpp | 167 +++ .../ops/autograd/roi_pool_kernel.cpp | 152 ++ .../ops/cpu/deform_conv2d_kernel.cpp | 1172 +++++++++++++++ .../torchvision/ops/cpu/nms_kernel.cpp | 117 ++ .../ops/cpu/ps_roi_align_kernel.cpp | 429 ++++++ .../ops/cpu/ps_roi_pool_kernel.cpp | 273 ++++ .../torchvision/ops/cpu/roi_align_common.h | 128 ++ .../torchvision/ops/cpu/roi_align_kernel.cpp | 400 +++++ .../torchvision/ops/cpu/roi_pool_kernel.cpp | 249 ++++ .../include/torchvision/ops/deform_conv2d.cpp | 172 +++ .../include/torchvision/ops/deform_conv2d.h | 82 ++ .../include/torchvision/ops/mps/mps_helpers.h | 6 + .../include/torchvision/ops/mps/mps_kernels.h | 1102 ++++++++++++++ .../include/torchvision/ops/mps/nms_kernel.mm | 109 ++ .../ops/mps/ps_roi_align_kernel.mm | 205 +++ .../torchvision/ops/mps/ps_roi_pool_kernel.mm | 200 +++ .../torchvision/ops/mps/roi_align_kernel.mm | 197 +++ .../torchvision/ops/mps/roi_pool_kernel.mm | 196 +++ framework/include/torchvision/ops/nms.cpp | 27 + framework/include/torchvision/ops/nms.h | 15 + framework/include/torchvision/ops/ops.h | 8 + .../include/torchvision/ops/ps_roi_align.cpp | 112 ++ .../include/torchvision/ops/ps_roi_align.h | 56 + .../include/torchvision/ops/ps_roi_pool.cpp | 104 ++ .../include/torchvision/ops/ps_roi_pool.h | 52 + .../include/torchvision/ops/roi_align.cpp | 132 ++ framework/include/torchvision/ops/roi_align.h | 58 + .../include/torchvision/ops/roi_pool.cpp | 102 ++ framework/include/torchvision/ops/roi_pool.h | 52 + framework/include/torchvision/vision.cpp | 41 + framework/include/torchvision/vision.h | 16 + .../cmake/TorchVision/TorchVisionConfig.cmake | 82 ++ .../TorchVisionConfigVersion.cmake | 43 + .../TorchVisionTargets-noconfig.cmake | 20 + .../TorchVision/TorchVisionTargets.cmake | 102 ++ product/.DS_Store | Bin 0 -> 6148 bytes .../torchvision/io/image/cpu/common_jpeg.cpp | 26 + .../torchvision/io/image/cpu/common_jpeg.h | 27 + .../torchvision/io/image/cpu/common_png.h | 6 + .../torchvision/io/image/cpu/decode_avif.cpp | 92 ++ .../torchvision/io/image/cpu/decode_avif.h | 11 + .../torchvision/io/image/cpu/decode_gif.cpp | 173 +++ .../torchvision/io/image/cpu/decode_gif.h | 12 + .../torchvision/io/image/cpu/decode_image.cpp | 77 + .../torchvision/io/image/cpu/decode_image.h | 15 + .../torchvision/io/image/cpu/decode_jpeg.cpp | 271 ++++ .../torchvision/io/image/cpu/decode_jpeg.h | 18 + .../torchvision/io/image/cpu/decode_png.cpp | 232 +++ .../torchvision/io/image/cpu/decode_png.h | 15 + .../torchvision/io/image/cpu/decode_webp.cpp | 40 + .../torchvision/io/image/cpu/decode_webp.h | 11 + .../torchvision/io/image/cpu/encode_jpeg.cpp | 113 ++ .../torchvision/io/image/cpu/encode_jpeg.h | 13 + .../torchvision/io/image/cpu/encode_png.cpp | 180 +++ .../torchvision/io/image/cpu/encode_png.h | 13 + .../include/torchvision/io/image/cpu/exif.h | 256 ++++ .../io/image/cpu/giflib/dgif_lib.c | 1312 +++++++++++++++++ .../io/image/cpu/giflib/gif_hash.c | 128 ++ .../io/image/cpu/giflib/gif_hash.h | 42 + .../torchvision/io/image/cpu/giflib/gif_lib.h | 291 ++++ .../io/image/cpu/giflib/gif_lib_private.h | 72 + .../io/image/cpu/giflib/gifalloc.c | 425 ++++++ .../image/cpu/giflib/openbsd-reallocarray.c | 73 + .../io/image/cpu/read_write_file.cpp | 108 ++ .../io/image/cpu/read_write_file.h | 13 + .../io/image/cuda/decode_jpegs_cuda.cpp | 603 ++++++++ .../io/image/cuda/decode_jpegs_cuda.h | 45 + .../io/image/cuda/encode_decode_jpegs_cuda.h | 59 + .../io/image/cuda/encode_jpegs_cuda.cpp | 274 ++++ .../io/image/cuda/encode_jpegs_cuda.h | 33 + .../include/torchvision/io/image/image.cpp | 37 + product/include/torchvision/io/image/image.h | 12 + .../torchvision/io/image/image_read_mode.h | 17 + product/include/torchvision/macros.h | 11 + .../ops/autograd/deform_conv2d_kernel.cpp | 266 ++++ .../ops/autograd/ps_roi_align_kernel.cpp | 167 +++ .../ops/autograd/ps_roi_pool_kernel.cpp | 152 ++ .../ops/autograd/roi_align_kernel.cpp | 167 +++ .../ops/autograd/roi_pool_kernel.cpp | 152 ++ .../ops/cpu/deform_conv2d_kernel.cpp | 1172 +++++++++++++++ .../torchvision/ops/cpu/nms_kernel.cpp | 117 ++ .../ops/cpu/ps_roi_align_kernel.cpp | 429 ++++++ .../ops/cpu/ps_roi_pool_kernel.cpp | 273 ++++ .../torchvision/ops/cpu/roi_align_common.h | 128 ++ .../torchvision/ops/cpu/roi_align_kernel.cpp | 400 +++++ .../torchvision/ops/cpu/roi_pool_kernel.cpp | 249 ++++ .../include/torchvision/ops/deform_conv2d.cpp | 172 +++ .../include/torchvision/ops/deform_conv2d.h | 82 ++ .../include/torchvision/ops/mps/mps_helpers.h | 6 + .../include/torchvision/ops/mps/mps_kernels.h | 1102 ++++++++++++++ .../include/torchvision/ops/mps/nms_kernel.mm | 109 ++ .../ops/mps/ps_roi_align_kernel.mm | 205 +++ .../torchvision/ops/mps/ps_roi_pool_kernel.mm | 200 +++ .../torchvision/ops/mps/roi_align_kernel.mm | 197 +++ .../torchvision/ops/mps/roi_pool_kernel.mm | 196 +++ product/include/torchvision/ops/nms.cpp | 28 + product/include/torchvision/ops/nms.h | 15 + product/include/torchvision/ops/ops.h | 8 + .../include/torchvision/ops/ps_roi_align.cpp | 112 ++ .../include/torchvision/ops/ps_roi_align.h | 56 + .../include/torchvision/ops/ps_roi_pool.cpp | 104 ++ product/include/torchvision/ops/ps_roi_pool.h | 52 + product/include/torchvision/ops/roi_align.cpp | 132 ++ product/include/torchvision/ops/roi_align.h | 58 + product/include/torchvision/ops/roi_pool.cpp | 102 ++ product/include/torchvision/ops/roi_pool.h | 52 + product/include/torchvision/vision.cpp | 32 + product/include/torchvision/vision.h | 12 + .../cmake/TorchVision/TorchVisionConfig.cmake | 74 + .../TorchVisionConfigVersion.cmake | 43 + .../TorchVisionTargets-noconfig.cmake | 20 + .../TorchVision/TorchVisionTargets.cmake | 102 ++ test/playground/test_mps_import.py | 6 + torchvision/installTest.cpp | 5 + 145 files changed, 20769 insertions(+), 2 deletions(-) create mode 100644 .DS_Store create mode 100644 framework/.DS_Store create mode 100644 framework/include/.DS_Store create mode 100644 framework/include/torchvision/.DS_Store create mode 100644 framework/include/torchvision/io/image/cpu/common_jpeg.cpp create mode 100644 framework/include/torchvision/io/image/cpu/common_jpeg.h create mode 100644 framework/include/torchvision/io/image/cpu/common_png.h create mode 100644 framework/include/torchvision/io/image/cpu/decode_image.cpp create mode 100644 framework/include/torchvision/io/image/cpu/decode_image.h create mode 100644 framework/include/torchvision/io/image/cpu/decode_jpeg.cpp create mode 100644 framework/include/torchvision/io/image/cpu/decode_jpeg.h create mode 100644 framework/include/torchvision/io/image/cpu/decode_png.cpp create mode 100644 framework/include/torchvision/io/image/cpu/decode_png.h create mode 100644 framework/include/torchvision/io/image/cpu/encode_jpeg.cpp create mode 100644 framework/include/torchvision/io/image/cpu/encode_jpeg.h create mode 100644 framework/include/torchvision/io/image/cpu/encode_png.cpp create mode 100644 framework/include/torchvision/io/image/cpu/encode_png.h create mode 100644 framework/include/torchvision/io/image/cpu/exif.h create mode 100644 framework/include/torchvision/io/image/cpu/read_write_file.cpp create mode 100644 framework/include/torchvision/io/image/cpu/read_write_file.h create mode 100644 framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.cpp create mode 100644 framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.h create mode 100644 framework/include/torchvision/io/image/image.cpp create mode 100644 framework/include/torchvision/io/image/image.h create mode 100644 framework/include/torchvision/io/image/image_read_mode.h create mode 100644 framework/include/torchvision/macros.h create mode 100644 framework/include/torchvision/ops/.DS_Store create mode 100644 framework/include/torchvision/ops/autograd/deform_conv2d_kernel.cpp create mode 100644 framework/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp create mode 100644 framework/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp create mode 100644 framework/include/torchvision/ops/autograd/roi_align_kernel.cpp create mode 100644 framework/include/torchvision/ops/autograd/roi_pool_kernel.cpp create mode 100644 framework/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp create mode 100644 framework/include/torchvision/ops/cpu/nms_kernel.cpp create mode 100644 framework/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp create mode 100644 framework/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp create mode 100644 framework/include/torchvision/ops/cpu/roi_align_common.h create mode 100644 framework/include/torchvision/ops/cpu/roi_align_kernel.cpp create mode 100644 framework/include/torchvision/ops/cpu/roi_pool_kernel.cpp create mode 100644 framework/include/torchvision/ops/deform_conv2d.cpp create mode 100644 framework/include/torchvision/ops/deform_conv2d.h create mode 100644 framework/include/torchvision/ops/mps/mps_helpers.h create mode 100644 framework/include/torchvision/ops/mps/mps_kernels.h create mode 100644 framework/include/torchvision/ops/mps/nms_kernel.mm create mode 100644 framework/include/torchvision/ops/mps/ps_roi_align_kernel.mm create mode 100644 framework/include/torchvision/ops/mps/ps_roi_pool_kernel.mm create mode 100644 framework/include/torchvision/ops/mps/roi_align_kernel.mm create mode 100644 framework/include/torchvision/ops/mps/roi_pool_kernel.mm create mode 100644 framework/include/torchvision/ops/nms.cpp create mode 100644 framework/include/torchvision/ops/nms.h create mode 100644 framework/include/torchvision/ops/ops.h create mode 100644 framework/include/torchvision/ops/ps_roi_align.cpp create mode 100644 framework/include/torchvision/ops/ps_roi_align.h create mode 100644 framework/include/torchvision/ops/ps_roi_pool.cpp create mode 100644 framework/include/torchvision/ops/ps_roi_pool.h create mode 100644 framework/include/torchvision/ops/roi_align.cpp create mode 100644 framework/include/torchvision/ops/roi_align.h create mode 100644 framework/include/torchvision/ops/roi_pool.cpp create mode 100644 framework/include/torchvision/ops/roi_pool.h create mode 100644 framework/include/torchvision/vision.cpp create mode 100644 framework/include/torchvision/vision.h create mode 100644 framework/share/cmake/TorchVision/TorchVisionConfig.cmake create mode 100644 framework/share/cmake/TorchVision/TorchVisionConfigVersion.cmake create mode 100644 framework/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake create mode 100644 framework/share/cmake/TorchVision/TorchVisionTargets.cmake create mode 100644 product/.DS_Store create mode 100644 product/include/torchvision/io/image/cpu/common_jpeg.cpp create mode 100644 product/include/torchvision/io/image/cpu/common_jpeg.h create mode 100644 product/include/torchvision/io/image/cpu/common_png.h create mode 100644 product/include/torchvision/io/image/cpu/decode_avif.cpp create mode 100644 product/include/torchvision/io/image/cpu/decode_avif.h create mode 100644 product/include/torchvision/io/image/cpu/decode_gif.cpp create mode 100644 product/include/torchvision/io/image/cpu/decode_gif.h create mode 100644 product/include/torchvision/io/image/cpu/decode_image.cpp create mode 100644 product/include/torchvision/io/image/cpu/decode_image.h create mode 100644 product/include/torchvision/io/image/cpu/decode_jpeg.cpp create mode 100644 product/include/torchvision/io/image/cpu/decode_jpeg.h create mode 100644 product/include/torchvision/io/image/cpu/decode_png.cpp create mode 100644 product/include/torchvision/io/image/cpu/decode_png.h create mode 100644 product/include/torchvision/io/image/cpu/decode_webp.cpp create mode 100644 product/include/torchvision/io/image/cpu/decode_webp.h create mode 100644 product/include/torchvision/io/image/cpu/encode_jpeg.cpp create mode 100644 product/include/torchvision/io/image/cpu/encode_jpeg.h create mode 100644 product/include/torchvision/io/image/cpu/encode_png.cpp create mode 100644 product/include/torchvision/io/image/cpu/encode_png.h create mode 100644 product/include/torchvision/io/image/cpu/exif.h create mode 100644 product/include/torchvision/io/image/cpu/giflib/dgif_lib.c create mode 100644 product/include/torchvision/io/image/cpu/giflib/gif_hash.c create mode 100644 product/include/torchvision/io/image/cpu/giflib/gif_hash.h create mode 100644 product/include/torchvision/io/image/cpu/giflib/gif_lib.h create mode 100644 product/include/torchvision/io/image/cpu/giflib/gif_lib_private.h create mode 100644 product/include/torchvision/io/image/cpu/giflib/gifalloc.c create mode 100644 product/include/torchvision/io/image/cpu/giflib/openbsd-reallocarray.c create mode 100644 product/include/torchvision/io/image/cpu/read_write_file.cpp create mode 100644 product/include/torchvision/io/image/cpu/read_write_file.h create mode 100644 product/include/torchvision/io/image/cuda/decode_jpegs_cuda.cpp create mode 100644 product/include/torchvision/io/image/cuda/decode_jpegs_cuda.h create mode 100644 product/include/torchvision/io/image/cuda/encode_decode_jpegs_cuda.h create mode 100644 product/include/torchvision/io/image/cuda/encode_jpegs_cuda.cpp create mode 100644 product/include/torchvision/io/image/cuda/encode_jpegs_cuda.h create mode 100644 product/include/torchvision/io/image/image.cpp create mode 100644 product/include/torchvision/io/image/image.h create mode 100644 product/include/torchvision/io/image/image_read_mode.h create mode 100644 product/include/torchvision/macros.h create mode 100644 product/include/torchvision/ops/autograd/deform_conv2d_kernel.cpp create mode 100644 product/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp create mode 100644 product/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp create mode 100644 product/include/torchvision/ops/autograd/roi_align_kernel.cpp create mode 100644 product/include/torchvision/ops/autograd/roi_pool_kernel.cpp create mode 100644 product/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp create mode 100644 product/include/torchvision/ops/cpu/nms_kernel.cpp create mode 100644 product/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp create mode 100644 product/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp create mode 100644 product/include/torchvision/ops/cpu/roi_align_common.h create mode 100644 product/include/torchvision/ops/cpu/roi_align_kernel.cpp create mode 100644 product/include/torchvision/ops/cpu/roi_pool_kernel.cpp create mode 100644 product/include/torchvision/ops/deform_conv2d.cpp create mode 100644 product/include/torchvision/ops/deform_conv2d.h create mode 100644 product/include/torchvision/ops/mps/mps_helpers.h create mode 100644 product/include/torchvision/ops/mps/mps_kernels.h create mode 100644 product/include/torchvision/ops/mps/nms_kernel.mm create mode 100644 product/include/torchvision/ops/mps/ps_roi_align_kernel.mm create mode 100644 product/include/torchvision/ops/mps/ps_roi_pool_kernel.mm create mode 100644 product/include/torchvision/ops/mps/roi_align_kernel.mm create mode 100644 product/include/torchvision/ops/mps/roi_pool_kernel.mm create mode 100644 product/include/torchvision/ops/nms.cpp create mode 100644 product/include/torchvision/ops/nms.h create mode 100644 product/include/torchvision/ops/ops.h create mode 100644 product/include/torchvision/ops/ps_roi_align.cpp create mode 100644 product/include/torchvision/ops/ps_roi_align.h create mode 100644 product/include/torchvision/ops/ps_roi_pool.cpp create mode 100644 product/include/torchvision/ops/ps_roi_pool.h create mode 100644 product/include/torchvision/ops/roi_align.cpp create mode 100644 product/include/torchvision/ops/roi_align.h create mode 100644 product/include/torchvision/ops/roi_pool.cpp create mode 100644 product/include/torchvision/ops/roi_pool.h create mode 100644 product/include/torchvision/vision.cpp create mode 100644 product/include/torchvision/vision.h create mode 100644 product/share/cmake/TorchVision/TorchVisionConfig.cmake create mode 100644 product/share/cmake/TorchVision/TorchVisionConfigVersion.cmake create mode 100644 product/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake create mode 100644 product/share/cmake/TorchVision/TorchVisionTargets.cmake create mode 100644 test/playground/test_mps_import.py create mode 100644 torchvision/installTest.cpp diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..1499086fa3f99cdf0adf902c657c89b9a25ecaf6 GIT binary patch literal 6148 zcmeHKJ&)5s5S_jAk%S0H!3j!Bh%OSb5Qx(im=FqDjGzD%>^NMkv%OaAoQo*JqJx6p zf%}?ORQ1SMI!@2mNIs|A(ntkKhnf1)GvYRC$He7@bQICk)P=VHY3=JaYg>J}} z_FM!qxkeb~lQ`f!-iS30!+>Gn?_)sxc6TVC1w|Co+V8jKjt?Af!jr64e{DA5qfrp^ z0JH5k*LSUt`(N&@FPrf2Hxb!v;|(2pMRW2gp(&NbVIMZs(GXRy=l$o0-#(~%8Y6$! z_3Oxn`dXBsCQc#9n@~~j@sizr-MK4SI1$ybjztAJ%&9o?=4s5+18|4!l)CVXt|v63 zBYd75$|()XJkxA{3bHuLi^1S4+uF8v&b1)j)(daOORw;YMLzWB&&1PXY{|h7+Yg>Z z({kLqe4nR)SG%n5)o?h5^ICIs@D4u_w;|2dCfv*OSbXVZbo(UopU} zkvAHmC4IJTG$+nl8~PF|M8s7JWeNs;97{tS#aE$HFlP$`=xVGK!UHjX1SAb+Fbw=r G2EGHhQ>T6a literal 0 HcmV?d00001 diff --git a/android/gradle/wrapper/gradle-wrapper.properties b/android/gradle/wrapper/gradle-wrapper.properties index 442d9132ea3..5ef7a1a3a60 100644 --- a/android/gradle/wrapper/gradle-wrapper.properties +++ b/android/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,6 @@ +#Tue Aug 27 15:56:14 CEST 2024 distributionBase=GRADLE_USER_HOME +distributionUrl=https\://services.gradle.org/distributions/gradle-8.9-bin.zip distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip -zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists +zipStoreBase=GRADLE_USER_HOME diff --git a/framework/.DS_Store b/framework/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..d0ccba84354fbb07f0ee1c0ad907b69509dc98c5 GIT binary patch literal 6148 zcmeHLJ5Iwu5S>i|F@ho`r3+e$w1HxYOiPmsKw?CKY#b4#qVc7;0tI&fS&0=jUGC8czm-StP4 zqLFQmzo-DeyDrV>fhLqu<^ARJaJ!vFx0#U_NixogF(T3D>gxUZ?B!xxR{D#q`IJ|S zvh9B`X&ZMkqp_?bjSSw7H<2p;{ZsO){92b68{THL#y40o-ySV!N;mZQy>^w8PR{Gm zD(`8w=B?CMxwmK3_{lTBnJHijm;$?206m*6Iux|o6fgx$fwcnseTXp_y<#aCJ{_3C z7690RI~b06FTpjjqE{>h5rH`=1xl&YEryeF_+!oUilv~GlMBqajxsylp}4>fe;m@u zc|n^^0aKu^KwD0G-2X3s-v8H&?93D}1^$%+u9J+CAs$J4YvbX#*Txvv7;KzZ3a(3V j5Lz*ExfS!jHNoJI`2gq@OF?*G_Cp}ZV3R4ZQw6>NB|~WX literal 0 HcmV?d00001 diff --git a/framework/include/.DS_Store b/framework/include/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..784dc5a7f8d658bf3890dc86f21b072059f4b66d GIT binary patch literal 6148 zcmeHKOG*Pl5UtWd2C~W0WnUp{H*FYCkfocD1dNb4A%50<5s%>sgsgJ`?~+%aNP?pv zf{0W>)vM{Q>Zy4!-Cab)!^hcxs82*SRFK8#5E<^AI`QBN$hyXyp6HhDp>4-Pe{o1^ z@6iM5UDK2f>%Uvv3`Wy-xze($>(kF8mu=Hb7VQKP_4DiI?d4)~`Bn7n8_}!zK^13P zEhbDb5DWwZ!9XzZ0|szri_|^Cu)#nu5DdICAp1jt3TDS*s9Oh&wg5o6MytS=UP5w` zV|FZtn1QgR0xgxj#b8Uvc=EXHSPU(l*qaabH}9Jl*0*E*q~XNbFl;ao3>-5c9WA6U z=l>Ocna(C(L&5|D!N7lIfU9QI4Doe-wto0dIcpPi3>A^MA`S%p(Io&6vX7kWq|GPs Y5tki{p{yd~nhuPMfD#fW82AMScC|Y_XCp zEkekaW}e5Naq^_h#6-l)hh{-EC!z{NkVTmgF;BW?7A#JdJ#JFlsqOjFe19F}*=KZ5 zX>+?+9Xx-wx2dkzo3ut;RXx`4Pv@@}`;=+a9_PB4syy&^mi)>NRTvb7kj>97Zji-x_TrW5DN*vBgGn-|X2VGlW+ zI4b(=3^)U01`1uy<^I3ICo|dPk5hc*3^)V-i~%m|RlUSZ+1+~ddUDqWjB5-LiR(py mKp*`CU?JzoX?ChVh>o~u*el8|V$bP7{}IT9_~Z=y0t27H*+Jp} literal 0 HcmV?d00001 diff --git a/framework/include/torchvision/io/image/cpu/common_jpeg.cpp b/framework/include/torchvision/io/image/cpu/common_jpeg.cpp new file mode 100644 index 00000000000..4c993106b45 --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/common_jpeg.cpp @@ -0,0 +1,26 @@ +#include "common_jpeg.h" + +namespace vision { +namespace image { +namespace detail { + +#if JPEG_FOUND +void torch_jpeg_error_exit(j_common_ptr cinfo) { + /* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce + * pointer */ + torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err; + + /* Always display the message. */ + /* We could postpone this until after returning, if we chose. */ + // (*cinfo->err->output_message)(cinfo); + /* Create the message */ + (*(cinfo->err->format_message))(cinfo, myerr->jpegLastErrorMsg); + + /* Return control to the setjmp point */ + longjmp(myerr->setjmp_buffer, 1); +} +#endif + +} // namespace detail +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/common_jpeg.h b/framework/include/torchvision/io/image/cpu/common_jpeg.h new file mode 100644 index 00000000000..7f7f9f0ccf1 --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/common_jpeg.h @@ -0,0 +1,27 @@ +#pragma once + +#if JPEG_FOUND +#include + +#include +#include + +namespace vision { +namespace image { +namespace detail { + +static const JOCTET EOI_BUFFER[1] = {JPEG_EOI}; +struct torch_jpeg_error_mgr { + struct jpeg_error_mgr pub; /* "public" fields */ + char jpegLastErrorMsg[JMSG_LENGTH_MAX]; /* error messages */ + jmp_buf setjmp_buffer; /* for return to caller */ +}; + +using torch_jpeg_error_ptr = struct torch_jpeg_error_mgr*; +void torch_jpeg_error_exit(j_common_ptr cinfo); + +} // namespace detail +} // namespace image +} // namespace vision + +#endif diff --git a/framework/include/torchvision/io/image/cpu/common_png.h b/framework/include/torchvision/io/image/cpu/common_png.h new file mode 100644 index 00000000000..68400d48e05 --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/common_png.h @@ -0,0 +1,6 @@ +#pragma once + +#if PNG_FOUND +#include +#include +#endif diff --git a/framework/include/torchvision/io/image/cpu/decode_image.cpp b/framework/include/torchvision/io/image/cpu/decode_image.cpp new file mode 100644 index 00000000000..dbf349b06ca --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/decode_image.cpp @@ -0,0 +1,41 @@ +#include "decode_image.h" + +#include "decode_jpeg.h" +#include "decode_png.h" + +namespace vision { +namespace image { + +torch::Tensor decode_image( + const torch::Tensor& data, + ImageReadMode mode, + bool apply_exif_orientation) { + // Check that tensor is a CPU tensor + TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor"); + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + auto datap = data.data_ptr(); + + const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF" + const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" + + if (memcmp(jpeg_signature, datap, 3) == 0) { + return decode_jpeg(data, mode, apply_exif_orientation); + } else if (memcmp(png_signature, datap, 4) == 0) { + return decode_png( + data, mode, /*allow_16_bits=*/false, apply_exif_orientation); + } else { + TORCH_CHECK( + false, + "Unsupported image file. Only jpeg and png ", + "are currently supported."); + } +} + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/decode_image.h b/framework/include/torchvision/io/image/cpu/decode_image.h new file mode 100644 index 00000000000..f0e66d397ac --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/decode_image.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include "../image_read_mode.h" + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_image( + const torch::Tensor& data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, + bool apply_exif_orientation = false); + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/decode_jpeg.cpp b/framework/include/torchvision/io/image/cpu/decode_jpeg.cpp new file mode 100644 index 00000000000..ec5953e4106 --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/decode_jpeg.cpp @@ -0,0 +1,271 @@ +#include "decode_jpeg.h" +#include "common_jpeg.h" +#include "exif.h" + +namespace vision { +namespace image { + +#if !JPEG_FOUND +torch::Tensor decode_jpeg( + const torch::Tensor& data, + ImageReadMode mode, + bool apply_exif_orientation) { + TORCH_CHECK( + false, "decode_jpeg: torchvision not compiled with libjpeg support"); +} +#else + +using namespace detail; +using namespace exif_private; + +namespace { + +struct torch_jpeg_mgr { + struct jpeg_source_mgr pub; + const JOCTET* data; + size_t len; +}; + +static void torch_jpeg_init_source(j_decompress_ptr cinfo) {} + +static boolean torch_jpeg_fill_input_buffer(j_decompress_ptr cinfo) { + // No more data. Probably an incomplete image; Raise exception. + torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err; + strcpy(myerr->jpegLastErrorMsg, "Image is incomplete or truncated"); + longjmp(myerr->setjmp_buffer, 1); +} + +static void torch_jpeg_skip_input_data(j_decompress_ptr cinfo, long num_bytes) { + torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src; + if (src->pub.bytes_in_buffer < (size_t)num_bytes) { + // Skipping over all of remaining data; output EOI. + src->pub.next_input_byte = EOI_BUFFER; + src->pub.bytes_in_buffer = 1; + } else { + // Skipping over only some of the remaining data. + src->pub.next_input_byte += num_bytes; + src->pub.bytes_in_buffer -= num_bytes; + } +} + +static void torch_jpeg_term_source(j_decompress_ptr cinfo) {} + +static void torch_jpeg_set_source_mgr( + j_decompress_ptr cinfo, + const unsigned char* data, + size_t len) { + torch_jpeg_mgr* src; + if (cinfo->src == 0) { // if this is first time; allocate memory + cinfo->src = (struct jpeg_source_mgr*)(*cinfo->mem->alloc_small)( + (j_common_ptr)cinfo, JPOOL_PERMANENT, sizeof(torch_jpeg_mgr)); + } + src = (torch_jpeg_mgr*)cinfo->src; + src->pub.init_source = torch_jpeg_init_source; + src->pub.fill_input_buffer = torch_jpeg_fill_input_buffer; + src->pub.skip_input_data = torch_jpeg_skip_input_data; + src->pub.resync_to_restart = jpeg_resync_to_restart; // default + src->pub.term_source = torch_jpeg_term_source; + // fill the buffers + src->data = (const JOCTET*)data; + src->len = len; + src->pub.bytes_in_buffer = len; + src->pub.next_input_byte = src->data; + + jpeg_save_markers(cinfo, APP1, 0xffff); +} + +inline unsigned char clamped_cmyk_rgb_convert( + unsigned char k, + unsigned char cmy) { + // Inspired from Pillow: + // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569 + int v = k * cmy + 128; + v = ((v >> 8) + v) >> 8; + return std::clamp(k - v, 0, 255); +} + +void convert_line_cmyk_to_rgb( + j_decompress_ptr cinfo, + const unsigned char* cmyk_line, + unsigned char* rgb_line) { + int width = cinfo->output_width; + for (int i = 0; i < width; ++i) { + int c = cmyk_line[i * 4 + 0]; + int m = cmyk_line[i * 4 + 1]; + int y = cmyk_line[i * 4 + 2]; + int k = cmyk_line[i * 4 + 3]; + + rgb_line[i * 3 + 0] = clamped_cmyk_rgb_convert(k, 255 - c); + rgb_line[i * 3 + 1] = clamped_cmyk_rgb_convert(k, 255 - m); + rgb_line[i * 3 + 2] = clamped_cmyk_rgb_convert(k, 255 - y); + } +} + +inline unsigned char rgb_to_gray(int r, int g, int b) { + // Inspired from Pillow: + // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226 + return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16; +} + +void convert_line_cmyk_to_gray( + j_decompress_ptr cinfo, + const unsigned char* cmyk_line, + unsigned char* gray_line) { + int width = cinfo->output_width; + for (int i = 0; i < width; ++i) { + int c = cmyk_line[i * 4 + 0]; + int m = cmyk_line[i * 4 + 1]; + int y = cmyk_line[i * 4 + 2]; + int k = cmyk_line[i * 4 + 3]; + + int r = clamped_cmyk_rgb_convert(k, 255 - c); + int g = clamped_cmyk_rgb_convert(k, 255 - m); + int b = clamped_cmyk_rgb_convert(k, 255 - y); + + gray_line[i] = rgb_to_gray(r, g, b); + } +} + +} // namespace + +torch::Tensor decode_jpeg( + const torch::Tensor& data, + ImageReadMode mode, + bool apply_exif_orientation) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg"); + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + struct jpeg_decompress_struct cinfo; + struct torch_jpeg_error_mgr jerr; + + auto datap = data.data_ptr(); + // Setup decompression structure + cinfo.err = jpeg_std_error(&jerr.pub); + jerr.pub.error_exit = torch_jpeg_error_exit; + /* Establish the setjmp return context for my_error_exit to use. */ + if (setjmp(jerr.setjmp_buffer)) { + /* If we get here, the JPEG code has signaled an error. + * We need to clean up the JPEG object. + */ + jpeg_destroy_decompress(&cinfo); + TORCH_CHECK(false, jerr.jpegLastErrorMsg); + } + + jpeg_create_decompress(&cinfo); + torch_jpeg_set_source_mgr(&cinfo, datap, data.numel()); + + // read info from header. + jpeg_read_header(&cinfo, TRUE); + + int channels = cinfo.num_components; + bool cmyk_to_rgb_or_gray = false; + + if (mode != IMAGE_READ_MODE_UNCHANGED) { + switch (mode) { + case IMAGE_READ_MODE_GRAY: + if (cinfo.jpeg_color_space == JCS_CMYK || + cinfo.jpeg_color_space == JCS_YCCK) { + cinfo.out_color_space = JCS_CMYK; + cmyk_to_rgb_or_gray = true; + } else { + cinfo.out_color_space = JCS_GRAYSCALE; + } + channels = 1; + break; + case IMAGE_READ_MODE_RGB: + if (cinfo.jpeg_color_space == JCS_CMYK || + cinfo.jpeg_color_space == JCS_YCCK) { + cinfo.out_color_space = JCS_CMYK; + cmyk_to_rgb_or_gray = true; + } else { + cinfo.out_color_space = JCS_RGB; + } + channels = 3; + break; + /* + * Libjpeg does not support converting from CMYK to grayscale etc. There + * is a way to do this but it involves converting it manually to RGB: + * https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313 + */ + default: + jpeg_destroy_decompress(&cinfo); + TORCH_CHECK(false, "The provided mode is not supported for JPEG files"); + } + + jpeg_calc_output_dimensions(&cinfo); + } + + int exif_orientation = -1; + if (apply_exif_orientation) { + exif_orientation = fetch_jpeg_exif_orientation(&cinfo); + } + + jpeg_start_decompress(&cinfo); + + int height = cinfo.output_height; + int width = cinfo.output_width; + + int stride = width * channels; + auto tensor = + torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); + auto ptr = tensor.data_ptr(); + torch::Tensor cmyk_line_tensor; + if (cmyk_to_rgb_or_gray) { + cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8); + } + + while (cinfo.output_scanline < cinfo.output_height) { + /* jpeg_read_scanlines expects an array of pointers to scanlines. + * Here the array is only one element long, but you could ask for + * more than one scanline at a time if that's more convenient. + */ + if (cmyk_to_rgb_or_gray) { + auto cmyk_line_ptr = cmyk_line_tensor.data_ptr(); + jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1); + + if (channels == 3) { + convert_line_cmyk_to_rgb(&cinfo, cmyk_line_ptr, ptr); + } else if (channels == 1) { + convert_line_cmyk_to_gray(&cinfo, cmyk_line_ptr, ptr); + } + } else { + jpeg_read_scanlines(&cinfo, &ptr, 1); + } + ptr += stride; + } + + jpeg_finish_decompress(&cinfo); + jpeg_destroy_decompress(&cinfo); + auto output = tensor.permute({2, 0, 1}); + + if (apply_exif_orientation) { + return exif_orientation_transform(output, exif_orientation); + } + return output; +} +#endif // #if !JPEG_FOUND + +int64_t _jpeg_version() { +#if JPEG_FOUND + return JPEG_LIB_VERSION; +#else + return -1; +#endif +} + +bool _is_compiled_against_turbo() { +#ifdef LIBJPEG_TURBO_VERSION + return true; +#else + return false; +#endif +} + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/decode_jpeg.h b/framework/include/torchvision/io/image/cpu/decode_jpeg.h new file mode 100644 index 00000000000..e0c9a24c846 --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/decode_jpeg.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include "../image_read_mode.h" + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_jpeg( + const torch::Tensor& data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, + bool apply_exif_orientation = false); + +C10_EXPORT int64_t _jpeg_version(); +C10_EXPORT bool _is_compiled_against_turbo(); + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/decode_png.cpp b/framework/include/torchvision/io/image/cpu/decode_png.cpp new file mode 100644 index 00000000000..ab4087fdfe2 --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/decode_png.cpp @@ -0,0 +1,259 @@ +#include "decode_png.h" +#include "common_png.h" +#include "exif.h" + +namespace vision { +namespace image { + +using namespace exif_private; + +#if !PNG_FOUND +torch::Tensor decode_png( + const torch::Tensor& data, + ImageReadMode mode, + bool allow_16_bits, + bool apply_exif_orientation) { + TORCH_CHECK( + false, "decode_png: torchvision not compiled with libPNG support"); +} +#else + +bool is_little_endian() { + uint32_t x = 1; + return *(uint8_t*)&x; +} + +torch::Tensor decode_png( + const torch::Tensor& data, + ImageReadMode mode, + bool allow_16_bits, + bool apply_exif_orientation) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png"); + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + auto png_ptr = + png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); + TORCH_CHECK(png_ptr, "libpng read structure allocation failed!") + auto info_ptr = png_create_info_struct(png_ptr); + if (!info_ptr) { + png_destroy_read_struct(&png_ptr, nullptr, nullptr); + // Seems redundant with the if statement. done here to avoid leaking memory. + TORCH_CHECK(info_ptr, "libpng info structure allocation failed!") + } + + auto accessor = data.accessor(); + auto datap = accessor.data(); + auto datap_len = accessor.size(0); + + if (setjmp(png_jmpbuf(png_ptr)) != 0) { + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK(false, "Internal error."); + } + TORCH_CHECK(datap_len >= 8, "Content is too small for png!") + auto is_png = !png_sig_cmp(datap, 0, 8); + TORCH_CHECK(is_png, "Content is not png!") + + struct Reader { + png_const_bytep ptr; + png_size_t count; + } reader; + reader.ptr = png_const_bytep(datap) + 8; + reader.count = datap_len - 8; + + auto read_callback = [](png_structp png_ptr, + png_bytep output, + png_size_t bytes) { + auto reader = static_cast(png_get_io_ptr(png_ptr)); + TORCH_CHECK( + reader->count >= bytes, + "Out of bound read in decode_png. Probably, the input image is corrupted"); + std::copy(reader->ptr, reader->ptr + bytes, output); + reader->ptr += bytes; + reader->count -= bytes; + }; + png_set_sig_bytes(png_ptr, 8); + png_set_read_fn(png_ptr, &reader, read_callback); + png_read_info(png_ptr, info_ptr); + + png_uint_32 width, height; + int bit_depth, color_type; + int interlace_type; + auto retval = png_get_IHDR( + png_ptr, + info_ptr, + &width, + &height, + &bit_depth, + &color_type, + &interlace_type, + nullptr, + nullptr); + + if (retval != 1) { + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK(retval == 1, "Could read image metadata from content.") + } + + auto max_bit_depth = allow_16_bits ? 16 : 8; + auto err_msg = "At most " + std::to_string(max_bit_depth) + + "-bit PNG images are supported currently."; + if (bit_depth > max_bit_depth) { + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK(false, err_msg) + } + + int channels = png_get_channels(png_ptr, info_ptr); + + if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8) + png_set_expand_gray_1_2_4_to_8(png_ptr); + + int number_of_passes; + if (interlace_type == PNG_INTERLACE_ADAM7) { + number_of_passes = png_set_interlace_handling(png_ptr); + } else { + number_of_passes = 1; + } + + if (mode != IMAGE_READ_MODE_UNCHANGED) { + // TODO: consider supporting PNG_INFO_tRNS + bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0; + bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0; + bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0; + + switch (mode) { + case IMAGE_READ_MODE_GRAY: + if (color_type != PNG_COLOR_TYPE_GRAY) { + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } + + if (has_alpha) { + png_set_strip_alpha(png_ptr); + } + + if (has_color) { + png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); + } + channels = 1; + } + break; + case IMAGE_READ_MODE_GRAY_ALPHA: + if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) { + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } + + if (!has_alpha) { + png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); + } + + if (has_color) { + png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); + } + channels = 2; + } + break; + case IMAGE_READ_MODE_RGB: + if (color_type != PNG_COLOR_TYPE_RGB) { + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } else if (!has_color) { + png_set_gray_to_rgb(png_ptr); + } + + if (has_alpha) { + png_set_strip_alpha(png_ptr); + } + channels = 3; + } + break; + case IMAGE_READ_MODE_RGB_ALPHA: + if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) { + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } else if (!has_color) { + png_set_gray_to_rgb(png_ptr); + } + + if (!has_alpha) { + png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); + } + channels = 4; + } + break; + default: + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK(false, "The provided mode is not supported for PNG files"); + } + + png_read_update_info(png_ptr, info_ptr); + } + + auto num_pixels_per_row = width * channels; + auto tensor = torch::empty( + {int64_t(height), int64_t(width), channels}, + bit_depth <= 8 ? torch::kU8 : torch::kI32); + + if (bit_depth <= 8) { + auto t_ptr = tensor.accessor().data(); + for (int pass = 0; pass < number_of_passes; pass++) { + for (png_uint_32 i = 0; i < height; ++i) { + png_read_row(png_ptr, t_ptr, nullptr); + t_ptr += num_pixels_per_row; + } + t_ptr = tensor.accessor().data(); + } + } else { + // We're reading a 16bits png, but pytorch doesn't support uint16. + // So we read each row in a 16bits tmp_buffer which we then cast into + // a int32 tensor instead. + if (is_little_endian()) { + png_set_swap(png_ptr); + } + int32_t* t_ptr = tensor.accessor().data(); + + // We create a tensor instead of malloc-ing for automatic memory management + auto tmp_buffer_tensor = torch::empty( + {int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8); + uint16_t* tmp_buffer = + (uint16_t*)tmp_buffer_tensor.accessor().data(); + + for (int pass = 0; pass < number_of_passes; pass++) { + for (png_uint_32 i = 0; i < height; ++i) { + png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr); + // Now we copy the uint16 values into the int32 tensor. + for (size_t j = 0; j < num_pixels_per_row; ++j) { + t_ptr[j] = (int32_t)tmp_buffer[j]; + } + t_ptr += num_pixels_per_row; + } + t_ptr = tensor.accessor().data(); + } + } + + int exif_orientation = -1; + if (apply_exif_orientation) { + exif_orientation = fetch_png_exif_orientation(png_ptr, info_ptr); + } + + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + + auto output = tensor.permute({2, 0, 1}); + if (apply_exif_orientation) { + return exif_orientation_transform(output, exif_orientation); + } + return output; +} +#endif + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/decode_png.h b/framework/include/torchvision/io/image/cpu/decode_png.h new file mode 100644 index 00000000000..b091f15e35f --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/decode_png.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include "../image_read_mode.h" + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_png( + const torch::Tensor& data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, + bool allow_16_bits = false, + bool apply_exif_orientation = false); + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/encode_jpeg.cpp b/framework/include/torchvision/io/image/cpu/encode_jpeg.cpp new file mode 100644 index 00000000000..d2ed73071a2 --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/encode_jpeg.cpp @@ -0,0 +1,113 @@ +#include "encode_jpeg.h" + +#include "common_jpeg.h" + +namespace vision { +namespace image { + +#if !JPEG_FOUND + +torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { + TORCH_CHECK( + false, "encode_jpeg: torchvision not compiled with libjpeg support"); +} + +#else +// For libjpeg version <= 9b, the out_size parameter in jpeg_mem_dest() is +// defined as unsigned long, whereas in later version, it is defined as size_t. +#if !defined(JPEG_LIB_VERSION_MAJOR) || JPEG_LIB_VERSION_MAJOR < 9 || \ + (JPEG_LIB_VERSION_MAJOR == 9 && JPEG_LIB_VERSION_MINOR <= 2) +using JpegSizeType = unsigned long; +#else +using JpegSizeType = size_t; +#endif + +using namespace detail; + +torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cpu.encode_jpeg.encode_jpeg"); + // Define compression structures and error handling + struct jpeg_compress_struct cinfo {}; + struct torch_jpeg_error_mgr jerr {}; + + // Define buffer to write JPEG information to and its size + JpegSizeType jpegSize = 0; + uint8_t* jpegBuf = nullptr; + + cinfo.err = jpeg_std_error(&jerr.pub); + jerr.pub.error_exit = torch_jpeg_error_exit; + + /* Establish the setjmp return context for my_error_exit to use. */ + if (setjmp(jerr.setjmp_buffer)) { + /* If we get here, the JPEG code has signaled an error. + * We need to clean up the JPEG object and the buffer. + */ + jpeg_destroy_compress(&cinfo); + if (jpegBuf != nullptr) { + free(jpegBuf); + } + + TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg); + } + + // Check that the input tensor is on CPU + TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); + + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + + // Check that the input tensor is 3-dimensional + TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); + + // Get image info + int channels = data.size(0); + int height = data.size(1); + int width = data.size(2); + auto input = data.permute({1, 2, 0}).contiguous(); + + TORCH_CHECK( + channels == 1 || channels == 3, + "The number of channels should be 1 or 3, got: ", + channels); + + // Initialize JPEG structure + jpeg_create_compress(&cinfo); + + // Set output image information + cinfo.image_width = width; + cinfo.image_height = height; + cinfo.input_components = channels; + cinfo.in_color_space = channels == 1 ? JCS_GRAYSCALE : JCS_RGB; + + jpeg_set_defaults(&cinfo); + jpeg_set_quality(&cinfo, quality, TRUE); + + // Save JPEG output to a buffer + jpeg_mem_dest(&cinfo, &jpegBuf, &jpegSize); + + // Start JPEG compression + jpeg_start_compress(&cinfo, TRUE); + + auto stride = width * channels; + auto ptr = input.data_ptr(); + + // Encode JPEG file + while (cinfo.next_scanline < cinfo.image_height) { + jpeg_write_scanlines(&cinfo, &ptr, 1); + ptr += stride; + } + + jpeg_finish_compress(&cinfo); + jpeg_destroy_compress(&cinfo); + + torch::TensorOptions options = torch::TensorOptions{torch::kU8}; + auto out_tensor = + torch::from_blob(jpegBuf, {(long)jpegSize}, ::free, options); + jpegBuf = nullptr; + return out_tensor; +} +#endif + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/encode_jpeg.h b/framework/include/torchvision/io/image/cpu/encode_jpeg.h new file mode 100644 index 00000000000..25084e154d6 --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/encode_jpeg.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor encode_jpeg( + const torch::Tensor& data, + int64_t quality); + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/encode_png.cpp b/framework/include/torchvision/io/image/cpu/encode_png.cpp new file mode 100644 index 00000000000..a9b7d76ff61 --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/encode_png.cpp @@ -0,0 +1,180 @@ +#include "encode_jpeg.h" + +#include "common_png.h" + +namespace vision { +namespace image { + +#if !PNG_FOUND + +torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { + TORCH_CHECK( + false, "encode_png: torchvision not compiled with libpng support"); +} + +#else + +namespace { + +struct torch_mem_encode { + char* buffer; + size_t size; +}; + +struct torch_png_error_mgr { + const char* pngLastErrorMsg; /* error messages */ + jmp_buf setjmp_buffer; /* for return to caller */ +}; + +using torch_png_error_mgr_ptr = torch_png_error_mgr*; + +void torch_png_error(png_structp png_ptr, png_const_charp error_msg) { + /* png_ptr->err really points to a torch_png_error_mgr struct, so coerce + * pointer */ + auto error_ptr = (torch_png_error_mgr_ptr)png_get_error_ptr(png_ptr); + /* Replace the error message on the error structure */ + error_ptr->pngLastErrorMsg = error_msg; + /* Return control to the setjmp point */ + longjmp(error_ptr->setjmp_buffer, 1); +} + +void torch_png_write_data( + png_structp png_ptr, + png_bytep data, + png_size_t length) { + struct torch_mem_encode* p = + (struct torch_mem_encode*)png_get_io_ptr(png_ptr); + size_t nsize = p->size + length; + + /* allocate or grow buffer */ + if (p->buffer) + p->buffer = (char*)realloc(p->buffer, nsize); + else + p->buffer = (char*)malloc(nsize); + + if (!p->buffer) + png_error(png_ptr, "Write Error"); + + /* copy new bytes to end of buffer */ + memcpy(p->buffer + p->size, data, length); + p->size += length; +} + +} // namespace + +torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.encode_png.encode_png"); + // Define compression structures and error handling + png_structp png_write; + png_infop info_ptr; + struct torch_png_error_mgr err_ptr; + + // Define output buffer + struct torch_mem_encode buf_info; + buf_info.buffer = NULL; + buf_info.size = 0; + + /* Establish the setjmp return context for my_error_exit to use. */ + if (setjmp(err_ptr.setjmp_buffer)) { + /* If we get here, the PNG code has signaled an error. + * We need to clean up the PNG object and the buffer. + */ + if (info_ptr != NULL) { + png_destroy_info_struct(png_write, &info_ptr); + } + + if (png_write != NULL) { + png_destroy_write_struct(&png_write, NULL); + } + + if (buf_info.buffer != NULL) { + free(buf_info.buffer); + } + + TORCH_CHECK(false, err_ptr.pngLastErrorMsg); + } + + // Check that the compression level is between 0 and 9 + TORCH_CHECK( + compression_level >= 0 && compression_level <= 9, + "Compression level should be between 0 and 9"); + + // Check that the input tensor is on CPU + TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); + + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + + // Check that the input tensor is 3-dimensional + TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); + + // Get image info + int channels = data.size(0); + int height = data.size(1); + int width = data.size(2); + auto input = data.permute({1, 2, 0}).contiguous(); + + TORCH_CHECK( + channels == 1 || channels == 3, + "The number of channels should be 1 or 3, got: ", + channels); + + // Initialize PNG structures + png_write = png_create_write_struct( + PNG_LIBPNG_VER_STRING, &err_ptr, torch_png_error, NULL); + + info_ptr = png_create_info_struct(png_write); + + // Define custom buffer output + png_set_write_fn(png_write, &buf_info, torch_png_write_data, NULL); + + // Set output image information + auto color_type = channels == 1 ? PNG_COLOR_TYPE_GRAY : PNG_COLOR_TYPE_RGB; + png_set_IHDR( + png_write, + info_ptr, + width, + height, + 8, + color_type, + PNG_INTERLACE_NONE, + PNG_COMPRESSION_TYPE_DEFAULT, + PNG_FILTER_TYPE_DEFAULT); + + // Set image compression level + png_set_compression_level(png_write, compression_level); + + // Write file header + png_write_info(png_write, info_ptr); + + auto stride = width * channels; + auto ptr = input.data_ptr(); + + // Encode PNG file + for (int y = 0; y < height; ++y) { + png_write_row(png_write, ptr); + ptr += stride; + } + + // Write EOF + png_write_end(png_write, info_ptr); + + // Destroy structures + png_destroy_write_struct(&png_write, &info_ptr); + + torch::TensorOptions options = torch::TensorOptions{torch::kU8}; + auto outTensor = torch::empty({(long)buf_info.size}, options); + + // Copy memory from png buffer, since torch cannot get ownership of it via + // `from_blob` + auto outPtr = outTensor.data_ptr(); + std::memcpy(outPtr, buf_info.buffer, sizeof(uint8_t) * outTensor.numel()); + free(buf_info.buffer); + + return outTensor; +} + +#endif + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/encode_png.h b/framework/include/torchvision/io/image/cpu/encode_png.h new file mode 100644 index 00000000000..86a67c8706e --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/encode_png.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor encode_png( + const torch::Tensor& data, + int64_t compression_level); + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/exif.h b/framework/include/torchvision/io/image/cpu/exif.h new file mode 100644 index 00000000000..0f9a59417db --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/exif.h @@ -0,0 +1,264 @@ +/*M/////////////////////////////////////////////////////////////////////////////////////// +// +// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. +// +// By downloading, copying, installing or using the software you agree to this +license. +// If you do not agree to this license, do not download, install, +// copy or use the software. +// +// +// License Agreement +// For Open Source Computer Vision Library +// +// Copyright (C) 2000-2008, Intel Corporation, all rights reserved. +// Copyright (C) 2009, Willow Garage Inc., all rights reserved. +// Third party copyrights are property of their respective owners. +// +// Redistribution and use in source and binary forms, with or without +modification, +// are permitted provided that the following conditions are met: +// +// * Redistribution's of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistribution's in binary form must reproduce the above copyright +notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * The name of the copyright holders may not be used to endorse or promote +products +// derived from this software without specific prior written permission. +// +// This software is provided by the copyright holders and contributors "as is" +and +// any express or implied warranties, including, but not limited to, the implied +// warranties of merchantability and fitness for a particular purpose are +disclaimed. +// In no event shall the Intel Corporation or contributors be liable for any +direct, +// indirect, incidental, special, exemplary, or consequential damages +// (including, but not limited to, procurement of substitute goods or services; +// loss of use, data, or profits; or business interruption) however caused +// and on any theory of liability, whether in contract, strict liability, +// or tort (including negligence or otherwise) arising in any way out of +// the use of this software, even if advised of the possibility of such damage. +// +//M*/ +#pragma once +// Functions in this module are taken from OpenCV +// https://github.com/opencv/opencv/blob/097891e311fae1d8354eb092a0fd0171e630d78c/modules/imgcodecs/src/exif.cpp + +#if JPEG_FOUND +#include +#endif +#if PNG_FOUND +#include +#endif + +#include + +namespace vision { +namespace image { +namespace exif_private { + +constexpr uint16_t APP1 = 0xe1; +constexpr uint16_t ENDIANNESS_INTEL = 0x49; +constexpr uint16_t ENDIANNESS_MOTO = 0x4d; +constexpr uint16_t REQ_EXIF_TAG_MARK = 0x2a; +constexpr uint16_t ORIENTATION_EXIF_TAG = 0x0112; +constexpr uint16_t INCORRECT_TAG = -1; + +class ExifDataReader { + public: + ExifDataReader(unsigned char* p, size_t s) : _ptr(p), _size(s) {} + size_t size() const { + return _size; + } + const unsigned char& operator[](size_t index) const { + TORCH_CHECK(index >= 0 && index < _size); + return _ptr[index]; + } + + protected: + unsigned char* _ptr; + size_t _size; +}; + +inline uint16_t get_endianness(const ExifDataReader& exif_data) { + if ((exif_data.size() < 1) || + (exif_data.size() > 1 && exif_data[0] != exif_data[1])) { + return 0; + } + if (exif_data[0] == 'I') { + return ENDIANNESS_INTEL; + } + if (exif_data[0] == 'M') { + return ENDIANNESS_MOTO; + } + return 0; +} + +inline uint16_t get_uint16( + const ExifDataReader& exif_data, + uint16_t endianness, + const size_t offset) { + if (offset + 1 >= exif_data.size()) { + return INCORRECT_TAG; + } + + if (endianness == ENDIANNESS_INTEL) { + return exif_data[offset] + (exif_data[offset + 1] << 8); + } + return (exif_data[offset] << 8) + exif_data[offset + 1]; +} + +inline uint32_t get_uint32( + const ExifDataReader& exif_data, + uint16_t endianness, + const size_t offset) { + if (offset + 3 >= exif_data.size()) { + return INCORRECT_TAG; + } + + if (endianness == ENDIANNESS_INTEL) { + return exif_data[offset] + (exif_data[offset + 1] << 8) + + (exif_data[offset + 2] << 16) + (exif_data[offset + 3] << 24); + } + return (exif_data[offset] << 24) + (exif_data[offset + 1] << 16) + + (exif_data[offset + 2] << 8) + exif_data[offset + 3]; +} + +inline int fetch_exif_orientation(unsigned char* exif_data_ptr, size_t size) { + int exif_orientation = -1; + + // Exif binary structure looks like this + // First 6 bytes: [E, x, i, f, 0, 0] + // Endianness, 2 bytes : [M, M] or [I, I] + // Tag mark, 2 bytes: [0, 0x2a] + // Offset, 4 bytes + // Num entries, 2 bytes + // Tag entries and data, tag has 2 bytes and its data has 10 bytes + // For more details: + // http://www.media.mit.edu/pia/Research/deepview/exif.html + + ExifDataReader exif_data(exif_data_ptr, size); + auto endianness = get_endianness(exif_data); + + // Checking whether Tag Mark (0x002A) correspond to one contained in the + // Jpeg file + uint16_t tag_mark = get_uint16(exif_data, endianness, 2); + if (tag_mark == REQ_EXIF_TAG_MARK) { + auto offset = get_uint32(exif_data, endianness, 4); + size_t num_entry = get_uint16(exif_data, endianness, offset); + offset += 2; // go to start of tag fields + constexpr size_t tiff_field_size = 12; + for (size_t entry = 0; entry < num_entry; entry++) { + // Here we just search for orientation tag and parse it + auto tag_num = get_uint16(exif_data, endianness, offset); + if (tag_num == INCORRECT_TAG) { + break; + } + if (tag_num == ORIENTATION_EXIF_TAG) { + exif_orientation = get_uint16(exif_data, endianness, offset + 8); + break; + } + offset += tiff_field_size; + } + } + return exif_orientation; +} + +#if JPEG_FOUND +inline int fetch_jpeg_exif_orientation(j_decompress_ptr cinfo) { + // Check for Exif marker APP1 + jpeg_saved_marker_ptr exif_marker = 0; + jpeg_saved_marker_ptr cmarker = cinfo->marker_list; + while (cmarker && exif_marker == 0) { + if (cmarker->marker == APP1) { + exif_marker = cmarker; + } + cmarker = cmarker->next; + } + + if (!exif_marker) { + return -1; + } + + constexpr size_t start_offset = 6; + if (exif_marker->data_length <= start_offset) { + return -1; + } + + auto* exif_data_ptr = exif_marker->data + start_offset; + auto size = exif_marker->data_length - start_offset; + + return fetch_exif_orientation(exif_data_ptr, size); +} +#else // #if JPEG_FOUND +inline int fetch_jpeg_exif_orientation(j_decompress_ptr cinfo) { + return -1; +} +#endif // #if JPEG_FOUND + +#if PNG_FOUND && defined(PNG_eXIf_SUPPORTED) +inline int fetch_png_exif_orientation(png_structp png_ptr, png_infop info_ptr) { + png_uint_32 num_exif = 0; + png_bytep exif = 0; + + // Exif info could be in info_ptr + if (png_get_valid(png_ptr, info_ptr, PNG_INFO_eXIf)) { + png_get_eXIf_1(png_ptr, info_ptr, &num_exif, &exif); + } + + if (exif && num_exif > 0) { + return fetch_exif_orientation(exif, num_exif); + } + return -1; +} +#else // #if PNG_FOUND && defined(PNG_eXIf_SUPPORTED) +inline int fetch_png_exif_orientation(png_structp png_ptr, png_infop info_ptr) { + return -1; +} +#endif // #if PNG_FOUND && defined(PNG_eXIf_SUPPORTED) + +constexpr uint16_t IMAGE_ORIENTATION_TL = 1; // normal orientation +constexpr uint16_t IMAGE_ORIENTATION_TR = 2; // needs horizontal flip +constexpr uint16_t IMAGE_ORIENTATION_BR = 3; // needs 180 rotation +constexpr uint16_t IMAGE_ORIENTATION_BL = 4; // needs vertical flip +constexpr uint16_t IMAGE_ORIENTATION_LT = + 5; // mirrored horizontal & rotate 270 CW +constexpr uint16_t IMAGE_ORIENTATION_RT = 6; // rotate 90 CW +constexpr uint16_t IMAGE_ORIENTATION_RB = + 7; // mirrored horizontal & rotate 90 CW +constexpr uint16_t IMAGE_ORIENTATION_LB = 8; // needs 270 CW rotation + +inline torch::Tensor exif_orientation_transform( + const torch::Tensor& image, + int orientation) { + if (orientation == IMAGE_ORIENTATION_TL) { + return image; + } else if (orientation == IMAGE_ORIENTATION_TR) { + return image.flip(-1); + } else if (orientation == IMAGE_ORIENTATION_BR) { + // needs 180 rotation equivalent to + // flip both horizontally and vertically + return image.flip({-2, -1}); + } else if (orientation == IMAGE_ORIENTATION_BL) { + return image.flip(-2); + } else if (orientation == IMAGE_ORIENTATION_LT) { + return image.transpose(-1, -2); + } else if (orientation == IMAGE_ORIENTATION_RT) { + return image.transpose(-1, -2).flip(-1); + } else if (orientation == IMAGE_ORIENTATION_RB) { + return image.transpose(-1, -2).flip({-2, -1}); + } else if (orientation == IMAGE_ORIENTATION_LB) { + return image.transpose(-1, -2).flip(-2); + } + return image; +} + +} // namespace exif_private +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/read_write_file.cpp b/framework/include/torchvision/io/image/cpu/read_write_file.cpp new file mode 100644 index 00000000000..def74c6721a --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/read_write_file.cpp @@ -0,0 +1,108 @@ +#include "read_write_file.h" + +#include + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#include +#endif + +namespace vision { +namespace image { + +#ifdef _WIN32 +namespace { +std::wstring utf8_decode(const std::string& str) { + if (str.empty()) { + return std::wstring(); + } + int size_needed = MultiByteToWideChar( + CP_UTF8, 0, str.c_str(), static_cast(str.size()), NULL, 0); + TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode"); + std::wstring wstrTo(size_needed, 0); + MultiByteToWideChar( + CP_UTF8, + 0, + str.c_str(), + static_cast(str.size()), + &wstrTo[0], + size_needed); + return wstrTo; +} +} // namespace +#endif + +torch::Tensor read_file(const std::string& filename) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cpu.read_write_file.read_file"); +#ifdef _WIN32 + // According to + // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019, + // we should use struct __stat64 and _wstat64 for 64-bit file size on Windows. + struct __stat64 stat_buf; + auto fileW = utf8_decode(filename); + int rc = _wstat64(fileW.c_str(), &stat_buf); +#else + struct stat stat_buf; + int rc = stat(filename.c_str(), &stat_buf); +#endif + // errno is a variable defined in errno.h + TORCH_CHECK( + rc == 0, "[Errno ", errno, "] ", strerror(errno), ": '", filename, "'"); + + int64_t size = stat_buf.st_size; + + TORCH_CHECK(size > 0, "Expected a non empty file"); + +#ifdef _WIN32 + // TODO: Once torch::from_file handles UTF-8 paths correctly, we should move + // back to use the following implementation since it uses file mapping. + // auto data = + // torch::from_file(filename, /*shared=*/false, /*size=*/size, + // torch::kU8).clone() + FILE* infile = _wfopen(fileW.c_str(), L"rb"); + + TORCH_CHECK(infile != nullptr, "Error opening input file"); + + auto data = torch::empty({size}, torch::kU8); + auto dataBytes = data.data_ptr(); + + fread(dataBytes, sizeof(uint8_t), size, infile); + fclose(infile); +#else + auto data = + torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8); +#endif + + return data; +} + +void write_file(const std::string& filename, torch::Tensor& data) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cpu.read_write_file.write_file"); + // Check that the input tensor is on CPU + TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); + + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + + // Check that the input tensor is 3-dimensional + TORCH_CHECK(data.dim() == 1, "Input data should be a 1-dimensional tensor"); + + auto fileBytes = data.data_ptr(); + auto fileCStr = filename.c_str(); +#ifdef _WIN32 + auto fileW = utf8_decode(filename); + FILE* outfile = _wfopen(fileW.c_str(), L"wb"); +#else + FILE* outfile = fopen(fileCStr, "wb"); +#endif + + TORCH_CHECK(outfile != nullptr, "Error opening output file"); + + fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile); + fclose(outfile); +} + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/read_write_file.h b/framework/include/torchvision/io/image/cpu/read_write_file.h new file mode 100644 index 00000000000..a5a712dd8e2 --- /dev/null +++ b/framework/include/torchvision/io/image/cpu/read_write_file.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor read_file(const std::string& filename); + +C10_EXPORT void write_file(const std::string& filename, torch::Tensor& data); + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.cpp b/framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.cpp new file mode 100644 index 00000000000..ee7d432f30d --- /dev/null +++ b/framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.cpp @@ -0,0 +1,208 @@ +#include "decode_jpeg_cuda.h" + +#include + +#if NVJPEG_FOUND +#include +#include +#include +#endif + +#include + +namespace vision { +namespace image { + +#if !NVJPEG_FOUND + +torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device) { + TORCH_CHECK( + false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support"); +} + +#else + +namespace { +static nvjpegHandle_t nvjpeg_handle = nullptr; +} + +torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cuda.decode_jpeg_cuda.decode_jpeg_cuda"); + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + + TORCH_CHECK( + !data.is_cuda(), + "The input tensor must be on CPU when decoding with nvjpeg") + + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + TORCH_CHECK(device.is_cuda(), "Expected a cuda device") + + int major_version; + int minor_version; + nvjpegStatus_t get_major_property_status = + nvjpegGetProperty(MAJOR_VERSION, &major_version); + nvjpegStatus_t get_minor_property_status = + nvjpegGetProperty(MINOR_VERSION, &minor_version); + + TORCH_CHECK( + get_major_property_status == NVJPEG_STATUS_SUCCESS, + "nvjpegGetProperty failed: ", + get_major_property_status); + TORCH_CHECK( + get_minor_property_status == NVJPEG_STATUS_SUCCESS, + "nvjpegGetProperty failed: ", + get_minor_property_status); + if ((major_version < 11) || ((major_version == 11) && (minor_version < 6))) { + TORCH_WARN_ONCE( + "There is a memory leak issue in the nvjpeg library for CUDA versions < 11.6. " + "Make sure to rely on CUDA 11.6 or above before using decode_jpeg(..., device='cuda')."); + } + + at::cuda::CUDAGuard device_guard(device); + + // Create global nvJPEG handle + static std::once_flag nvjpeg_handle_creation_flag; + std::call_once(nvjpeg_handle_creation_flag, []() { + if (nvjpeg_handle == nullptr) { + nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); + + if (create_status != NVJPEG_STATUS_SUCCESS) { + // Reset handle so that one can still call the function again in the + // same process if there was a failure + free(nvjpeg_handle); + nvjpeg_handle = nullptr; + } + TORCH_CHECK( + create_status == NVJPEG_STATUS_SUCCESS, + "nvjpegCreateSimple failed: ", + create_status); + } + }); + + // Create the jpeg state + nvjpegJpegState_t jpeg_state; + nvjpegStatus_t state_status = + nvjpegJpegStateCreate(nvjpeg_handle, &jpeg_state); + + TORCH_CHECK( + state_status == NVJPEG_STATUS_SUCCESS, + "nvjpegJpegStateCreate failed: ", + state_status); + + auto datap = data.data_ptr(); + + // Get the image information + int num_channels; + nvjpegChromaSubsampling_t subsampling; + int widths[NVJPEG_MAX_COMPONENT]; + int heights[NVJPEG_MAX_COMPONENT]; + nvjpegStatus_t info_status = nvjpegGetImageInfo( + nvjpeg_handle, + datap, + data.numel(), + &num_channels, + &subsampling, + widths, + heights); + + if (info_status != NVJPEG_STATUS_SUCCESS) { + nvjpegJpegStateDestroy(jpeg_state); + TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status); + } + + if (subsampling == NVJPEG_CSS_UNKNOWN) { + nvjpegJpegStateDestroy(jpeg_state); + TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling"); + } + + int width = widths[0]; + int height = heights[0]; + + nvjpegOutputFormat_t ouput_format; + int num_channels_output; + + switch (mode) { + case IMAGE_READ_MODE_UNCHANGED: + num_channels_output = num_channels; + // For some reason, setting output_format to NVJPEG_OUTPUT_UNCHANGED will + // not properly decode RGB images (it's fine for grayscale), so we set + // output_format manually here + if (num_channels == 1) { + ouput_format = NVJPEG_OUTPUT_Y; + } else if (num_channels == 3) { + ouput_format = NVJPEG_OUTPUT_RGB; + } else { + nvjpegJpegStateDestroy(jpeg_state); + TORCH_CHECK( + false, + "When mode is UNCHANGED, only 1 or 3 input channels are allowed."); + } + break; + case IMAGE_READ_MODE_GRAY: + ouput_format = NVJPEG_OUTPUT_Y; + num_channels_output = 1; + break; + case IMAGE_READ_MODE_RGB: + ouput_format = NVJPEG_OUTPUT_RGB; + num_channels_output = 3; + break; + default: + nvjpegJpegStateDestroy(jpeg_state); + TORCH_CHECK( + false, "The provided mode is not supported for JPEG decoding on GPU"); + } + + auto out_tensor = torch::empty( + {int64_t(num_channels_output), int64_t(height), int64_t(width)}, + torch::dtype(torch::kU8).device(device)); + + // nvjpegImage_t is a struct with + // - an array of pointers to each channel + // - the pitch for each channel + // which must be filled in manually + nvjpegImage_t out_image; + + for (int c = 0; c < num_channels_output; c++) { + out_image.channel[c] = out_tensor[c].data_ptr(); + out_image.pitch[c] = width; + } + for (int c = num_channels_output; c < NVJPEG_MAX_COMPONENT; c++) { + out_image.channel[c] = nullptr; + out_image.pitch[c] = 0; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()); + + nvjpegStatus_t decode_status = nvjpegDecode( + nvjpeg_handle, + jpeg_state, + datap, + data.numel(), + ouput_format, + &out_image, + stream); + + nvjpegJpegStateDestroy(jpeg_state); + + TORCH_CHECK( + decode_status == NVJPEG_STATUS_SUCCESS, + "nvjpegDecode failed: ", + decode_status); + + return out_tensor; +} + +#endif // NVJPEG_FOUND + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.h b/framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.h new file mode 100644 index 00000000000..496b355e9b7 --- /dev/null +++ b/framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include "../image_read_mode.h" + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device); + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/image.cpp b/framework/include/torchvision/io/image/image.cpp new file mode 100644 index 00000000000..53d588e4746 --- /dev/null +++ b/framework/include/torchvision/io/image/image.cpp @@ -0,0 +1,39 @@ +#include "image.h" + +#include +#ifdef USE_PYTHON +#include +#endif + +// If we are in a Windows environment, we need to define +// initialization functions for the _custom_ops extension +#ifdef USE_PYTHON +#ifdef _WIN32 +PyMODINIT_FUNC PyInit_image(void) { + // No need to do anything. + return NULL; +} +#endif +#endif // USE_PYTHON + +namespace vision { +namespace image { + +static auto registry = + torch::RegisterOperators() + .op("image::decode_png(Tensor data, int mode, bool allow_16_bits = False, bool apply_exif_orientation=False) -> Tensor", + &decode_png) + .op("image::encode_png", &encode_png) + .op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", + &decode_jpeg) + .op("image::encode_jpeg", &encode_jpeg) + .op("image::read_file", &read_file) + .op("image::write_file", &write_file) + .op("image::decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", + &decode_image) + .op("image::decode_jpeg_cuda", &decode_jpeg_cuda) + .op("image::_jpeg_version", &_jpeg_version) + .op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo); + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/io/image/image.h b/framework/include/torchvision/io/image/image.h new file mode 100644 index 00000000000..05bac44c77d --- /dev/null +++ b/framework/include/torchvision/io/image/image.h @@ -0,0 +1,9 @@ +#pragma once + +#include "cpu/decode_image.h" +#include "cpu/decode_jpeg.h" +#include "cpu/decode_png.h" +#include "cpu/encode_jpeg.h" +#include "cpu/encode_png.h" +#include "cpu/read_write_file.h" +#include "cuda/decode_jpeg_cuda.h" diff --git a/framework/include/torchvision/io/image/image_read_mode.h b/framework/include/torchvision/io/image/image_read_mode.h new file mode 100644 index 00000000000..84425265c34 --- /dev/null +++ b/framework/include/torchvision/io/image/image_read_mode.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +namespace vision { +namespace image { + +/* Should be kept in-sync with Python ImageReadMode enum */ +using ImageReadMode = int64_t; +const ImageReadMode IMAGE_READ_MODE_UNCHANGED = 0; +const ImageReadMode IMAGE_READ_MODE_GRAY = 1; +const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2; +const ImageReadMode IMAGE_READ_MODE_RGB = 3; +const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4; + +} // namespace image +} // namespace vision diff --git a/framework/include/torchvision/macros.h b/framework/include/torchvision/macros.h new file mode 100644 index 00000000000..64ca89429a9 --- /dev/null +++ b/framework/include/torchvision/macros.h @@ -0,0 +1,22 @@ +#pragma once + +#if defined(_WIN32) && !defined(TORCHVISION_BUILD_STATIC_LIBS) +#if defined(torchvision_EXPORTS) +#define VISION_API __declspec(dllexport) +#else +#define VISION_API __declspec(dllimport) +#endif +#else +#define VISION_API +#endif + +#if (defined __cpp_inline_variables) || __cplusplus >= 201703L +#define VISION_INLINE_VARIABLE inline +#else +#ifdef _MSC_VER +#define VISION_INLINE_VARIABLE __declspec(selectany) +#define HINT_MSVC_LINKER_INCLUDE_SYMBOL +#else +#define VISION_INLINE_VARIABLE __attribute__((weak)) +#endif +#endif diff --git a/framework/include/torchvision/ops/.DS_Store b/framework/include/torchvision/ops/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..ba279f808b2887b459c9e79f767d8973c56e4b0c GIT binary patch literal 6148 zcmeHKy-veG47S@2MKE+=U?Z=Pxl5?hiLomKw2{h?C>78>FM`xZVBr-Qcmy7R=iu|% z%8wExCa91t`M%59KIgux;+lwfvCm?n84*pPf};}*10r_Oo{Vhe49H=REj`d3W%)K< z^}K!I7#Wbeo8r#)^h71K?r&2TJAL8J^N>$brs+B_*JAs;UcSH1->!!JvLE^F?&DRH z^MI{rOdHzKUVcrrg_C(>X}`mHUcZCI6=jpTZ|&scaMUps$oJUIck>bQ1xq{M<)M|& ztWCbYwVUrn&gbF`I0MeWUon81Eiye-^wAk`2AqMK0r@^eP{G8oQVgFC450-8POuyV zbLk}{CNfM6D@9l!tf4>+Wot25!?7NjUt(A(YB;eqA8eV~Ius7sv40fYi4#R1odIW{ z%fMJ4XHx%{-}nFBAb)ZOoPmGE08i3Yy1 +#include + +namespace vision { +namespace ops { + +namespace { + +class DeformConv2dFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& weight, + const torch::autograd::Variable& offset, + const torch::autograd::Variable& mask, + const torch::autograd::Variable& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + at::AutoDispatchBelowADInplaceOrView g; + auto output = deform_conv2d_symint( + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); + + ctx->save_for_backward({input, weight, offset, mask, bias}); + ctx->saved_data["stride_h"] = stride_h; + ctx->saved_data["stride_w"] = stride_w; + ctx->saved_data["pad_h"] = pad_h; + ctx->saved_data["pad_w"] = pad_w; + ctx->saved_data["dilation_h"] = dilation_h; + ctx->saved_data["dilation_w"] = dilation_w; + ctx->saved_data["groups"] = groups; + ctx->saved_data["offset_groups"] = offset_groups; + ctx->saved_data["use_mask"] = use_mask; + + return { + output, + }; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + auto weight = saved[1]; + auto offset = saved[2]; + auto mask = saved[3]; + auto bias = saved[4]; + + auto stride_h = ctx->saved_data["stride_h"].toSymInt(); + auto stride_w = ctx->saved_data["stride_w"].toSymInt(); + auto pad_h = ctx->saved_data["pad_h"].toSymInt(); + auto pad_w = ctx->saved_data["pad_w"].toSymInt(); + auto dilation_h = ctx->saved_data["dilation_h"].toSymInt(); + auto dilation_w = ctx->saved_data["dilation_w"].toSymInt(); + auto groups = ctx->saved_data["groups"].toSymInt(); + auto offset_groups = ctx->saved_data["offset_groups"].toSymInt(); + auto use_mask = ctx->saved_data["use_mask"].toBool(); + + auto grads = detail::_deform_conv2d_backward_symint( + grad_output[0], + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); + auto grad_input = std::get<0>(grads); + auto grad_weight = std::get<1>(grads); + auto grad_offset = std::get<2>(grads); + auto grad_mask = std::get<3>(grads); + auto grad_bias = std::get<4>(grads); + + return { + grad_input, + grad_weight, + grad_offset, + grad_mask, + grad_bias, + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + }; + } +}; + +// TODO: There should be an easier way to do this +class DeformConv2dBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& grad, + const torch::autograd::Variable& input, + const torch::autograd::Variable& weight, + const torch::autograd::Variable& offset, + const torch::autograd::Variable& mask, + const torch::autograd::Variable& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + at::AutoDispatchBelowADInplaceOrView g; + auto result = detail::_deform_conv2d_backward_symint( + grad, + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); + + auto grad_input = std::get<0>(result); + auto grad_weight = std::get<1>(result); + auto grad_offset = std::get<2>(result); + auto grad_mask = std::get<3>(result); + auto grad_bias = std::get<4>(result); + + return { + grad_input, + grad_weight, + grad_offset, + grad_mask, + grad_bias, + }; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + TORCH_CHECK(0, "double backwards on deform_conv2d not supported"); + } +}; + +at::Tensor deform_conv2d_autograd( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + return DeformConv2dFunction::apply( + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask)[0]; +} + +std::tuple +deform_conv2d_backward_autograd( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + auto result = DeformConv2dBackwardFunction::apply( + grad, + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); + + return std::make_tuple(result[0], result[1], result[2], result[3], result[4]); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), + TORCH_FN(deform_conv2d_autograd)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), + TORCH_FN(deform_conv2d_backward_autograd)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp b/framework/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp new file mode 100644 index 00000000000..7205e9b15db --- /dev/null +++ b/framework/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp @@ -0,0 +1,167 @@ +#include "../ps_roi_align.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +class PSROIAlignFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio) { + ctx->saved_data["spatial_scale"] = spatial_scale; + ctx->saved_data["pooled_height"] = pooled_height; + ctx->saved_data["pooled_width"] = pooled_width; + ctx->saved_data["sampling_ratio"] = sampling_ratio; + ctx->saved_data["input_shape"] = input.sym_sizes(); + at::AutoDispatchBelowADInplaceOrView g; + auto result = ps_roi_align_symint( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio); + + auto output = std::get<0>(result); + auto channel_mapping = std::get<1>(result); + ctx->save_for_backward({rois, channel_mapping}); + ctx->mark_non_differentiable({channel_mapping}); + + return {output, channel_mapping}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + // Use data saved in forward + auto saved = ctx->get_saved_variables(); + auto rois = saved[0]; + auto channel_mapping = saved[1]; + auto input_shape = ctx->saved_data["input_shape"].toList(); + auto grad_in = detail::_ps_roi_align_backward_symint( + grad_output[0], + rois, + channel_mapping, + ctx->saved_data["spatial_scale"].toDouble(), + ctx->saved_data["pooled_height"].toSymInt(), + ctx->saved_data["pooled_width"].toSymInt(), + ctx->saved_data["sampling_ratio"].toInt(), + input_shape[0].get().toSymInt(), + input_shape[1].get().toSymInt(), + input_shape[2].get().toSymInt(), + input_shape[3].get().toSymInt()); + + return { + grad_in, + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable()}; + } +}; + +// TODO: There should be an easier way to do this +class PSROIAlignBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& grad, + const torch::autograd::Variable& rois, + const torch::autograd::Variable& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + at::AutoDispatchBelowADInplaceOrView g; + auto grad_in = detail::_ps_roi_align_backward_symint( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width); + + return {grad_in}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + TORCH_CHECK(0, "double backwards on ps_roi_align not supported"); + } +}; + +std::tuple ps_roi_align_autograd( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio) { + auto result = PSROIAlignFunction::apply( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); + + return std::make_tuple(result[0], result[1]); +} + +at::Tensor ps_roi_align_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + return PSROIAlignBackwardFunction::apply( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width)[0]; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), + TORCH_FN(ps_roi_align_autograd)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), + TORCH_FN(ps_roi_align_backward_autograd)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp b/framework/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp new file mode 100644 index 00000000000..39b83819f94 --- /dev/null +++ b/framework/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp @@ -0,0 +1,152 @@ +#include "../ps_roi_pool.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +class PSROIPoolFunction : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + ctx->saved_data["spatial_scale"] = spatial_scale; + ctx->saved_data["pooled_height"] = pooled_height; + ctx->saved_data["pooled_width"] = pooled_width; + ctx->saved_data["input_shape"] = input.sym_sizes(); + at::AutoDispatchBelowADInplaceOrView g; + auto result = ps_roi_pool_symint( + input, rois, spatial_scale, pooled_height, pooled_width); + + auto output = std::get<0>(result); + auto channel_mapping = std::get<1>(result); + ctx->save_for_backward({rois, channel_mapping}); + ctx->mark_non_differentiable({channel_mapping}); + + return {output, channel_mapping}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + // Use data saved in forward + auto saved = ctx->get_saved_variables(); + auto rois = saved[0]; + auto channel_mapping = saved[1]; + auto input_shape = ctx->saved_data["input_shape"].toList(); + auto grad_in = detail::_ps_roi_pool_backward_symint( + grad_output[0], + rois, + channel_mapping, + ctx->saved_data["spatial_scale"].toDouble(), + ctx->saved_data["pooled_height"].toSymInt(), + ctx->saved_data["pooled_width"].toSymInt(), + input_shape[0].get().toSymInt(), + input_shape[1].get().toSymInt(), + input_shape[2].get().toSymInt(), + input_shape[3].get().toSymInt()); + + return { + grad_in, + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable()}; + } +}; + +// TODO: There should be an easier way to do this +class PSROIPoolBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& grad, + const torch::autograd::Variable& rois, + const torch::autograd::Variable& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + at::AutoDispatchBelowADInplaceOrView g; + auto grad_in = detail::_ps_roi_pool_backward_symint( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); + + return {grad_in}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + TORCH_CHECK(0, "double backwards on ps_roi_pool not supported"); + } +}; + +std::tuple ps_roi_pool_autograd( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + auto result = PSROIPoolFunction::apply( + input, rois, spatial_scale, pooled_height, pooled_width); + + return std::make_tuple(result[0], result[1]); +} + +at::Tensor ps_roi_pool_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + return PSROIPoolBackwardFunction::apply( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width)[0]; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), + TORCH_FN(ps_roi_pool_autograd)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), + TORCH_FN(ps_roi_pool_backward_autograd)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/autograd/roi_align_kernel.cpp b/framework/include/torchvision/ops/autograd/roi_align_kernel.cpp new file mode 100644 index 00000000000..6d792fe09d9 --- /dev/null +++ b/framework/include/torchvision/ops/autograd/roi_align_kernel.cpp @@ -0,0 +1,167 @@ +#include "../roi_align.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +class ROIAlignFunction : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + bool aligned) { + ctx->saved_data["spatial_scale"] = spatial_scale; + ctx->saved_data["pooled_height"] = pooled_height; + ctx->saved_data["pooled_width"] = pooled_width; + ctx->saved_data["sampling_ratio"] = sampling_ratio; + ctx->saved_data["aligned"] = aligned; + ctx->saved_data["input_shape"] = input.sym_sizes(); + ctx->save_for_backward({rois}); + at::AutoDispatchBelowADInplaceOrView g; + auto result = roi_align_symint( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + aligned); + return {result}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + // Use data saved in forward + auto saved = ctx->get_saved_variables(); + auto rois = saved[0]; + auto input_shape = ctx->saved_data["input_shape"].toList(); + auto grad_in = detail::_roi_align_backward_symint( + grad_output[0], + rois, + ctx->saved_data["spatial_scale"].toDouble(), + ctx->saved_data["pooled_height"].toSymInt(), + ctx->saved_data["pooled_width"].toSymInt(), + input_shape[0].get().toSymInt(), + input_shape[1].get().toSymInt(), + input_shape[2].get().toSymInt(), + input_shape[3].get().toSymInt(), + ctx->saved_data["sampling_ratio"].toInt(), + ctx->saved_data["aligned"].toBool()); + return { + grad_in, + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable()}; + } +}; + +// TODO: There should be an easier way to do this +class ROIAlignBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& grad, + const torch::autograd::Variable& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width, + int64_t sampling_ratio, + bool aligned) { + at::AutoDispatchBelowADInplaceOrView g; + auto result = detail::_roi_align_backward_symint( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio, + aligned); + return {result}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + TORCH_CHECK(0, "double backwards on roi_align not supported"); + } +}; + +at::Tensor roi_align_autograd( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + bool aligned) { + return ROIAlignFunction::apply( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + aligned)[0]; +} + +at::Tensor roi_align_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width, + int64_t sampling_ratio, + bool aligned) { + return ROIAlignBackwardFunction::apply( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio, + aligned)[0]; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_align"), + TORCH_FN(roi_align_autograd)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), + TORCH_FN(roi_align_backward_autograd)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/autograd/roi_pool_kernel.cpp b/framework/include/torchvision/ops/autograd/roi_pool_kernel.cpp new file mode 100644 index 00000000000..508bafb2b1e --- /dev/null +++ b/framework/include/torchvision/ops/autograd/roi_pool_kernel.cpp @@ -0,0 +1,152 @@ +#include "../roi_pool.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +class ROIPoolFunction : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + ctx->saved_data["spatial_scale"] = spatial_scale; + ctx->saved_data["pooled_height"] = pooled_height; + ctx->saved_data["pooled_width"] = pooled_width; + ctx->saved_data["input_shape"] = input.sym_sizes(); + at::AutoDispatchBelowADInplaceOrView g; + auto result = roi_pool_symint( + input, rois, spatial_scale, pooled_height, pooled_width); + + auto output = std::get<0>(result); + auto argmax = std::get<1>(result); + ctx->save_for_backward({rois, argmax}); + ctx->mark_non_differentiable({argmax}); + + return {output, argmax}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + // Use data saved in forward + auto saved = ctx->get_saved_variables(); + auto rois = saved[0]; + auto argmax = saved[1]; + auto input_shape = ctx->saved_data["input_shape"].toList(); + auto grad_in = detail::_roi_pool_backward_symint( + grad_output[0], + rois, + argmax, + ctx->saved_data["spatial_scale"].toDouble(), + ctx->saved_data["pooled_height"].toSymInt(), + ctx->saved_data["pooled_width"].toSymInt(), + input_shape[0].get().toSymInt(), + input_shape[1].get().toSymInt(), + input_shape[2].get().toSymInt(), + input_shape[3].get().toSymInt()); + + return { + grad_in, + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable()}; + } +}; + +// TODO: There should be an easier way to do this +class ROIPoolBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& grad, + const torch::autograd::Variable& rois, + const torch::autograd::Variable& argmax, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + at::AutoDispatchBelowADInplaceOrView g; + auto grad_in = detail::_roi_pool_backward_symint( + grad, + rois, + argmax, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); + + return {grad_in}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + TORCH_CHECK(0, "double backwards on roi_pool not supported"); + } +}; + +std::tuple roi_pool_autograd( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + auto result = ROIPoolFunction::apply( + input, rois, spatial_scale, pooled_height, pooled_width); + + return std::make_tuple(result[0], result[1]); +} + +at::Tensor roi_pool_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + return ROIPoolBackwardFunction::apply( + grad, + rois, + argmax, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width)[0]; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_pool"), + TORCH_FN(roi_pool_autograd)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), + TORCH_FN(roi_pool_backward_autograd)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp b/framework/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp new file mode 100644 index 00000000000..c5e59077aa6 --- /dev/null +++ b/framework/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp @@ -0,0 +1,1172 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer + ***************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE + *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer + ********************* + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +// modified from +// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +const int kMaxParallelImgs = 32; + +template +scalar_t bilinear_interpolate( + const scalar_t* in, + int height, + int width, + scalar_t h, + scalar_t w) { + if (h <= -1 || height <= h || w <= -1 || width <= w) { + return 0; + } + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = in[h_low * width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = in[h_low * width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = in[h_high * width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = in[h_high * width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +void deformable_im2col_kernel( + int n, + const scalar_t* input, + const scalar_t* offset, + const scalar_t* mask, + int height, + int width, + int weight_h, + int weight_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int batch_sz, + int n_in_channels, + int n_offset_grps, + int out_h, + int out_w, + bool use_mask, + scalar_t* columns) { + for (int index = 0; index != n; ++index) { + const int out_x = index % out_w; + const int out_y = (index / out_w) % out_h; + const int out_b = (index / (out_w * out_h)) % batch_sz; + const int in_c = index / (out_w * out_h * batch_sz); + const int out_c = in_c * weight_h * weight_w; + + int c_per_offset_grp = n_in_channels / n_offset_grps; + const int grp_idx = in_c / c_per_offset_grp; + + auto columns_ptr = columns + + (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + + out_y * out_w + out_x); + + auto input_ptr = input + + (out_b * (n_in_channels * height * width) + in_c * (height * width)); + + auto offset_ptr = offset + + (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * + out_w; + + auto mask_ptr = mask; + if (use_mask) { + mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * + out_h * out_w; + } + + for (int i = 0; i < weight_h; ++i) { + for (int j = 0; j < weight_w; ++j) { + const int mask_idx = i * weight_w + j; + const int offset_idx = 2 * mask_idx; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = + mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; + } + + const scalar_t offset_h = + offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t offset_w = offset_ptr + [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t y = + (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = + (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + *columns_ptr = + mask_value * bilinear_interpolate(input_ptr, height, width, y, x); + columns_ptr += batch_sz * out_h * out_w; + } + } + } +} + +void deformable_im2col( + const at::Tensor& input, + const at::Tensor& data_offset, + const at::Tensor& data_mask, + int n_in_channels, + int height, + int width, + int weight_h, + int weight_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int out_h, + int out_w, + int parallel_imgs, + int deformable_group, + bool use_mask, + at::Tensor data_col) { + int num_kernels = n_in_channels * out_h * out_w * parallel_imgs; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "deformable_im2col", ([&] { + deformable_im2col_kernel( + num_kernels, + input.data_ptr(), + data_offset.data_ptr(), + data_mask.data_ptr(), + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + n_in_channels, + deformable_group, + out_h, + out_w, + use_mask, + data_col.data_ptr()); + })); +} + +int get_greatest_divisor_below_bound(int n, int bound) { + for (int k = bound; k > 1; --k) { + if (n % k == 0) { + return k; + } + } + return 1; +} + +template +void deformable_col2im_kernel( + int n, + const scalar_t* col, + const scalar_t* offset, + const scalar_t* mask, + int channels, + int height, + int width, + int kernel_h, + int kernel_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int batch_sz, + int n_offset_grps, + int out_h, + int out_w, + bool use_mask, + scalar_t* grad_im) { + for (int index = 0; index != n; ++index) { + const int out_x = index % out_w; + const int out_y = (index / out_w) % out_h; + const int b = (index / (out_w * out_h)) % batch_sz; + const int j = (index / (out_w * out_h * batch_sz)) % kernel_w; + const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; + const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); + + int c_per_offset_grp = channels / n_offset_grps; + const int offset_grp = c / c_per_offset_grp; + + auto offset_ptr = offset + + (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * + out_w; + + auto mask_ptr = mask; + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * + out_h * out_w; + } + + const int mask_idx = i * kernel_w + j; + const int offset_idx = 2 * mask_idx; + + const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; + const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; + + const scalar_t offset_h = offset_ptr[offset_h_ptr]; + const scalar_t offset_w = offset_ptr[offset_w_ptr]; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + + const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + for (int dy = -1; dy <= 1; dy++) { + for (int dx = -1; dx <= 1; dx++) { + int yp = int(y) + dy; + int xp = int(x) + dx; + if (0 <= yp && yp < height && 0 <= xp && xp < width && + std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { + int grad_pos = ((b * channels + c) * height + yp) * width + xp; + scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); + grad_im[grad_pos] += mask_value * weight * col[index]; + } + } + } + } +} + +void compute_grad_input( + const at::Tensor& columns, + const at::Tensor& offset, + const at::Tensor& mask, + int channels, + int height, + int width, + int weight_h, + int weight_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int parallel_imgs, + int n_offset_grps, + bool use_mask, + at::Tensor grad_im) { + int out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + int out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int num_kernels = + channels * weight_h * weight_w * out_h * out_w * parallel_imgs; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + columns.scalar_type(), "compute_grad_input", ([&] { + deformable_col2im_kernel( + num_kernels, + columns.data_ptr(), + offset.data_ptr(), + mask.data_ptr(), + channels, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + n_offset_grps, + out_h, + out_w, + use_mask, + grad_im.data_ptr()); + })); +} + +template +scalar_t get_coordinate_weight( + const scalar_t* im_data, + int height, + int width, + scalar_t y, + scalar_t x, + bool is_y_direction) { + int y_l = floor(y); + int x_l = floor(x); + int y_h = y_l + 1; + int x_h = x_l + 1; + + bool valid_y_l = 0 <= y_l && y_l < height; + bool valid_y_h = 0 <= y_h && y_h < height; + bool valid_x_l = 0 <= x_l && x_l < width; + bool valid_x_h = 0 <= x_h && x_h < width; + + scalar_t zero = 0; + scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero; + scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero; + scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero; + scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero; + + if (is_y_direction) { + scalar_t dx = x - x_l; + return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx); + } else { + scalar_t dy = y - y_l; + return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx); + } +} + +template +void deformable_col2im_coord_kernel( + int n, + const scalar_t* col, + const scalar_t* im, + const scalar_t* offset, + const scalar_t* mask, + int channels, + int height, + int width, + int weight_h, + int weight_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int batch_sz, + int offset_channels, + int n_offset_grps, + int out_h, + int out_w, + bool use_mask, + scalar_t* grad_offset, + scalar_t* grad_mask) { + for (int index = 0; index != n; ++index) { + scalar_t grad_offset_val = 0; + scalar_t grad_mask_val = 0; + + int w = index % out_w; + int h = (index / out_w) % out_h; + int w_w = (index / (out_w * out_h * 2)) % weight_w; + int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; + int c = (index / (out_w * out_h)) % offset_channels; + int b = index / (out_w * out_h * offset_channels); + + const int offset_grp = c / (2 * weight_h * weight_w); + const int col_step = weight_h * weight_w; + + int c_per_offset_grp = channels / n_offset_grps; + + auto col_ptr = col + + offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * + out_h; + auto im_ptr = im + + (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; + auto offset_ptr = offset + + (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * + out_w; + + auto mask_ptr = mask; + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * + out_h * out_w; + } + + const int offset_c = c - offset_grp * 2 * weight_h * weight_w; + const bool is_y_direction = offset_c % 2 == 0; + + const int c_bound = c_per_offset_grp * weight_h * weight_w; + for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { + const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w; + + int out_x = col_pos % out_w; + int out_y = (col_pos / out_w) % out_h; + int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; + int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; + + const int mask_idx = i * weight_w + j; + + const int offset_h_idx = + (((2 * mask_idx) * out_h + out_y) * out_w + out_x); + const int offset_w_idx = + (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); + const scalar_t offset_h = offset_ptr[offset_h_idx]; + const scalar_t offset_w = offset_ptr[offset_w_idx]; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + + scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + const scalar_t weight = + get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); + grad_offset_val += mask_value * weight * col_ptr[col_pos]; + + if (use_mask && is_y_direction) { + grad_mask_val += col_ptr[col_pos] * + bilinear_interpolate(im_ptr, height, width, y, x); + } + + im_ptr += height * width; + } + + grad_offset[index] = grad_offset_val; + + if (use_mask && is_y_direction) { + const int idx = + ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + + w_w) * + out_h + + h) * + out_w + + w; + grad_mask[idx] = grad_mask_val; + } + } +} + +void compute_grad_offset_and_mask( + const at::Tensor& columns, + const at::Tensor& input, + const at::Tensor& offset, + const at::Tensor& mask, + int channels, + int height, + int width, + int weight_h, + int weight_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int parallel_imgs, + int n_offset_grps, + bool use_mask, + at::Tensor grad_offset, + at::Tensor grad_mask) { + int out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + int out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int num_kernels = + out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + columns.scalar_type(), "compute_grad_offset_and_mask", ([&] { + deformable_col2im_coord_kernel( + num_kernels, + columns.data_ptr(), + input.data_ptr(), + offset.data_ptr(), + mask.data_ptr(), + channels, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + 2 * weight_h * weight_w * n_offset_grps, + n_offset_grps, + out_h, + out_w, + use_mask, + grad_offset.data_ptr(), + grad_mask.data_ptr()); + })); +} + +std::tuple backward_gradient_inputs( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor mask, + at::Tensor grad_out, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w, + int n_weight_grps, + int n_offset_grps, + int n_parallel_imgs, + bool use_mask) { + int batch_sz = input.size(0); + int n_in_channels = input.size(1); + int in_h = input.size(2); + int in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + long n_out_channels = weight.size(0); + int weight_h = weight.size(2); + int weight_w = weight.size(3); + + long out_h = + (in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + long out_w = + (in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + + auto grad_input = at::zeros_like(input); + auto grad_offset = at::zeros_like(offset); + auto grad_mask = at::zeros_like(mask); + + if (batch_sz == 0) { + return std::make_tuple(grad_input, grad_offset, grad_mask); + } + + auto columns = at::empty( + {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, + input.options()); + + // Separate into blocks + grad_input = grad_input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + input = input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + + grad_offset = grad_offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + offset = offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + grad_mask = grad_mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + mask = mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + grad_out = grad_out + .reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_weight_grps, + n_out_channels / n_weight_grps, + out_h, + out_w}) + .permute({0, 2, 3, 1, 4, 5}); + + weight = weight.reshape( + {n_weight_grps, + weight.size(0) / n_weight_grps, + weight.size(1), + weight.size(2), + weight.size(3)}); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + + for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + columns.zero_(); + // Separate into weight groups + for (int g = 0; g < n_weight_grps; g++) { + columns[g] = columns[g].addmm_( + weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); + } + + compute_grad_offset_and_mask( + columns, + input[elt], + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + grad_offset[elt], + grad_mask[elt]); + + compute_grad_input( + columns, + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + grad_input[elt]); + } + + grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); + grad_offset = grad_offset.view( + {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + if (use_mask) { + grad_mask = grad_mask.view( + {batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w}); + } + + return std::make_tuple(grad_input, grad_offset, grad_mask); +} + +at::Tensor backward_gradient_parameters( + at::Tensor input, + const at::Tensor& weight, + at::Tensor offset, + at::Tensor mask, + const at::Tensor& grad_out, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w, + int n_weight_grps, + int n_offset_grps, + int n_parallel_imgs, + bool use_mask) { + int batch_sz = input.size(0); + int n_in_channels = input.size(1); + int in_h = input.size(2); + int in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + long n_out_channels = weight.size(0); + int weight_h = weight.size(2); + int weight_w = weight.size(3); + + long out_h = grad_out.size(2); + long out_w = grad_out.size(3); + + auto grad_weight = at::zeros_like(weight); + if (batch_sz == 0) { + return grad_weight; + } + + at::Tensor grad_out_buf = grad_out + .reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_weight_grps, + n_out_channels / n_weight_grps, + out_h, + out_w}) + .permute({0, 2, 3, 1, 4, 5}) + .contiguous(); + + input = input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + + offset = offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + mask = mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + grad_weight = grad_weight.view( + {n_weight_grps, + grad_weight.size(0) / n_weight_grps, + grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3)}); + + auto columns = at::empty( + {n_weight_grps, + n_in_channels * weight_w * weight_h / n_weight_grps, + n_parallel_imgs * out_h * out_w}, + input.options()); + + for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + deformable_im2col( + input[elt], + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + columns); + + for (int g = 0; g < n_weight_grps; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_( + grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) + .view_as(grad_weight[g]); + } + } + + grad_weight = grad_weight.view( + {grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3), + grad_weight.size(4)}); + return grad_weight; +} + +at::Tensor deform_conv2d_forward_kernel( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + at::Tensor input_c = input.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + TORCH_CHECK(input_c.ndimension() == 4); + TORCH_CHECK(offset_c.ndimension() == 4); + TORCH_CHECK(!use_mask || mask_c.ndimension() == 4); + TORCH_CHECK(weight_c.ndimension() == 4); + TORCH_CHECK(input_c.device().is_cpu(), "input must be a CPU tensor"); + + int batch_sz = input_c.size(0); + int n_in_channels = input_c.size(1); + int in_h = input_c.size(2); + int in_w = input_c.size(3); + + int n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); + + // Unpack shapes and args + int out_channels = weight_c.size(0); + int weight_h = weight_c.size(2); + int weight_w = weight_c.size(3); + + int ker_h = dilation_h * (weight_h - 1) + 1; + int ker_w = dilation_w * (weight_w - 1) + 1; + int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; + + TORCH_CHECK( + weight_h > 0 && weight_w > 0, + "weight_h: ", + weight_h, + " weight_w: ", + weight_w); + TORCH_CHECK( + stride_h > 0 && stride_w > 0, + "stride_h: ", + stride_h, + " stride_w: ", + stride_w); + TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); + TORCH_CHECK( + dilation_h > 0 && dilation_w > 0, + "dilation_h: ", + dilation_h, + " dilation_w: ", + dilation_w); + + TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1)); + TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0); + TORCH_CHECK( + (offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w), + "offset.shape[1] is not valid: got: ", + offset_c.size(1), + " expected: ", + n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK( + (!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w), + "mask.shape[1] is not valid: got: ", + mask_c.size(1), + " expected: ", + n_offset_grps * weight_h * weight_w); + TORCH_CHECK(input_c.size(1) % n_offset_grps == 0); + + TORCH_CHECK( + (offset_c.size(0) == input_c.size(0)), "invalid batch size of offset"); + TORCH_CHECK( + (offset_c.size(2) == out_h && offset_c.size(3) == out_w), + "offset output dims: (", + offset_c.size(2), + ", ", + offset_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + (mask_c.size(0) == input_c.size(0)), "invalid batch size of mask"); + TORCH_CHECK( + (!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)), + "mask output dims: (", + mask_c.size(2), + ", ", + mask_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + out_h > 0 && out_w > 0, + "Calculated output size too small - out_h: ", + out_h, + " out_w: ", + out_w); + + auto out = + at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options()); + if (batch_sz == 0) { + return out; + } + + // Separate batches into blocks + out = out.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + out_channels, + out_h, + out_w}); + input_c = input_c.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + + offset_c = offset_c.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + mask_c = mask_c.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + at::Tensor out_buf = at::zeros( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs * out_h, + out_w}, + out.options()); + + // Separate channels into convolution groups + out_buf = out_buf.view( + {out_buf.size(0), + n_weight_grps, + out_buf.size(1) / n_weight_grps, + out_buf.size(2), + out_buf.size(3)}); + weight_c = weight_c.view( + {n_weight_grps, + weight_c.size(0) / n_weight_grps, + weight_c.size(1), + weight_c.size(2), + weight_c.size(3)}); + + // Sample points and perform convolution + auto columns = at::zeros( + {n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, + input_c.options()); + for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { + deformable_im2col( + input_c[b], + offset_c[b], + mask_c[b], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + columns); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + for (int g = 0; g < n_weight_grps; g++) { + out_buf[b][g] = out_buf[b][g] + .flatten(1) + .addmm_(weight_c[g].flatten(1), columns[g]) + .view_as(out_buf[b][g]); + } + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + out_buf = out_buf.view( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs, + out_h, + out_w}); + out_buf.transpose_(1, 2); + out.copy_(out_buf); + out = out.view({batch_sz, out_channels, out_h, out_w}); + + return out + bias_c.view({1, out_channels, 1, 1}); +} + +std::tuple +deform_conv2d_backward_kernel( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + at::Tensor grad_out_c = grad_out.contiguous(); + at::Tensor input_c = input.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + const int batch_sz = input_c.size(0); + const int n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); + + auto grad_input_and_offset_and_mask = backward_gradient_inputs( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + n_weight_grps, + n_offset_grps, + n_parallel_imgs, + use_mask); + + auto grad_input = std::get<0>(grad_input_and_offset_and_mask); + auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); + auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); + + auto grad_weight = backward_gradient_parameters( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + n_weight_grps, + n_offset_grps, + n_parallel_imgs, + use_mask); + + auto grad_bias = at::ones_like(bias_c) * grad_out_c.sum({0, 2, 3}); + + return std::make_tuple( + grad_input, grad_weight, grad_offset, grad_mask, grad_bias); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), + TORCH_FN(deform_conv2d_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), + TORCH_FN(deform_conv2d_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/nms_kernel.cpp b/framework/include/torchvision/ops/cpu/nms_kernel.cpp new file mode 100644 index 00000000000..50479066cbd --- /dev/null +++ b/framework/include/torchvision/ops/cpu/nms_kernel.cpp @@ -0,0 +1,117 @@ +#include +#include + +namespace vision { +namespace ops { + +namespace { + +template +at::Tensor nms_kernel_impl( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); + TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); + TORCH_CHECK( + dets.scalar_type() == scores.scalar_type(), + "dets should have the same type as scores"); + + if (dets.numel() == 0) + return at::empty({0}, dets.options().dtype(at::kLong)); + + auto x1_t = dets.select(1, 0).contiguous(); + auto y1_t = dets.select(1, 1).contiguous(); + auto x2_t = dets.select(1, 2).contiguous(); + auto y2_t = dets.select(1, 3).contiguous(); + + at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); + + auto order_t = std::get<1>( + scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + + auto ndets = dets.size(0); + at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); + at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); + + auto suppressed = suppressed_t.data_ptr(); + auto keep = keep_t.data_ptr(); + auto order = order_t.data_ptr(); + auto x1 = x1_t.data_ptr(); + auto y1 = y1_t.data_ptr(); + auto x2 = x2_t.data_ptr(); + auto y2 = y2_t.data_ptr(); + auto areas = areas_t.data_ptr(); + + int64_t num_to_keep = 0; + + for (int64_t _i = 0; _i < ndets; _i++) { + auto i = order[_i]; + if (suppressed[i] == 1) + continue; + keep[num_to_keep++] = i; + auto ix1 = x1[i]; + auto iy1 = y1[i]; + auto ix2 = x2[i]; + auto iy2 = y2[i]; + auto iarea = areas[i]; + + for (int64_t _j = _i + 1; _j < ndets; _j++) { + auto j = order[_j]; + if (suppressed[j] == 1) + continue; + auto xx1 = std::max(ix1, x1[j]); + auto yy1 = std::max(iy1, y1[j]); + auto xx2 = std::min(ix2, x2[j]); + auto yy2 = std::min(iy2, y2[j]); + + auto w = std::max(static_cast(0), xx2 - xx1); + auto h = std::max(static_cast(0), yy2 - yy1); + auto inter = w * h; + auto ovr = inter / (iarea + areas[j] - inter); + if (ovr > iou_threshold) + suppressed[j] = 1; + } + } + return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); +} + +at::Tensor nms_kernel( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 4, + "boxes should have 4 elements in dimension 1, got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)); + + auto result = at::empty({0}, dets.options()); + + AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] { + result = nms_kernel_impl(dets, scores, iou_threshold); + }); + return result; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp b/framework/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp new file mode 100644 index 00000000000..1c272427d3f --- /dev/null +++ b/framework/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp @@ -0,0 +1,429 @@ +#include +#include + +namespace vision { +namespace ops { + +namespace { + +template +T bilinear_interpolate( + const T* input, + int height, + int width, + T y, + T x, + int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +void ps_roi_align_forward_kernel_impl( + int num_rois, + const T* input, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + const T* rois, + int channels_out, + T* output, + int* channel_mapping) { + for (int n = 0; n < num_rois; n++) { + // [start, end) interval for spatial sampling + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int c_in = 0; + for (int c_out = 0; c_out < channels_out; ++c_out) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int index = + ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + + pw; + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + const T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + + T out_sum = 0; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = bilinear_interpolate( + offset_input, height, width, y, x, index); + out_sum += val; + } + } + + out_sum /= count; + output[index] = out_sum; + channel_mapping[index] = c_in; + c_in++; + } + } + } + } +} + +template +void bilinear_interpolate_gradient( + int height, + int width, + T y, + T x, + T& w1, + T& w2, + T& w3, + T& w4, + int& x_low, + int& x_high, + int& y_low, + int& y_high, + int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; +} + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void ps_roi_align_backward_kernel_impl( + int nthreads, + const T* grad_output, + const int* channel_mapping, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + int channels_out, + T* grad_input, + const T* rois) { + for (int index = 0; index < nthreads; index++) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int n = index / pooled_width / pooled_height / channels_out; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + // Force too small ROIs to be 1x1 + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + int c_in = channel_mapping[index]; + T* grad_input_offset = + grad_input + (roi_batch_ind * channels + c_in) * height * width; + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + const T grad_output_this_bin = grad_output[index]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + add(grad_input_offset + y_low * width + x_low, g1); + add(grad_input_offset + y_low * width + x_high, g2); + add(grad_input_offset + y_high * width + x_low, g3); + add(grad_input_offset + y_high * width + x_high, g4); + } // if + } // ix + } // iy + } +} + +std::tuple ps_roi_align_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { + // Check if input tensors are CPU tensors + TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + TORCH_CHECK( + rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_align_forward_kernel"; + at::checkAllSameType(c, {input_t, rois_t}); + + int num_rois = rois.size(0); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + TORCH_CHECK( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros( + {num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = + at::zeros(output.sizes(), input.options().dtype(at::kInt)); + + if (output.numel() == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "ps_roi_align_forward_kernel", [&] { + ps_roi_align_forward_kernel_impl( + num_rois, + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois_.data_ptr(), + channels_out, + output.data_ptr(), + channel_mapping.data_ptr()); + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor ps_roi_align_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + // Check if input tensors are CPU tensors + TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + TORCH_CHECK( + channel_mapping.device().is_cpu(), + "channel_mapping must be a CPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "ps_roi_align_backward_kernel"; + at::checkAllSameType(c, {grad_t, rois_t}); + + auto grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + int channels_out = channels / (pooled_height * pooled_width); + + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "ps_roi_align_backward_kernel", [&] { + ps_roi_align_backward_kernel_impl( + grad.numel(), + grad_.data_ptr(), + channel_mapping.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + channels_out, + grad_input.data_ptr(), + rois_.data_ptr()); + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), + TORCH_FN(ps_roi_align_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), + TORCH_FN(ps_roi_align_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp b/framework/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp new file mode 100644 index 00000000000..607cbe4bab6 --- /dev/null +++ b/framework/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp @@ -0,0 +1,273 @@ +#include +#include + +namespace vision { +namespace ops { + +namespace { + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void ps_roi_pool_forward_kernel_impl( + const T* input, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + const T* rois, + int channels_out, + int num_rois, + T* output, + int* channel_mapping) { + for (int n = 0; n < num_rois; ++n) { + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = round(offset_rois[1] * spatial_scale); + int roi_start_h = round(offset_rois[2] * spatial_scale); + int roi_end_w = round(offset_rois[3] * spatial_scale); + int roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + int roi_width = std::max(roi_end_w - roi_start_w, 1); + int roi_height = std::max(roi_end_h - roi_start_h, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int c_in = 0; + for (int c_out = 0; c_out < channels_out; ++c_out) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = + static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = + static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), height - 1); + hend = std::min(std::max(hend + roi_start_h, 0), height - 1); + wstart = std::min(std::max(wstart + roi_start_w, 0), width - 1); + wend = std::min(std::max(wend + roi_start_w, 0), width - 1); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + const T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + + T out_sum = 0; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * width + w; + out_sum += offset_input[input_index]; + } + } + + int index = + ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + + pw; + T bin_area = (hend - hstart) * (wend - wstart); + output[index] = is_empty ? static_cast(0) : out_sum / bin_area; + channel_mapping[index] = c_in; + c_in++; + } + } + } + } +} + +template +void ps_roi_pool_backward_kernel_impl( + const T* grad_output, + const int* channel_mapping, + int num_rois, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int channels_out, + T* grad_input, + const T* rois) { + for (int n = 0; n < num_rois; ++n) { + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = roundf(offset_rois[1] * spatial_scale); + int roi_start_h = roundf(offset_rois[2] * spatial_scale); + int roi_end_w = roundf(offset_rois[3] * spatial_scale); + int roi_end_h = roundf(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + int roi_width = std::max(roi_end_w - roi_start_w, 1); + int roi_height = std::max(roi_end_h - roi_start_h, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), height); + hend = std::min(std::max(hend + roi_start_h, 0), height); + wstart = std::min(std::max(wstart + roi_start_w, 0), width); + wend = std::min(std::max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + for (int c_out = 0; c_out < channels_out; ++c_out) { + int index = + ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + + pw; + int c_in = channel_mapping[index]; + + T* grad_input_offset = + grad_input + (roi_batch_ind * channels + c_in) * height * width; + T bin_area = (hend - hstart) * (wend - wstart); + T diff_val = + is_empty ? static_cast(0) : grad_output[index] / bin_area; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int grad_input_index = h * width + w; + add(grad_input_offset + grad_input_index, diff_val); + } + } + } + } + } + } +} + +std::tuple ps_roi_pool_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + // Check if input tensors are CPU tensors + TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + TORCH_CHECK( + rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_pool_forward_kernel"; + at::checkAllSameType(c, {input_t, rois_t}); + + int num_rois = rois.size(0); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + TORCH_CHECK( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros( + {num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = + at::zeros(output.sizes(), input.options().dtype(at::kInt)); + + auto output_size = output.numel(); + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "ps_roi_pool_forward_kernel", [&] { + ps_roi_pool_forward_kernel_impl( + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.data_ptr(), + channels_out, + num_rois, + output.data_ptr(), + channel_mapping.data_ptr()); + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor ps_roi_pool_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + // Check if input tensors are CPU tensors + TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + TORCH_CHECK( + channel_mapping.device().is_cpu(), + "channel_mapping must be a CPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "ps_roi_pool_backward_kernel"; + at::checkAllSameType(c, {grad_t, rois_t}); + + auto num_rois = rois.size(0); + auto grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + int channels_out = channels / (pooled_height * pooled_width); + + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] { + ps_roi_pool_backward_kernel_impl( + grad_.data_ptr(), + channel_mapping.data_ptr(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + channels_out, + grad_input.data_ptr(), + rois_.data_ptr()); + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), + TORCH_FN(ps_roi_pool_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), + TORCH_FN(ps_roi_pool_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/roi_align_common.h b/framework/include/torchvision/ops/cpu/roi_align_common.h new file mode 100644 index 00000000000..e10c67b5b79 --- /dev/null +++ b/framework/include/torchvision/ops/cpu/roi_align_common.h @@ -0,0 +1,128 @@ +#pragma once + +#include + +namespace vision { +namespace ops { +namespace detail { + +template +struct PreCalc { + int pos1; + int pos2; + int pos3; + int pos4; + T w1; + T w2; + T w3; + T w4; +}; + +// This helper computes the interpolation weights (w1, w2...) for every sampling +// point of a given box. There are pool_height * pool_width * roi_bin_grid_h * +// roi_bin_grid_w such sampling points. +// +// The weights (w1, w2...) are computed as the areas in this figure: +// https://en.wikipedia.org/wiki/Bilinear_interpolation#/media/File:Bilinear_interpolation_visualisation.svg +// and pos1, pos2 etc correspond to the indices of their respective pixels. +// +// Note: the weights and indices are shared across all channels, which is why +// they are pre-calculated prior to the main loop in the RoIAlign kernel. +// implementation taken from Caffe2 +template +void pre_calc_for_bilinear_interpolate( + int height, + int width, + int pooled_height, + int pooled_width, + T roi_start_h, + T roi_start_w, + T bin_size_h, + T bin_size_w, + int roi_bin_grid_h, + int roi_bin_grid_w, + std::vector>& pre_calc) { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T x = xx; + T y = yy; + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + PreCalc pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc[pre_calc_index] = pc; + pre_calc_index += 1; + continue; + } + + if (y <= 0) { + y = 0; + } + if (x <= 0) { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // save weights and indices + PreCalc pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc[pre_calc_index] = pc; + + pre_calc_index += 1; + } + } + } + } +} + +} // namespace detail +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/roi_align_kernel.cpp b/framework/include/torchvision/ops/cpu/roi_align_kernel.cpp new file mode 100644 index 00000000000..b787de6f6bb --- /dev/null +++ b/framework/include/torchvision/ops/cpu/roi_align_kernel.cpp @@ -0,0 +1,400 @@ +#include +#include + +#include "./roi_align_common.h" + +namespace vision { +namespace ops { + +namespace { + +template +void roi_align_forward_kernel_impl( + int n_rois, + const T* input, + const T& spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + bool aligned, + const T* rois, + T* output) { + // (n, c, ph, pw) is an element in the pooled output + // can be parallelized using omp + // #pragma omp parallel for num_threads(32) + for (int n = 0; n < n_rois; n++) { + int index_n = n * channels * pooled_width * pooled_height; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + // When the grid is empty, output zeros. + const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + // we want to precalculate indices and weights shared by all channels, + // this is the key point of optimization + std::vector> pre_calc( + roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); + detail::pre_calc_for_bilinear_interpolate( + height, + width, + pooled_height, + pooled_width, + roi_start_h, + roi_start_w, + bin_size_h, + bin_size_w, + roi_bin_grid_h, + roi_bin_grid_w, + pre_calc); + + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * pooled_width * pooled_height; + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + int index = index_n_c + ph * pooled_width + pw; + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + detail::PreCalc pc = pre_calc[pre_calc_index]; + output_val += pc.w1 * offset_input[pc.pos1] + + pc.w2 * offset_input[pc.pos2] + + pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4]; + + pre_calc_index += 1; + } + } + output_val /= count; // Average pooling + + output[index] = output_val; + } // for pw + } // for ph + } // for c + } // for n +} + +template +void bilinear_interpolate_gradient( + int height, + int width, + T y, + T x, + T& w1, + T& w2, + T& w3, + T& w4, + int& x_low, + int& x_high, + int& y_low, + int& y_high, + int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; +} + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void roi_align_backward_kernel_impl( + int nthreads, + const T* grad_output, + const T& spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + bool aligned, + T* grad_input, + const T* rois, + int n_stride, + int c_stride, + int h_stride, + int w_stride) { + for (int index = 0; index < nthreads; index++) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_grad_input = + grad_input + ((roi_batch_ind * channels + c) * height * width); + + int output_offset = n * n_stride + c * c_stride; + const T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = + offset_grad_output[ph * h_stride + pw * w_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + // atomic add is not needed for now since it is single threaded + add(offset_grad_input + y_low * width + x_low, static_cast(g1)); + add(offset_grad_input + y_low * width + x_high, static_cast(g2)); + add(offset_grad_input + y_high * width + x_low, static_cast(g3)); + add(offset_grad_input + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } // for +} + +at::Tensor roi_align_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { + TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_forward_kernel"; + at::checkAllSameType(c, {input_t, rois_t}); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + + if (output.numel() == 0) + return output; + + auto input_ = input.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "roi_align_forward_kernel", [&] { + roi_align_forward_kernel_impl( + num_rois, + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + rois_.data_ptr(), + output.data_ptr()); + }); + return output; +} + +at::Tensor roi_align_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned) { + TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_backward_kernel"; + at::checkAllSameType(c, {grad_t, rois_t}); + + at::Tensor grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "roi_align_backward_kernel", [&] { + roi_align_backward_kernel_impl( + grad.numel(), + grad.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + grad_input.data_ptr(), + rois_.data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_align"), + TORCH_FN(roi_align_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), + TORCH_FN(roi_align_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/roi_pool_kernel.cpp b/framework/include/torchvision/ops/cpu/roi_pool_kernel.cpp new file mode 100644 index 00000000000..b099523896a --- /dev/null +++ b/framework/include/torchvision/ops/cpu/roi_pool_kernel.cpp @@ -0,0 +1,249 @@ +#include + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void roi_pool_forward_kernel_impl( + const T* input, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + const T* rois, + int num_rois, + T* output, + int* argmax_data) { + for (int n = 0; n < num_rois; ++n) { + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = round(offset_rois[1] * spatial_scale); + int roi_start_h = round(offset_rois[2] * spatial_scale); + int roi_end_w = round(offset_rois[3] * spatial_scale); + int roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + int roi_width = std::max(roi_end_w - roi_start_w + 1, 1); + int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), height); + hend = std::min(std::max(hend + roi_start_h, 0), height); + wstart = std::min(std::max(wstart + roi_start_w, 0), width); + wend = std::min(std::max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + for (int c = 0; c < channels; ++c) { + // Define an empty pooling region to be zero + T maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + + const T* input_offset = + input + (roi_batch_ind * channels + c) * height * width; + + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * width + w; + if (input_offset[input_index] > maxval) { + maxval = input_offset[input_index]; + maxidx = input_index; + } + } + } + int index = + ((n * channels + c) * pooled_height + ph) * pooled_width + pw; + output[index] = maxval; + argmax_data[index] = maxidx; + } // channels + } // pooled_width + } // pooled_height + } // num_rois +} + +template +void roi_pool_backward_kernel_impl( + const T* grad_output, + const int* argmax_data, + int num_rois, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + T* grad_input, + const T* rois, + int n_stride, + int c_stride, + int h_stride, + int w_stride) { + for (int n = 0; n < num_rois; ++n) { + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + for (int c = 0; c < channels; ++c) { + T* grad_input_offset = + grad_input + ((roi_batch_ind * channels + c) * height * width); + const int* argmax_data_offset = + argmax_data + (n * channels + c) * pooled_height * pooled_width; + + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int output_offset = n * n_stride + c * c_stride; + int argmax = argmax_data_offset[ph * pooled_width + pw]; + + if (argmax != -1) { + add(grad_input_offset + argmax, + static_cast( + grad_output + [output_offset + ph * h_stride + pw * w_stride])); + } + } // pooled_width + } // pooled_height + } // channels + } // num_rois +} + +std::tuple roi_pool_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_pool_forward_kernel"; + at::checkAllSameType(c, {input_t, rois_t}); + + int num_rois = rois.size(0); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, + input.options().dtype(at::kInt)); + + if (output.numel() == 0) { + return std::make_tuple(output, argmax); + } + + auto input_ = input.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "roi_pool_forward_kernel", [&] { + roi_pool_forward_kernel_impl( + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.data_ptr(), + num_rois, + output.data_ptr(), + argmax.data_ptr()); + }); + return std::make_tuple(output, argmax); +} + +at::Tensor roi_pool_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + // Check if input tensors are CPU tensors + TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + TORCH_CHECK(argmax.device().is_cpu(), "argmax must be a CPU tensor"); + TORCH_CHECK( + rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_pool_backward_kernel"; + at::checkAllSameType(c, {grad_t, rois_t}); + + auto num_rois = rois.size(0); + + at::Tensor grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "roi_pool_backward_kernel", [&] { + roi_pool_backward_kernel_impl( + grad.data_ptr(), + argmax.data_ptr(), + num_rois, + channels, + height, + width, + pooled_height, + pooled_width, + grad_input.data_ptr(), + rois_.data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_pool"), + TORCH_FN(roi_pool_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), + TORCH_FN(roi_pool_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/deform_conv2d.cpp b/framework/include/torchvision/ops/deform_conv2d.cpp new file mode 100644 index 00000000000..3cda60fe0bc --- /dev/null +++ b/framework/include/torchvision/ops/deform_conv2d.cpp @@ -0,0 +1,172 @@ +#include "deform_conv2d.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +at::Tensor deform_conv2d( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::deform_conv2d", "") + .typed(); + return op.call( + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); +} + +at::Tensor deform_conv2d_symint( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::deform_conv2d", "") + .typed(); + return op.call( + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); +} + +namespace detail { + +std::tuple +_deform_conv2d_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") + .typed(); + return op.call( + grad, + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); +} + +std::tuple +_deform_conv2d_backward_symint( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") + .typed(); + return op.call( + grad, + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); +} + +} // namespace detail + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)")); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/deform_conv2d.h b/framework/include/torchvision/ops/deform_conv2d.h new file mode 100644 index 00000000000..cf1f142e648 --- /dev/null +++ b/framework/include/torchvision/ops/deform_conv2d.h @@ -0,0 +1,82 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API at::Tensor deform_conv2d( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask); + +VISION_API at::Tensor deform_conv2d_symint( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask); + +namespace detail { + +std::tuple +_deform_conv2d_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask); + +std::tuple +_deform_conv2d_backward_symint( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask); + +} // namespace detail + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/mps/mps_helpers.h b/framework/include/torchvision/ops/mps/mps_helpers.h new file mode 100644 index 00000000000..d3c0e8d94b7 --- /dev/null +++ b/framework/include/torchvision/ops/mps/mps_helpers.h @@ -0,0 +1,6 @@ +constexpr int threadsPerBlock = 512; + +template +constexpr inline T ceil_div(T n, T m) { + return (n + m - 1) / m; +} diff --git a/framework/include/torchvision/ops/mps/mps_kernels.h b/framework/include/torchvision/ops/mps/mps_kernels.h new file mode 100644 index 00000000000..e720a1608f1 --- /dev/null +++ b/framework/include/torchvision/ops/mps/mps_kernels.h @@ -0,0 +1,1102 @@ +#include + +namespace vision { +namespace ops { + +namespace mps { + +static const char* METAL_VISION = R"VISION_METAL( + +#include +#include +using namespace metal; + +/*----------Macros----------*/ + +#define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \ + for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \ + i += (tptg.x * n_tgs)) + +#define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, uint) + +/*----------Helpers--------*/ + +template +inline T ceil_div(T n, T m) { + return (n + m - 1) / m; +} + +template +inline void atomic_add_float( device T* data_ptr, const T val) +{ +#if __METAL_VERSION__ >= 300 + // atomic_float is supported in Metal 3 (macOS Ventura) onward. + device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); +#else + // Custom atomic addition implementation + // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 + // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639 + // https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide) + + // Create an atomic uint pointer for atomic transaction. + device atomic_uint* atom_var = (device atomic_uint*)data_ptr; + // Create necessary storage. + uint fetched_uint, assigning_uint; + T fetched_float, assigning_float; + + // Replace the value in atom_var with 0 and return the previous value in atom_var. + fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed); + // Read out the previous value as float. + fetched_float = *( (thread T*) &fetched_uint ); + + // Do addition and represent the addition result in uint for atomic transaction. + assigning_float = fetched_float + val; + assigning_uint = *((thread uint*) &assigning_float); + + // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr). + while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) { + // If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads. + // Try to assign 0 and get the previously assigned addition result. + uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed); + T fetched_float_again = *( (thread T*) &fetched_uint_again ); + // Re-add again + fetched_float = *((thread T*) &(fetched_uint)); + // Previously assigned addition result + addition result from other threads. + assigning_float = fetched_float_again + fetched_float; + assigning_uint = *( (thread uint*) &assigning_float); + } +#endif +} + +template +inline T bilinear_interpolate( + constant T* input, + integer_t height, + integer_t width, + T y, + T x, + uint index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + integer_t y_low = (integer_t)y; + integer_t x_low = (integer_t)x; + integer_t y_high; + integer_t x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +inline void bilinear_interpolate_gradient( + integer_t height, + integer_t width, + T y, + T x, + thread T& w1, + thread T& w2, + thread T& w3, + thread T& w4, + thread integer_t& x_low, + thread integer_t& x_high, + thread integer_t& y_low, + thread integer_t& y_high, + uint index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + y_low = (integer_t)y; + x_low = (integer_t)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; +} + +template +inline bool IoU( + constant T & a, + threadgroup T & b, + const float threshold) { + auto xx1 = max(a.x, b.x); + auto yy1 = max(a.y, b.y); + auto xx2 = min(a.z, b.z); + auto yy2 = min(a.w, b.w); + auto w = max(static_cast(0), xx2 - xx1); + auto h = max(static_cast(0), yy2 - yy1); + // Upcast to float before multiplications to circumvent precision issues in half. + auto inter = static_cast(w) * static_cast(h); + auto area_b = static_cast(b.z - b.x) * static_cast(b.w - b.y); + auto area_a = static_cast(a.z - a.x) * static_cast(a.w - a.y); + return (inter / (area_a + area_b - inter)) > threshold; +} + +/*----------Kernels----------*/ + +// This should be in sync with the one in nms_kernel.mm. +// Since metal does not support dynamic array, +// we need to make it static instead of deriving it from [[threads_per_threadgroup]]. +constant int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8; + +template +kernel void nms(constant T * dev_boxes [[buffer(0)]], + device uint64_t * mask [[buffer(1)]], + constant int64_t & n_boxes [[buffer(2)]], + constant float & iou_threshold [[buffer(3)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tid2 [[thread_position_in_threadgroup]]) { + + const uint row_start = tgid.y; + const uint col_start = tgid.x; + const uint tid = tid2.x; + const uint row_size = + min(n_boxes - row_start * nmsThreadsPerBlock, nmsThreadsPerBlock); + const uint col_size = + min(n_boxes - col_start * nmsThreadsPerBlock, nmsThreadsPerBlock); + + threadgroup T block_boxes[nmsThreadsPerBlock]; + block_boxes[tid] = dev_boxes[nmsThreadsPerBlock * col_start + tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid < row_size) { + const uint cur_box_idx = nmsThreadsPerBlock * row_start + tid; + uint64_t t = 0; + uint start = 0; + + if (row_start == col_start) { + start = tid + 1; + } + + for (uint i = start; i < col_size; i++){ + if (IoU(dev_boxes[cur_box_idx], block_boxes[i], iou_threshold)){ + t |= static_cast(1) << i; // discard 1 keep 0 + } + } + const uint col_blocks = ceil_div(n_boxes, nmsThreadsPerBlock); + mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +#define REGISTER_NMS_OP(DTYPE) \ +template \ +[[host_name("nms_" #DTYPE)]] \ +kernel void nms( \ + constant DTYPE ## 4 * dev_boxes [[buffer(0)]], \ + device uint64_t * mask [[buffer(1)]], \ + constant int64_t & n_boxes [[buffer(2)]], \ + constant float & iou_threshold [[buffer(3)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void roi_align( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + constant int64_t & output_size [[buffer(3)]], + constant int64_t & channels [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], + constant int64_t & pooled_height [[buffer(7)]], + constant int64_t & pooled_width [[buffer(8)]], + constant int64_t & sampling_ratio [[buffer(9)]], + constant bool & aligned [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + constant T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + integer_t roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + integer_t roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + // When the grid is empty, output zeros. + const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast(1)); // e.g. = 4 + + T output_val = 0.; + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + output_val += val; + } + } + output_val /= count; + + output[index] = output_val; + } +} + +#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_align_" #DTYPE)]] \ +kernel void roi_align( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + constant int64_t & output_size [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & pooled_height [[buffer(7)]], \ + constant int64_t & pooled_width [[buffer(8)]], \ + constant int64_t & sampling_ratio [[buffer(9)]], \ + constant bool & aligned [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void roi_align_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * grad_input [[buffer(2)]], + constant int64_t & output_size [[buffer(3)]], + constant int64_t & channels [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], + constant int64_t & pooled_height [[buffer(7)]], + constant int64_t & pooled_width [[buffer(8)]], + constant int64_t & sampling_ratio [[buffer(9)]], + constant bool & aligned [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + constant int64_t & n_stride [[buffer(12)]], + constant int64_t & c_stride [[buffer(13)]], + constant int64_t & h_stride [[buffer(14)]], + constant int64_t & w_stride [[buffer(15)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We need to index the gradient using the tensor strides to access the + // correct values. + const integer_t output_offset = n * n_stride + c * c_stride; + constant T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = + offset_grad_output[ph * h_stride + pw * w_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + integer_t roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + integer_t roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + const integer_t input_offset = (roi_batch_ind * channels + c) * height * width; + + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + integer_t x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomic_add_float(grad_input + input_offset + y_low * width + x_low, static_cast(g1)); + atomic_add_float(grad_input + input_offset + y_low * width + x_high, static_cast(g2)); + atomic_add_float(grad_input + input_offset + y_high * width + x_low, static_cast(g3)); + atomic_add_float(grad_input + input_offset + y_high * width + x_high, static_cast(g4)); + + } // if + } // ix + } // iy + } // MPS_1D_KERNEL_LOOP +} + +#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_align_backward_" #DTYPE)]] \ +kernel void roi_align_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * grad_input [[buffer(2)]], \ + constant int64_t & output_size [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & pooled_height [[buffer(7)]], \ + constant int64_t & pooled_width [[buffer(8)]], \ + constant int64_t & sampling_ratio [[buffer(9)]], \ + constant bool & aligned [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + constant int64_t & n_stride [[buffer(12)]], \ + constant int64_t & c_stride [[buffer(13)]], \ + constant int64_t & h_stride [[buffer(14)]], \ + constant int64_t & w_stride [[buffer(15)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void roi_pool( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * argmax [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant float & spatial_scale [[buffer(10)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + integer_t roi_start_w = round(offset_rois[1] * spatial_scale); + integer_t roi_start_h = round(offset_rois[2] * spatial_scale); + integer_t roi_end_w = round(offset_rois[3] * spatial_scale); + integer_t roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + integer_t roi_width = max(roi_end_w - roi_start_w + 1, static_cast(1)); + integer_t roi_height = max(roi_end_h - roi_start_h + 1, static_cast(1)); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height)); + hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height)); + wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width)); + wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width)); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + T maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + integer_t maxidx = -1; + constant T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + for (integer_t h = hstart; h < hend; ++h) { + for (integer_t w = wstart; w < wend; ++w) { + integer_t input_index = h * width + w; + if (offset_input[input_index] > maxval) { + maxval = offset_input[input_index]; + maxidx = input_index; + } + } + } + output[index] = maxval; + argmax[index] = maxidx; + } +} + +#define REGISTER_ROI_POOL_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_pool_" #DTYPE)]] \ +kernel void roi_pool( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * argmax_data [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant float & spatial_scale [[buffer(10)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void roi_pool_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * argmax_data [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant float & spatial_scale [[buffer(10)]], + constant int64_t & n_stride [[buffer(11)]], + constant int64_t & c_stride [[buffer(12)]], + constant int64_t & h_stride [[buffer(13)]], + constant int64_t & w_stride [[buffer(14)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + const integer_t output_offset = n * n_stride + c * c_stride; + constant integer_t * argmax_data_offset = + argmax_data + (n * channels + c) * pooled_height * pooled_width; + const integer_t argmax = argmax_data_offset[ph * pooled_width + pw]; + const integer_t offset = (roi_batch_ind * channels + c) * height * width; + + if (argmax != -1) { + atomic_add_float(grad_input + offset + argmax, static_cast(grad_output[output_offset + ph * h_stride + pw * w_stride])); + } + + } // MPS_1D_KERNEL_LOOP +} + +#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_pool_backward_" #DTYPE)]] \ +kernel void roi_pool_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * argmax_data [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant float & spatial_scale [[buffer(10)]], \ + constant int64_t & n_stride [[buffer(11)]], \ + constant int64_t & c_stride [[buffer(12)]], \ + constant int64_t & h_stride [[buffer(13)]], \ + constant int64_t & w_stride [[buffer(14)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void ps_roi_align( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * channel_mapping [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & sampling_ratio [[buffer(10)]], + constant int64_t & channels_out [[buffer(11)]], + constant float & spatial_scale [[buffer(12)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c_out, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c_out = (index / pooled_width / pooled_height) % channels_out; + integer_t n = index / pooled_width / pooled_height / channels_out; + + // (n, c_in, ph, pw) is the associated element in the input + integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw; + + // [start, end) interval for spatial sampling + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + // We use roi_bin_grid to sample the grid and mimic integral + integer_t roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + integer_t roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + constant T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + T out_sum = 0; + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + out_sum += val; + } + } + + out_sum /= count; + output[index] = out_sum; + channel_mapping[index] = c_in; + } +} + +#define REGISTER_PS_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_align_" #DTYPE)]] \ +kernel void ps_roi_align( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * channel_mapping [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & sampling_ratio [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void ps_roi_align_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * channel_mapping [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & sampling_ratio [[buffer(10)]], + constant int64_t & channels_out [[buffer(11)]], + constant float & spatial_scale [[buffer(12)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, *, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t n = index / pooled_width / pooled_height / channels_out; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + // Force too small ROIs to be 1x1 + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + integer_t c_in = channel_mapping[index]; + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + const T grad_output_this_bin = grad_output[index]; + + // We use roi_bin_grid to sample the grid and mimic integral + integer_t roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + integer_t roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + const integer_t offset = (roi_batch_ind * channels + c_in) * height * width; + + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + integer_t x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomic_add_float(grad_input + offset + y_low * width + x_low, static_cast(g1)); + atomic_add_float(grad_input + offset + y_low * width + x_high, static_cast(g2)); + atomic_add_float(grad_input + offset + y_high * width + x_low, static_cast(g3)); + atomic_add_float(grad_input + offset + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } +} + +#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_align_backward_" #DTYPE)]] \ +kernel void ps_roi_align_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * channel_mapping [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & sampling_ratio [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void ps_roi_pool( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * channel_mapping [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & channels_out [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c_out, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c_out = (index / (pooled_width * pooled_height)) % channels_out; + integer_t n = index / pooled_width / pooled_height / channels_out; + + // (n, c_in, ph, pw) is the associated element in the input + integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw; + + // [start, end) interval for spatial sampling + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + integer_t roi_start_w = round(offset_rois[1] * spatial_scale); + integer_t roi_start_h = round(offset_rois[2] * spatial_scale); + integer_t roi_end_w = round(offset_rois[3] * spatial_scale); + integer_t roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + integer_t roi_width = max(roi_end_w - roi_start_w, static_cast(1)); + integer_t roi_height = max(roi_end_h - roi_start_h, static_cast(1)); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height - 1)); + hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height - 1)); + wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width - 1)); + wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width - 1)); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + constant T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + T out_sum = 0; + for (integer_t h = hstart; h < hend; ++h) { + for (integer_t w = wstart; w < wend; ++w) { + integer_t input_index = h * width + w; + out_sum += offset_input[input_index]; + } + } + + T bin_area = (hend - hstart) * (wend - wstart); + output[index] = is_empty ? static_cast(0) : out_sum / bin_area; + channel_mapping[index] = c_in; + } +} + +#define REGISTER_PS_ROI_POOL_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_pool_" #DTYPE)]] \ +kernel void ps_roi_pool( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * channel_mapping [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & channels_out [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void ps_roi_pool_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * channel_mapping [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & channels_out [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, *, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t n = index / pooled_width / pooled_height / channels_out; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + integer_t roi_start_w = round(offset_rois[1] * spatial_scale); + integer_t roi_start_h = round(offset_rois[2] * spatial_scale); + integer_t roi_end_w = round(offset_rois[3] * spatial_scale); + integer_t roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + integer_t roi_width = max(roi_end_w - roi_start_w, static_cast(1)); + integer_t roi_height = max(roi_end_h - roi_start_h, static_cast(1)); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height)); + hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height)); + wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width)); + wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width)); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + integer_t c_in = channel_mapping[index]; + T bin_area = (hend - hstart) * (wend - wstart); + T diff_val = is_empty ? static_cast(0) : grad_output[index] / bin_area; + + const integer_t offset = (roi_batch_ind * channels + c_in) * height * width; + + for (integer_t h = hstart; h < hend; ++h) { + for (integer_t w = wstart; w < wend; ++w) { + integer_t grad_input_index = h * width + w; + atomic_add_float(grad_input + offset + grad_input_index, diff_val); + } + } + + } // MPS_1D_KERNEL_LOOP +} + +#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_pool_backward_" #DTYPE)]] \ +kernel void ps_roi_pool_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * channel_mapping [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & channels_out [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +REGISTER_NMS_OP(float); +REGISTER_NMS_OP(half); +REGISTER_ROI_ALIGN_OP(float, int64_t); +REGISTER_ROI_ALIGN_OP(half, int64_t); +REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t); +REGISTER_ROI_ALIGN_BACKWARD_OP(half, int64_t); +REGISTER_ROI_POOL_OP(float, int64_t); +REGISTER_ROI_POOL_OP(half, int64_t); +REGISTER_ROI_POOL_BACKWARD_OP(float, int64_t); +REGISTER_ROI_POOL_BACKWARD_OP(half, int64_t); +REGISTER_PS_ROI_ALIGN_OP(float, int64_t); +REGISTER_PS_ROI_ALIGN_OP(half, int64_t); +REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float, int64_t); +REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half, int64_t); +REGISTER_PS_ROI_POOL_OP(float, int64_t); +REGISTER_PS_ROI_POOL_OP(half, int64_t); +REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t); +REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t); + +)VISION_METAL"; + +static id compileVisionOpsLibrary(id device) { + static id visionLibrary = nil; + if (visionLibrary) { + return visionLibrary; + } + + NSError* error = nil; + MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion:MTLLanguageVersion2_3]; + visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding] + options:options + error:&error]; + TORCH_CHECK(visionLibrary, "Failed to create metal vision library, error: ", [[error description] UTF8String]); + return visionLibrary; +} + +static id visionPipelineState(id device, const std::string& kernel) { + static std::unordered_map> psoCache; + id pso = psoCache[kernel]; + if (pso) { + return pso; + } + + NSError* error = nil; + id visionLib = compileVisionOpsLibrary(device); + id visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]]; + TORCH_CHECK(visionFunc, "Failed to create function state object for: ", kernel); + pso = [device newComputePipelineStateWithFunction:visionFunc error:&error]; + TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); + + psoCache[kernel] = pso; + return pso; +} + +} // namespace mps +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/mps/nms_kernel.mm b/framework/include/torchvision/ops/mps/nms_kernel.mm new file mode 100644 index 00000000000..5ee9b5cbeae --- /dev/null +++ b/framework/include/torchvision/ops/mps/nms_kernel.mm @@ -0,0 +1,109 @@ +#include +#include +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +// This should be in sync with `nmsThreadsPerBlock` in the metal kernel. +constexpr int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8; + +at::Tensor nms_kernel(const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { + using namespace at::native::mps; + TORCH_CHECK(dets.is_mps(), "dets must be a MPS tensor"); + TORCH_CHECK(scores.is_mps(), "scores must be a MPS tensor"); + + TORCH_CHECK(dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK(dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1)); + TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D"); + TORCH_CHECK(dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)) + + if (dets.numel() == 0) { + return at::empty({0}, dets.options().dtype(at::kLong)); + } + + auto order_t = std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + auto dets_sorted = dets.index_select(0, order_t).contiguous(); + int64_t dets_num = dets.size(0); + float iou_threshold_f = static_cast(iou_threshold); + + const int col_blocks = (dets_num + nmsThreadsPerBlock - 1) / nmsThreadsPerBlock; + at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + + id inputBuffer = getMTLBufferStorage(dets_sorted); + id outputBuffer = getMTLBufferStorage(mask); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(col_blocks, col_blocks, 1); + + const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {dets, scores}); + + [computeEncoder setComputePipelineState:visionPSO]; + [computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0]; + [computeEncoder setBuffer:outputBuffer offset:mask.storage_offset() * mask.element_size() atIndex:1]; + [computeEncoder setBytes:&dets_num length:sizeof(int64_t) atIndex:2]; + [computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > nmsThreadsPerBlock) { + tgSize = nmsThreadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + + int64_t num_to_keep = 0; + + at::Tensor mask_cpu = mask.to(at::kCPU); + unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr(); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); + + at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); + int64_t* keep_out = keep.data_ptr(); + + for (int64_t i = 0; i < dets_num; i++) { + int64_t nblock = i / nmsThreadsPerBlock; + int64_t inblock = i % nmsThreadsPerBlock; + + if (!(remv[nblock] & (1ULL << inblock))) { + keep_out[num_to_keep++] = i; + unsigned long long* p = mask_host + i * col_blocks; + for (int64_t j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } + } + + return order_t.index( + {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(order_t.device(), keep.scalar_type())}); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/mps/ps_roi_align_kernel.mm b/framework/include/torchvision/ops/mps/ps_roi_align_kernel.mm new file mode 100644 index 00000000000..16b711ad5ef --- /dev/null +++ b/framework/include/torchvision/ops/mps/ps_roi_align_kernel.mm @@ -0,0 +1,205 @@ +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +std::tuple ps_roi_align_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_align_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + TORCH_CHECK(channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + + int64_t channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong)); + + int64_t output_size = output.numel(); + + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor ps_roi_align_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_align backward with float16 inputs."); + TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); + + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "ps_roi_align_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t output_size = grad.numel(); + int64_t channels_out = channels / (pooled_height * pooled_width); + + at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel"); + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad_); + id roisBuffer = getMTLBufferStorage(rois_); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), TORCH_FN(ps_roi_align_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), TORCH_FN(ps_roi_align_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/mps/ps_roi_pool_kernel.mm b/framework/include/torchvision/ops/mps/ps_roi_pool_kernel.mm new file mode 100644 index 00000000000..fc24f6990fa --- /dev/null +++ b/framework/include/torchvision/ops/mps/ps_roi_pool_kernel.mm @@ -0,0 +1,200 @@ +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +std::tuple ps_roi_pool_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_pool_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + TORCH_CHECK(channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int64_t channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong)); + auto output_size = output.numel(); + + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "ps_roi_pool_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor ps_roi_pool_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_pool backward with float16 inputs."); + TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "ps_roi_pool_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + auto num_rois = rois.size(0); + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t channels_out = channels / (pooled_height * pooled_width); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("ps_roi_pool_backward_kernel"); + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad_); + id roisBuffer = getMTLBufferStorage(rois_); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "ps_roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad_, rois_, channel_mapping}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), TORCH_FN(ps_roi_pool_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), TORCH_FN(ps_roi_pool_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/mps/roi_align_kernel.mm b/framework/include/torchvision/ops/mps/roi_align_kernel.mm new file mode 100644 index 00000000000..d4ed8b43fd2 --- /dev/null +++ b/framework/include/torchvision/ops/mps/roi_align_kernel.mm @@ -0,0 +1,197 @@ +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +at::Tensor roi_align_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + + int64_t output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return output; + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return output; +} + +at::Tensor roi_align_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned) { + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_align backward with float16 inputs."); + + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t n_stride = grad.stride(0); + int64_t c_stride = grad.stride(1); + int64_t h_stride = grad.stride(2); + int64_t w_stride = grad.stride(3); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("roi_align_backward_kernel"); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_align"), TORCH_FN(roi_align_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), TORCH_FN(roi_align_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/mps/roi_pool_kernel.mm b/framework/include/torchvision/ops/mps/roi_pool_kernel.mm new file mode 100644 index 00000000000..816d8d70863 --- /dev/null +++ b/framework/include/torchvision/ops/mps/roi_pool_kernel.mm @@ -0,0 +1,196 @@ +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +std::tuple roi_pool_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_pool_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kLong)); + + int64_t output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return std::make_tuple(output, argmax); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id argmaxBuffer = getMTLBufferStorage(argmax); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "roi_pool_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:argmaxBuffer offset:argmax.storage_offset() * argmax.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return std::make_tuple(output, argmax); +} + +at::Tensor roi_pool_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_pool backward with float16 inputs."); + TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor"); + + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3}; + + at::CheckedFrom c = "roi_pool_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t n_stride = grad.stride(0); + int64_t c_stride = grad.stride(1); + int64_t h_stride = grad.stride(2); + int64_t w_stride = grad.stride(3); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("roi_pool_backward_kernel"); + auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad); + id roisBuffer = getMTLBufferStorage(rois_); + id argmaxBuffer = getMTLBufferStorage(argmax_); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_, argmax_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; + [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_pool"), TORCH_FN(roi_pool_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), TORCH_FN(roi_pool_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/nms.cpp b/framework/include/torchvision/ops/nms.cpp new file mode 100644 index 00000000000..07a934bce5a --- /dev/null +++ b/framework/include/torchvision/ops/nms.cpp @@ -0,0 +1,27 @@ +#include "nms.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +at::Tensor nms( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.nms.nms"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::nms", "") + .typed(); + return op.call(dets, scores, iou_threshold); +} + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/nms.h b/framework/include/torchvision/ops/nms.h new file mode 100644 index 00000000000..8c75a242bff --- /dev/null +++ b/framework/include/torchvision/ops/nms.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API at::Tensor nms( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold); + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/ops.h b/framework/include/torchvision/ops/ops.h new file mode 100644 index 00000000000..77995e44197 --- /dev/null +++ b/framework/include/torchvision/ops/ops.h @@ -0,0 +1,8 @@ +#pragma once + +#include "deform_conv2d.h" +#include "nms.h" +#include "ps_roi_align.h" +#include "ps_roi_pool.h" +#include "roi_align.h" +#include "roi_pool.h" diff --git a/framework/include/torchvision/ops/ps_roi_align.cpp b/framework/include/torchvision/ops/ps_roi_align.cpp new file mode 100644 index 00000000000..de458c0d62d --- /dev/null +++ b/framework/include/torchvision/ops/ps_roi_align.cpp @@ -0,0 +1,112 @@ +#include "ps_roi_align.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +std::tuple ps_roi_align( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_align", "") + .typed(); + return op.call( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + +std::tuple ps_roi_align_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_align", "") + .typed(); + return op.call( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + +namespace detail { + +at::Tensor _ps_roi_align_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") + .typed(); + return op.call( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width); +} + +at::Tensor _ps_roi_align_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") + .typed(); + return op.call( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width); +} + +} // namespace detail + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio) -> (Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/ps_roi_align.h b/framework/include/torchvision/ops/ps_roi_align.h new file mode 100644 index 00000000000..75650586bc6 --- /dev/null +++ b/framework/include/torchvision/ops/ps_roi_align.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API std::tuple ps_roi_align( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio); + +VISION_API std::tuple ps_roi_align_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio); + +namespace detail { + +at::Tensor _ps_roi_align_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); + +at::Tensor _ps_roi_align_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width); + +} // namespace detail + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/ps_roi_pool.cpp b/framework/include/torchvision/ops/ps_roi_pool.cpp new file mode 100644 index 00000000000..92469d5e380 --- /dev/null +++ b/framework/include/torchvision/ops/ps_roi_pool.cpp @@ -0,0 +1,104 @@ +#include "ps_roi_pool.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +std::tuple ps_roi_pool( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + +std::tuple ps_roi_pool_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + +namespace detail { + +at::Tensor _ps_roi_pool_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") + .typed(); + return op.call( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +} + +at::Tensor _ps_roi_pool_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") + .typed(); + return op.call( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +} + +} // namespace detail + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/ps_roi_pool.h b/framework/include/torchvision/ops/ps_roi_pool.h new file mode 100644 index 00000000000..4a3cc54e0e5 --- /dev/null +++ b/framework/include/torchvision/ops/ps_roi_pool.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API std::tuple ps_roi_pool( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +VISION_API std::tuple ps_roi_pool_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width); + +namespace detail { + +at::Tensor _ps_roi_pool_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); + +at::Tensor _ps_roi_pool_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width); + +} // namespace detail + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/roi_align.cpp b/framework/include/torchvision/ops/roi_align.cpp new file mode 100644 index 00000000000..aa6dccb44f2 --- /dev/null +++ b/framework/include/torchvision/ops/roi_align.cpp @@ -0,0 +1,132 @@ +#include "roi_align.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +at::Tensor roi_align( + const at::Tensor& input, // Input feature map. + const at::Tensor& rois, // List of ROIs to pool over. + double spatial_scale, // The scale of the image features. ROIs will be + // scaled to this. + int64_t pooled_height, // The height of the pooled feature map. + int64_t pooled_width, // The width of the pooled feature + int64_t sampling_ratio, // The number of points to sample in each bin + bool aligned) // The flag for pixel shift +// along each axis. +{ + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::roi_align", "") + .typed(); + return op.call( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + aligned); +} + +at::Tensor roi_align_symint( + const at::Tensor& input, // Input feature map. + const at::Tensor& rois, // List of ROIs to pool over. + double spatial_scale, // The scale of the image features. ROIs will be + // scaled to this. + c10::SymInt pooled_height, // The height of the pooled feature map. + c10::SymInt pooled_width, // The width of the pooled feature + int64_t sampling_ratio, // The number of points to sample in each bin + bool aligned) // The flag for pixel shift +// along each axis. +{ + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::roi_align", "") + .typed(); + return op.call( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + aligned); +} + +namespace detail { + +at::Tensor _roi_align_backward( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_roi_align_backward", "") + .typed(); + return op.call( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio, + aligned); +} + +at::Tensor _roi_align_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width, + int64_t sampling_ratio, + bool aligned) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_roi_align_backward", "") + .typed(); + return op.call( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio, + aligned); +} + +} // namespace detail + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width, int sampling_ratio, bool aligned) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/roi_align.h b/framework/include/torchvision/ops/roi_align.h new file mode 100644 index 00000000000..072d6d4231c --- /dev/null +++ b/framework/include/torchvision/ops/roi_align.h @@ -0,0 +1,58 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API at::Tensor roi_align( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned); + +VISION_API at::Tensor roi_align_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + bool aligned); + +namespace detail { + +at::Tensor _roi_align_backward( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned); + +at::Tensor _roi_align_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width, + int64_t sampling_ratio, + bool aligned); + +} // namespace detail + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/roi_pool.cpp b/framework/include/torchvision/ops/roi_pool.cpp new file mode 100644 index 00000000000..20ca3ca91e7 --- /dev/null +++ b/framework/include/torchvision/ops/roi_pool.cpp @@ -0,0 +1,102 @@ +#include "roi_pool.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +std::tuple roi_pool( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + +std::tuple roi_pool_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + +namespace detail { + +at::Tensor _roi_pool_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_roi_pool_backward", "") + .typed(); + return op.call( + grad, + rois, + argmax, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +} + +at::Tensor _roi_pool_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_roi_pool_backward", "") + .typed(); + return op.call( + grad, + rois, + argmax, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +} + +} // namespace detail + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/ops/roi_pool.h b/framework/include/torchvision/ops/roi_pool.h new file mode 100644 index 00000000000..e2133240f4f --- /dev/null +++ b/framework/include/torchvision/ops/roi_pool.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API std::tuple roi_pool( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +VISION_API std::tuple roi_pool_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width); + +namespace detail { + +at::Tensor _roi_pool_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); + +at::Tensor _roi_pool_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width); + +} // namespace detail + +} // namespace ops +} // namespace vision diff --git a/framework/include/torchvision/vision.cpp b/framework/include/torchvision/vision.cpp new file mode 100644 index 00000000000..161b8ecfa2f --- /dev/null +++ b/framework/include/torchvision/vision.cpp @@ -0,0 +1,41 @@ +#include "vision.h" + +#ifndef MOBILE +#ifdef USE_PYTHON +#include +#endif +#endif +#include + +#ifdef WITH_CUDA +#include +#endif +#ifdef WITH_HIP +#include +#endif + +// If we are in a Windows environment, we need to define +// initialization functions for the _custom_ops extension. +// For PyMODINIT_FUNC to work, we need to include Python.h +#if !defined(MOBILE) && defined(_WIN32) +#ifdef USE_PYTHON +PyMODINIT_FUNC PyInit__C(void) { + // No need to do anything. + return NULL; +} +#endif // USE_PYTHON +#endif // !defined(MOBILE) && defined(_WIN32) + +namespace vision { +int64_t cuda_version() { +#ifdef WITH_CUDA + return CUDA_VERSION; +#else + return -1; +#endif +} + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def("_cuda_version", &cuda_version); +} +} // namespace vision diff --git a/framework/include/torchvision/vision.h b/framework/include/torchvision/vision.h new file mode 100644 index 00000000000..22f8c6cdd38 --- /dev/null +++ b/framework/include/torchvision/vision.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include "macros.h" + +namespace vision { +VISION_API int64_t cuda_version(); + +namespace detail { +extern "C" VISION_INLINE_VARIABLE auto _register_ops = &cuda_version; +#ifdef HINT_MSVC_LINKER_INCLUDE_SYMBOL +#pragma comment(linker, "/include:_register_ops") +#endif + +} // namespace detail +} // namespace vision diff --git a/framework/share/cmake/TorchVision/TorchVisionConfig.cmake b/framework/share/cmake/TorchVision/TorchVisionConfig.cmake new file mode 100644 index 00000000000..f04d2919ebf --- /dev/null +++ b/framework/share/cmake/TorchVision/TorchVisionConfig.cmake @@ -0,0 +1,82 @@ +# TorchVisionConfig.cmake +# -------------------- +# +# Exported targets:: Vision +# + + +####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() ####### +####### Any changes to this file will be overwritten by the next CMake run #### +####### The input file was TorchVisionConfig.cmake.in ######## + +get_filename_component(PACKAGE_${CMAKE_FIND_PACKAGE_NAME}_COUNTER_1 "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) + +macro(set_and_check _var _file) + set(${_var} "${_file}") + if(NOT EXISTS "${_file}") + message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !") + endif() +endmacro() + +macro(check_required_components _NAME) + foreach(comp ${${_NAME}_FIND_COMPONENTS}) + if(NOT ${_NAME}_${comp}_FOUND) + if(${_NAME}_FIND_REQUIRED_${comp}) + set(${_NAME}_FOUND FALSE) + endif() + endif() + endforeach() +endmacro() + +#################################################################################### + +set(PN TorchVision) + +# location of include/torchvision +set(${PN}_INCLUDE_DIR "${PACKAGE_PREFIX_DIR}/include") + +set(${PN}_LIBRARY "") +set(${PN}_DEFINITIONS USING_${PN}) + +check_required_components(${PN}) + + +if(NOT (CMAKE_VERSION VERSION_LESS 3.0)) +#----------------------------------------------------------------------------- +# Don't include targets if this file is being picked up by another +# project which has already built this as a subproject +#----------------------------------------------------------------------------- +if(NOT TARGET ${PN}::${PN}) +include("${CMAKE_CURRENT_LIST_DIR}/${PN}Targets.cmake") + +target_include_directories(${PN}::${PN} INTERFACE "${${PN}_INCLUDE_DIR}") + +if(OFF) + target_compile_definitions(${PN}::${PN} INTERFACE WITH_CUDA) +endif() + +find_package(Torch REQUIRED) +target_link_libraries(${PN}::${PN} INTERFACE torch) + +if(ON) + find_package(PNG REQUIRED) + target_link_libraries(${PN}::${PN} INTERFACE ${PNG_LIBRARY}) + target_compile_definitions(${PN}::${PN} INTERFACE PNG_FOUND) +endif() + +if(ON) + find_package(JPEG REQUIRED) + target_link_libraries(${PN}::${PN} INTERFACE ${JPEG_LIBRARIES}) + target_compile_definitions(${PN}::${PN} INTERFACE JPEG_FOUND) +endif() + +if (OFF) + if(NOT TARGET Python3::Python) + find_package(Python3 COMPONENTS Development) + endif() + target_link_libraries(torch INTERFACE Python3::Python) + target_compile_definitions(${PN}::${PN} INTERFACE USE_PYTHON) +endif() + +endif() +endif() diff --git a/framework/share/cmake/TorchVision/TorchVisionConfigVersion.cmake b/framework/share/cmake/TorchVision/TorchVisionConfigVersion.cmake new file mode 100644 index 00000000000..cb344ba7a65 --- /dev/null +++ b/framework/share/cmake/TorchVision/TorchVisionConfigVersion.cmake @@ -0,0 +1,43 @@ +# This is a basic version file for the Config-mode of find_package(). +# It is used by write_basic_package_version_file() as input file for configure_file() +# to create a version-file which can be installed along a config.cmake file. +# +# The created file sets PACKAGE_VERSION_EXACT if the current version string and +# the requested version string are exactly the same and it sets +# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version. +# The variable CVF_VERSION must be set before calling configure_file(). + +set(PACKAGE_VERSION "0.18.0a0") + +if (PACKAGE_FIND_VERSION_RANGE) + # Package version must be in the requested version range + if ((PACKAGE_FIND_VERSION_RANGE_MIN STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MIN) + OR ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_GREATER PACKAGE_FIND_VERSION_MAX) + OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND PACKAGE_VERSION VERSION_GREATER_EQUAL PACKAGE_FIND_VERSION_MAX))) + set(PACKAGE_VERSION_COMPATIBLE FALSE) + else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + endif() +else() + if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION) + set(PACKAGE_VERSION_COMPATIBLE FALSE) + else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION) + set(PACKAGE_VERSION_EXACT TRUE) + endif() + endif() +endif() + + +# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it: +if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "") + return() +endif() + +# check that the installed version has the same 32/64bit-ness as the one which is currently searching: +if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8") + math(EXPR installedBits "8 * 8") + set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)") + set(PACKAGE_VERSION_UNSUITABLE TRUE) +endif() diff --git a/framework/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake b/framework/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake new file mode 100644 index 00000000000..91aa482bb9c --- /dev/null +++ b/framework/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake @@ -0,0 +1,20 @@ +#---------------------------------------------------------------- +# Generated CMake target import file. +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Import target "TorchVision::TorchVision" for configuration "" +set_property(TARGET TorchVision::TorchVision APPEND PROPERTY IMPORTED_CONFIGURATIONS NOCONFIG) +set_target_properties(TorchVision::TorchVision PROPERTIES + IMPORTED_LINK_DEPENDENT_LIBRARIES_NOCONFIG "torch" + IMPORTED_LOCATION_NOCONFIG "${_IMPORT_PREFIX}/lib/libtorchvision.dylib" + IMPORTED_SONAME_NOCONFIG "@rpath/libtorchvision.dylib" + ) + +list(APPEND _cmake_import_check_targets TorchVision::TorchVision ) +list(APPEND _cmake_import_check_files_for_TorchVision::TorchVision "${_IMPORT_PREFIX}/lib/libtorchvision.dylib" ) + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) diff --git a/framework/share/cmake/TorchVision/TorchVisionTargets.cmake b/framework/share/cmake/TorchVision/TorchVisionTargets.cmake new file mode 100644 index 00000000000..1e07b7fc626 --- /dev/null +++ b/framework/share/cmake/TorchVision/TorchVisionTargets.cmake @@ -0,0 +1,102 @@ +# Generated by CMake + +if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8) + message(FATAL_ERROR "CMake >= 2.8.0 required") +endif() +if(CMAKE_VERSION VERSION_LESS "2.8.3") + message(FATAL_ERROR "CMake >= 2.8.3 required") +endif() +cmake_policy(PUSH) +cmake_policy(VERSION 2.8.3...3.27) +#---------------------------------------------------------------- +# Generated CMake target import file. +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Protect against multiple inclusion, which would fail when already imported targets are added once more. +set(_cmake_targets_defined "") +set(_cmake_targets_not_defined "") +set(_cmake_expected_targets "") +foreach(_cmake_expected_target IN ITEMS TorchVision::TorchVision) + list(APPEND _cmake_expected_targets "${_cmake_expected_target}") + if(TARGET "${_cmake_expected_target}") + list(APPEND _cmake_targets_defined "${_cmake_expected_target}") + else() + list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}") + endif() +endforeach() +unset(_cmake_expected_target) +if(_cmake_targets_defined STREQUAL _cmake_expected_targets) + unset(_cmake_targets_defined) + unset(_cmake_targets_not_defined) + unset(_cmake_expected_targets) + unset(CMAKE_IMPORT_FILE_VERSION) + cmake_policy(POP) + return() +endif() +if(NOT _cmake_targets_defined STREQUAL "") + string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}") + string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}") + message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n") +endif() +unset(_cmake_targets_defined) +unset(_cmake_targets_not_defined) +unset(_cmake_expected_targets) + + +# Compute the installation prefix relative to this file. +get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +if(_IMPORT_PREFIX STREQUAL "/") + set(_IMPORT_PREFIX "") +endif() + +# Create imported target TorchVision::TorchVision +add_library(TorchVision::TorchVision SHARED IMPORTED) + +# Load information for each installed configuration. +file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/TorchVisionTargets-*.cmake") +foreach(_cmake_config_file IN LISTS _cmake_config_files) + include("${_cmake_config_file}") +endforeach() +unset(_cmake_config_file) +unset(_cmake_config_files) + +# Cleanup temporary variables. +set(_IMPORT_PREFIX) + +# Loop over all imported files and verify that they actually exist +foreach(_cmake_target IN LISTS _cmake_import_check_targets) + if(CMAKE_VERSION VERSION_LESS "3.28" + OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target} + OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}") + foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}") + if(NOT EXISTS "${_cmake_file}") + message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file + \"${_cmake_file}\" +but this file does not exist. Possible reasons include: +* The file was deleted, renamed, or moved to another location. +* An install or uninstall procedure did not complete successfully. +* The installation package was faulty and contained + \"${CMAKE_CURRENT_LIST_FILE}\" +but not all the files it references. +") + endif() + endforeach() + endif() + unset(_cmake_file) + unset("_cmake_import_check_files_for_${_cmake_target}") +endforeach() +unset(_cmake_target) +unset(_cmake_import_check_targets) + +# This file does not depend on other imported targets which have +# been exported from the same project but in a separate export set. + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) +cmake_policy(POP) diff --git a/product/.DS_Store b/product/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..773050d530cd0ef7bf753032e0c2257470adce8f GIT binary patch literal 6148 zcmeHKO-sW-5S?wSR`iggHv?Y0_9k8uOB6g8@3oECLSl+i(33g(%e?y!{3*WK8PkMB z@Ki)*VE1k2V_)*JWV1x%#_M!M6cEu6jj^|nuERLatz|v+90Hx6V@)aDW%r}Wyl7>c z<1Z?}?{1Hl^gwgUsP_KymA~EH{%boUFOp=I6*EMJ$JggCLG*snm6iS?Yqj9jqHMAP z`?!-O&14;EB>H#*gGTE7&u0~X72WW*qYb|36O+;oJ<)=0>G6B*Dkq)XG^2IitJ8|N zR-YFe?(U2RKgG;%W(t@Brogrpz@E)EJP@?r6fgx$fl2}XK7?qDv0^FcKON}&5db)V z+Z)Dwmf)IDF;*-E5rH`=1xl&YBZiZ5_(RQ$6-z-WC+C?lk1{(wp*YVDf9TW6#e&wG z0;WJyfxcV@y#HVReEx41*_A0^3j8Yt+#s1GV?2`Xt*wXSy*5OfqOoyaDOiYC=<_-Rk2f$dd6odz6KLTC`>r8>ID)0rT`(E7u literal 0 HcmV?d00001 diff --git a/product/include/torchvision/io/image/cpu/common_jpeg.cpp b/product/include/torchvision/io/image/cpu/common_jpeg.cpp new file mode 100644 index 00000000000..4c993106b45 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/common_jpeg.cpp @@ -0,0 +1,26 @@ +#include "common_jpeg.h" + +namespace vision { +namespace image { +namespace detail { + +#if JPEG_FOUND +void torch_jpeg_error_exit(j_common_ptr cinfo) { + /* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce + * pointer */ + torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err; + + /* Always display the message. */ + /* We could postpone this until after returning, if we chose. */ + // (*cinfo->err->output_message)(cinfo); + /* Create the message */ + (*(cinfo->err->format_message))(cinfo, myerr->jpegLastErrorMsg); + + /* Return control to the setjmp point */ + longjmp(myerr->setjmp_buffer, 1); +} +#endif + +} // namespace detail +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/common_jpeg.h b/product/include/torchvision/io/image/cpu/common_jpeg.h new file mode 100644 index 00000000000..7f7f9f0ccf1 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/common_jpeg.h @@ -0,0 +1,27 @@ +#pragma once + +#if JPEG_FOUND +#include + +#include +#include + +namespace vision { +namespace image { +namespace detail { + +static const JOCTET EOI_BUFFER[1] = {JPEG_EOI}; +struct torch_jpeg_error_mgr { + struct jpeg_error_mgr pub; /* "public" fields */ + char jpegLastErrorMsg[JMSG_LENGTH_MAX]; /* error messages */ + jmp_buf setjmp_buffer; /* for return to caller */ +}; + +using torch_jpeg_error_ptr = struct torch_jpeg_error_mgr*; +void torch_jpeg_error_exit(j_common_ptr cinfo); + +} // namespace detail +} // namespace image +} // namespace vision + +#endif diff --git a/product/include/torchvision/io/image/cpu/common_png.h b/product/include/torchvision/io/image/cpu/common_png.h new file mode 100644 index 00000000000..68400d48e05 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/common_png.h @@ -0,0 +1,6 @@ +#pragma once + +#if PNG_FOUND +#include +#include +#endif diff --git a/product/include/torchvision/io/image/cpu/decode_avif.cpp b/product/include/torchvision/io/image/cpu/decode_avif.cpp new file mode 100644 index 00000000000..ec136743806 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/decode_avif.cpp @@ -0,0 +1,92 @@ +#include "decode_avif.h" + +#if AVIF_FOUND +#include "avif/avif.h" +#endif // AVIF_FOUND + +namespace vision { +namespace image { + +#if !AVIF_FOUND +torch::Tensor decode_avif(const torch::Tensor& data) { + TORCH_CHECK( + false, "decode_avif: torchvision not compiled with libavif support"); +} +#else + +// This normally comes from avif_cxx.h, but it's not always present when +// installing libavif. So we just copy/paste it here. +struct UniquePtrDeleter { + void operator()(avifDecoder* decoder) const { + avifDecoderDestroy(decoder); + } +}; +using DecoderPtr = std::unique_ptr; + +torch::Tensor decode_avif(const torch::Tensor& encoded_data) { + // This is based on + // https://github.com/AOMediaCodec/libavif/blob/main/examples/avif_example_decode_memory.c + // Refer there for more detail about what each function does, and which + // structure/data is available after which call. + + TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); + TORCH_CHECK( + encoded_data.dtype() == torch::kU8, + "Input tensor must have uint8 data type, got ", + encoded_data.dtype()); + TORCH_CHECK( + encoded_data.dim() == 1, + "Input tensor must be 1-dimensional, got ", + encoded_data.dim(), + " dims."); + + DecoderPtr decoder(avifDecoderCreate()); + TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder."); + + auto result = AVIF_RESULT_UNKNOWN_ERROR; + result = avifDecoderSetIOMemory( + decoder.get(), encoded_data.data_ptr(), encoded_data.numel()); + TORCH_CHECK( + result == AVIF_RESULT_OK, + "avifDecoderSetIOMemory failed:", + avifResultToString(result)); + + result = avifDecoderParse(decoder.get()); + TORCH_CHECK( + result == AVIF_RESULT_OK, + "avifDecoderParse failed: ", + avifResultToString(result)); + TORCH_CHECK( + decoder->imageCount == 1, "Avif file contains more than one image"); + TORCH_CHECK( + decoder->image->depth <= 8, + "avif images with bitdepth > 8 are not supported"); + + result = avifDecoderNextImage(decoder.get()); + TORCH_CHECK( + result == AVIF_RESULT_OK, + "avifDecoderNextImage failed:", + avifResultToString(result)); + + auto out = torch::empty( + {decoder->image->height, decoder->image->width, 3}, torch::kUInt8); + + avifRGBImage rgb; + memset(&rgb, 0, sizeof(rgb)); + avifRGBImageSetDefaults(&rgb, decoder->image); + rgb.format = AVIF_RGB_FORMAT_RGB; + rgb.pixels = out.data_ptr(); + rgb.rowBytes = rgb.width * avifRGBImagePixelSize(&rgb); + + result = avifImageYUVToRGB(decoder->image, &rgb); + TORCH_CHECK( + result == AVIF_RESULT_OK, + "avifImageYUVToRGB failed: ", + avifResultToString(result)); + + return out.permute({2, 0, 1}); // return CHW, channels-last +} +#endif // AVIF_FOUND + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_avif.h b/product/include/torchvision/io/image/cpu/decode_avif.h new file mode 100644 index 00000000000..269bce52197 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/decode_avif.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_avif(const torch::Tensor& data); + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_gif.cpp b/product/include/torchvision/io/image/cpu/decode_gif.cpp new file mode 100644 index 00000000000..183d42e86a4 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/decode_gif.cpp @@ -0,0 +1,173 @@ +#include "decode_gif.h" +#include +#include "giflib/gif_lib.h" + +namespace vision { +namespace image { + +typedef struct reader_helper_t { + uint8_t const* encoded_data; // input tensor data pointer + size_t encoded_data_size; // size of input tensor in bytes + size_t num_bytes_read; // number of bytes read so far in the tensor +} reader_helper_t; + +// That function is used by GIFLIB routines to read the encoded bytes. +// This reads `len` bytes and writes them into `buf`. The data is read from the +// input tensor passed to decode_gif() starting at the `num_bytes_read` +// position. +int read_from_tensor(GifFileType* gifFile, GifByteType* buf, int len) { + // the UserData field was set in DGifOpen() + reader_helper_t* reader_helper = + static_cast(gifFile->UserData); + + size_t num_bytes_to_read = std::min( + (size_t)len, + reader_helper->encoded_data_size - reader_helper->num_bytes_read); + std::memcpy( + buf, reader_helper->encoded_data + reader_helper->num_bytes_read, len); + reader_helper->num_bytes_read += num_bytes_to_read; + return num_bytes_to_read; +} + +torch::Tensor decode_gif(const torch::Tensor& encoded_data) { + // LibGif docs: https://giflib.sourceforge.net/intro.html + // Refer over there for more details on the libgif API, API ref, and a + // detailed description of the GIF format. + + TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); + TORCH_CHECK( + encoded_data.dtype() == torch::kU8, + "Input tensor must have uint8 data type, got ", + encoded_data.dtype()); + TORCH_CHECK( + encoded_data.dim() == 1, + "Input tensor must be 1-dimensional, got ", + encoded_data.dim(), + " dims."); + + int error = D_GIF_SUCCEEDED; + + // We're using DGidOpen. The other entrypoints of libgif are + // DGifOpenFileName and DGifOpenFileHandle but we don't want to use those, + // since we need to read the encoded bytes from a tensor of encoded bytes, not + // from a file (for consistency with existing jpeg and png decoders). Using + // DGifOpen is the only way to read from a custom source. + // For that we need to provide a reader function `read_from_tensor` that + // reads from the tensor, and we have to keep track of the number of bytes + // read so far: this is why we need the reader_helper struct. + + // TODO: We are potentially doing an unnecessary copy of the encoded bytes: + // - 1 copy in from file to tensor (in read_file()) + // - 1 copy from tensor to GIFLIB buffers (in read_from_tensor()) + // Since we're vendoring GIFLIB we can potentially modify the calls to + // InternalRead() and just set the `buf` pointer to the tensor data directly. + // That might even save allocation of those buffers. + // If we do that, we'd have to make sure the buffers are never written to by + // GIFLIB, otherwise we'd be overridding the tensor data. + reader_helper_t reader_helper; + reader_helper.encoded_data = encoded_data.data_ptr(); + reader_helper.encoded_data_size = encoded_data.numel(); + reader_helper.num_bytes_read = 0; + GifFileType* gifFile = + DGifOpen(static_cast(&reader_helper), read_from_tensor, &error); + + TORCH_CHECK( + (gifFile != nullptr) && (error == D_GIF_SUCCEEDED), + "DGifOpenFileName() failed - ", + error); + + if (DGifSlurp(gifFile) == GIF_ERROR) { + auto gifFileError = gifFile->Error; + DGifCloseFile(gifFile, &error); + TORCH_CHECK(false, "DGifSlurp() failed - ", gifFileError); + } + auto num_images = gifFile->ImageCount; + + // This check should already done within DGifSlurp(), just to be safe + TORCH_CHECK(num_images > 0, "GIF file should contain at least one image!"); + + GifColorType bg = {0, 0, 0}; + if (gifFile->SColorMap) { + bg = gifFile->SColorMap->Colors[gifFile->SBackGroundColor]; + } + + // The GIFLIB docs say that the canvas's height and width are potentially + // ignored by modern viewers, so to be on the safe side we set the output + // height to max(canvas_heigh, first_image_height). Same for width. + // https://giflib.sourceforge.net/whatsinagif/bits_and_bytes.html + auto out_h = + std::max(gifFile->SHeight, gifFile->SavedImages[0].ImageDesc.Height); + auto out_w = + std::max(gifFile->SWidth, gifFile->SavedImages[0].ImageDesc.Width); + + // We output a channels-last tensor for consistency with other image decoders. + // Torchvision's resize tends to be is faster on uint8 channels-last tensors. + auto options = torch::TensorOptions() + .dtype(torch::kU8) + .memory_format(torch::MemoryFormat::ChannelsLast); + auto out = torch::empty( + {int64_t(num_images), 3, int64_t(out_h), int64_t(out_w)}, options); + auto out_a = out.accessor(); + for (int i = 0; i < num_images; i++) { + const SavedImage& img = gifFile->SavedImages[i]; + + GraphicsControlBlock gcb; + DGifSavedExtensionToGCB(gifFile, i, &gcb); + + const GifImageDesc& desc = img.ImageDesc; + const ColorMapObject* cmap = + desc.ColorMap ? desc.ColorMap : gifFile->SColorMap; + TORCH_CHECK( + cmap != nullptr, + "Global and local color maps are missing. This should never happen!"); + + // When going from one image to another, there is a "disposal method" which + // specifies how to handle the transition. E.g. DISPOSE_DO_NOT means that + // the current image should essentially be drawn on top of the previous + // canvas. The pixels of that previous canvas will appear on the new one if + // either: + // - a pixel is transparent in the current image + // - the current image is smaller than the canvas, hence exposing its pixels + // The "background" disposal method means that the current canvas should be + // set to the background color. + // We only support these 2 modes and default to "background" when the + // disposal method is unspecified, or when it's set to "DISPOSE_PREVIOUS" + // which according to GIFLIB is not widely supported. + // (https://giflib.sourceforge.net/whatsinagif/animation_and_transparency.html). + if (i > 0 && gcb.DisposalMode == DISPOSE_DO_NOT) { + out[i] = out[i - 1]; + } else { + // Background. If bg wasn't defined, it will be (0, 0, 0) + for (int h = 0; h < gifFile->SHeight; h++) { + for (int w = 0; w < gifFile->SWidth; w++) { + out_a[i][0][h][w] = bg.Red; + out_a[i][1][h][w] = bg.Green; + out_a[i][2][h][w] = bg.Blue; + } + } + } + + for (int h = 0; h < desc.Height; h++) { + for (int w = 0; w < desc.Width; w++) { + auto c = img.RasterBits[h * desc.Width + w]; + if (c == gcb.TransparentColor) { + continue; + } + GifColorType rgb = cmap->Colors[c]; + out_a[i][0][h + desc.Top][w + desc.Left] = rgb.Red; + out_a[i][1][h + desc.Top][w + desc.Left] = rgb.Green; + out_a[i][2][h + desc.Top][w + desc.Left] = rgb.Blue; + } + } + } + + out = out.squeeze(0); // remove batch dim if there's only one image + + DGifCloseFile(gifFile, &error); + TORCH_CHECK(error == D_GIF_SUCCEEDED, "DGifCloseFile() failed - ", error); + + return out; +} + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_gif.h b/product/include/torchvision/io/image/cpu/decode_gif.h new file mode 100644 index 00000000000..68d5073c91b --- /dev/null +++ b/product/include/torchvision/io/image/cpu/decode_gif.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace vision { +namespace image { + +// encoded_data tensor must be 1D uint8 and contiguous +C10_EXPORT torch::Tensor decode_gif(const torch::Tensor& encoded_data); + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_image.cpp b/product/include/torchvision/io/image/cpu/decode_image.cpp new file mode 100644 index 00000000000..75c7e06195a --- /dev/null +++ b/product/include/torchvision/io/image/cpu/decode_image.cpp @@ -0,0 +1,77 @@ +#include "decode_image.h" + +#include "decode_avif.h" +#include "decode_gif.h" +#include "decode_jpeg.h" +#include "decode_png.h" +#include "decode_webp.h" + +namespace vision { +namespace image { + +torch::Tensor decode_image( + const torch::Tensor& data, + ImageReadMode mode, + bool apply_exif_orientation) { + // Check that tensor is a CPU tensor + TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor"); + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + auto err_msg = + "Unsupported image file. Only jpeg, png and gif are currently supported."; + + auto datap = data.data_ptr(); + + const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF" + TORCH_CHECK(data.numel() >= 3, err_msg); + if (memcmp(jpeg_signature, datap, 3) == 0) { + return decode_jpeg(data, mode, apply_exif_orientation); + } + + const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" + TORCH_CHECK(data.numel() >= 4, err_msg); + if (memcmp(png_signature, datap, 4) == 0) { + return decode_png(data, mode, apply_exif_orientation); + } + + const uint8_t gif_signature_1[6] = { + 0x47, 0x49, 0x46, 0x38, 0x39, 0x61}; // == "GIF89a" + const uint8_t gif_signature_2[6] = { + 0x47, 0x49, 0x46, 0x38, 0x37, 0x61}; // == "GIF87a" + TORCH_CHECK(data.numel() >= 6, err_msg); + if (memcmp(gif_signature_1, datap, 6) == 0 || + memcmp(gif_signature_2, datap, 6) == 0) { + return decode_gif(data); + } + + // We assume the signature of an avif file is + // 0000 0020 6674 7970 6176 6966 + // xxxx xxxx f t y p a v i f + // We only check for the "ftyp avif" part. + // This is probably not perfect, but hopefully this should cover most files. + const uint8_t avif_signature[8] = { + 0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66}; // == "ftypavif" + TORCH_CHECK(data.numel() >= 12, err_msg); + if ((memcmp(avif_signature, datap + 4, 8) == 0)) { + return decode_avif(data); + } + + const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF" + const uint8_t webp_signature_end[7] = { + 0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8" + TORCH_CHECK(data.numel() >= 15, err_msg); + if ((memcmp(webp_signature_begin, datap, 4) == 0) && + (memcmp(webp_signature_end, datap + 8, 7) == 0)) { + return decode_webp(data); + } + + TORCH_CHECK(false, err_msg); +} + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_image.h b/product/include/torchvision/io/image/cpu/decode_image.h new file mode 100644 index 00000000000..f0e66d397ac --- /dev/null +++ b/product/include/torchvision/io/image/cpu/decode_image.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include "../image_read_mode.h" + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_image( + const torch::Tensor& data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, + bool apply_exif_orientation = false); + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_jpeg.cpp b/product/include/torchvision/io/image/cpu/decode_jpeg.cpp new file mode 100644 index 00000000000..ec5953e4106 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/decode_jpeg.cpp @@ -0,0 +1,271 @@ +#include "decode_jpeg.h" +#include "common_jpeg.h" +#include "exif.h" + +namespace vision { +namespace image { + +#if !JPEG_FOUND +torch::Tensor decode_jpeg( + const torch::Tensor& data, + ImageReadMode mode, + bool apply_exif_orientation) { + TORCH_CHECK( + false, "decode_jpeg: torchvision not compiled with libjpeg support"); +} +#else + +using namespace detail; +using namespace exif_private; + +namespace { + +struct torch_jpeg_mgr { + struct jpeg_source_mgr pub; + const JOCTET* data; + size_t len; +}; + +static void torch_jpeg_init_source(j_decompress_ptr cinfo) {} + +static boolean torch_jpeg_fill_input_buffer(j_decompress_ptr cinfo) { + // No more data. Probably an incomplete image; Raise exception. + torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err; + strcpy(myerr->jpegLastErrorMsg, "Image is incomplete or truncated"); + longjmp(myerr->setjmp_buffer, 1); +} + +static void torch_jpeg_skip_input_data(j_decompress_ptr cinfo, long num_bytes) { + torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src; + if (src->pub.bytes_in_buffer < (size_t)num_bytes) { + // Skipping over all of remaining data; output EOI. + src->pub.next_input_byte = EOI_BUFFER; + src->pub.bytes_in_buffer = 1; + } else { + // Skipping over only some of the remaining data. + src->pub.next_input_byte += num_bytes; + src->pub.bytes_in_buffer -= num_bytes; + } +} + +static void torch_jpeg_term_source(j_decompress_ptr cinfo) {} + +static void torch_jpeg_set_source_mgr( + j_decompress_ptr cinfo, + const unsigned char* data, + size_t len) { + torch_jpeg_mgr* src; + if (cinfo->src == 0) { // if this is first time; allocate memory + cinfo->src = (struct jpeg_source_mgr*)(*cinfo->mem->alloc_small)( + (j_common_ptr)cinfo, JPOOL_PERMANENT, sizeof(torch_jpeg_mgr)); + } + src = (torch_jpeg_mgr*)cinfo->src; + src->pub.init_source = torch_jpeg_init_source; + src->pub.fill_input_buffer = torch_jpeg_fill_input_buffer; + src->pub.skip_input_data = torch_jpeg_skip_input_data; + src->pub.resync_to_restart = jpeg_resync_to_restart; // default + src->pub.term_source = torch_jpeg_term_source; + // fill the buffers + src->data = (const JOCTET*)data; + src->len = len; + src->pub.bytes_in_buffer = len; + src->pub.next_input_byte = src->data; + + jpeg_save_markers(cinfo, APP1, 0xffff); +} + +inline unsigned char clamped_cmyk_rgb_convert( + unsigned char k, + unsigned char cmy) { + // Inspired from Pillow: + // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569 + int v = k * cmy + 128; + v = ((v >> 8) + v) >> 8; + return std::clamp(k - v, 0, 255); +} + +void convert_line_cmyk_to_rgb( + j_decompress_ptr cinfo, + const unsigned char* cmyk_line, + unsigned char* rgb_line) { + int width = cinfo->output_width; + for (int i = 0; i < width; ++i) { + int c = cmyk_line[i * 4 + 0]; + int m = cmyk_line[i * 4 + 1]; + int y = cmyk_line[i * 4 + 2]; + int k = cmyk_line[i * 4 + 3]; + + rgb_line[i * 3 + 0] = clamped_cmyk_rgb_convert(k, 255 - c); + rgb_line[i * 3 + 1] = clamped_cmyk_rgb_convert(k, 255 - m); + rgb_line[i * 3 + 2] = clamped_cmyk_rgb_convert(k, 255 - y); + } +} + +inline unsigned char rgb_to_gray(int r, int g, int b) { + // Inspired from Pillow: + // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226 + return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16; +} + +void convert_line_cmyk_to_gray( + j_decompress_ptr cinfo, + const unsigned char* cmyk_line, + unsigned char* gray_line) { + int width = cinfo->output_width; + for (int i = 0; i < width; ++i) { + int c = cmyk_line[i * 4 + 0]; + int m = cmyk_line[i * 4 + 1]; + int y = cmyk_line[i * 4 + 2]; + int k = cmyk_line[i * 4 + 3]; + + int r = clamped_cmyk_rgb_convert(k, 255 - c); + int g = clamped_cmyk_rgb_convert(k, 255 - m); + int b = clamped_cmyk_rgb_convert(k, 255 - y); + + gray_line[i] = rgb_to_gray(r, g, b); + } +} + +} // namespace + +torch::Tensor decode_jpeg( + const torch::Tensor& data, + ImageReadMode mode, + bool apply_exif_orientation) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg"); + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + struct jpeg_decompress_struct cinfo; + struct torch_jpeg_error_mgr jerr; + + auto datap = data.data_ptr(); + // Setup decompression structure + cinfo.err = jpeg_std_error(&jerr.pub); + jerr.pub.error_exit = torch_jpeg_error_exit; + /* Establish the setjmp return context for my_error_exit to use. */ + if (setjmp(jerr.setjmp_buffer)) { + /* If we get here, the JPEG code has signaled an error. + * We need to clean up the JPEG object. + */ + jpeg_destroy_decompress(&cinfo); + TORCH_CHECK(false, jerr.jpegLastErrorMsg); + } + + jpeg_create_decompress(&cinfo); + torch_jpeg_set_source_mgr(&cinfo, datap, data.numel()); + + // read info from header. + jpeg_read_header(&cinfo, TRUE); + + int channels = cinfo.num_components; + bool cmyk_to_rgb_or_gray = false; + + if (mode != IMAGE_READ_MODE_UNCHANGED) { + switch (mode) { + case IMAGE_READ_MODE_GRAY: + if (cinfo.jpeg_color_space == JCS_CMYK || + cinfo.jpeg_color_space == JCS_YCCK) { + cinfo.out_color_space = JCS_CMYK; + cmyk_to_rgb_or_gray = true; + } else { + cinfo.out_color_space = JCS_GRAYSCALE; + } + channels = 1; + break; + case IMAGE_READ_MODE_RGB: + if (cinfo.jpeg_color_space == JCS_CMYK || + cinfo.jpeg_color_space == JCS_YCCK) { + cinfo.out_color_space = JCS_CMYK; + cmyk_to_rgb_or_gray = true; + } else { + cinfo.out_color_space = JCS_RGB; + } + channels = 3; + break; + /* + * Libjpeg does not support converting from CMYK to grayscale etc. There + * is a way to do this but it involves converting it manually to RGB: + * https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313 + */ + default: + jpeg_destroy_decompress(&cinfo); + TORCH_CHECK(false, "The provided mode is not supported for JPEG files"); + } + + jpeg_calc_output_dimensions(&cinfo); + } + + int exif_orientation = -1; + if (apply_exif_orientation) { + exif_orientation = fetch_jpeg_exif_orientation(&cinfo); + } + + jpeg_start_decompress(&cinfo); + + int height = cinfo.output_height; + int width = cinfo.output_width; + + int stride = width * channels; + auto tensor = + torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); + auto ptr = tensor.data_ptr(); + torch::Tensor cmyk_line_tensor; + if (cmyk_to_rgb_or_gray) { + cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8); + } + + while (cinfo.output_scanline < cinfo.output_height) { + /* jpeg_read_scanlines expects an array of pointers to scanlines. + * Here the array is only one element long, but you could ask for + * more than one scanline at a time if that's more convenient. + */ + if (cmyk_to_rgb_or_gray) { + auto cmyk_line_ptr = cmyk_line_tensor.data_ptr(); + jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1); + + if (channels == 3) { + convert_line_cmyk_to_rgb(&cinfo, cmyk_line_ptr, ptr); + } else if (channels == 1) { + convert_line_cmyk_to_gray(&cinfo, cmyk_line_ptr, ptr); + } + } else { + jpeg_read_scanlines(&cinfo, &ptr, 1); + } + ptr += stride; + } + + jpeg_finish_decompress(&cinfo); + jpeg_destroy_decompress(&cinfo); + auto output = tensor.permute({2, 0, 1}); + + if (apply_exif_orientation) { + return exif_orientation_transform(output, exif_orientation); + } + return output; +} +#endif // #if !JPEG_FOUND + +int64_t _jpeg_version() { +#if JPEG_FOUND + return JPEG_LIB_VERSION; +#else + return -1; +#endif +} + +bool _is_compiled_against_turbo() { +#ifdef LIBJPEG_TURBO_VERSION + return true; +#else + return false; +#endif +} + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_jpeg.h b/product/include/torchvision/io/image/cpu/decode_jpeg.h new file mode 100644 index 00000000000..e0c9a24c846 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/decode_jpeg.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include "../image_read_mode.h" + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_jpeg( + const torch::Tensor& data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, + bool apply_exif_orientation = false); + +C10_EXPORT int64_t _jpeg_version(); +C10_EXPORT bool _is_compiled_against_turbo(); + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_png.cpp b/product/include/torchvision/io/image/cpu/decode_png.cpp new file mode 100644 index 00000000000..ac14ae934a4 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/decode_png.cpp @@ -0,0 +1,232 @@ +#include "decode_png.h" +#include "common_png.h" +#include "exif.h" + +namespace vision { +namespace image { + +using namespace exif_private; + +#if !PNG_FOUND +torch::Tensor decode_png( + const torch::Tensor& data, + ImageReadMode mode, + bool apply_exif_orientation) { + TORCH_CHECK( + false, "decode_png: torchvision not compiled with libPNG support"); +} +#else + +bool is_little_endian() { + uint32_t x = 1; + return *(uint8_t*)&x; +} + +torch::Tensor decode_png( + const torch::Tensor& data, + ImageReadMode mode, + bool apply_exif_orientation) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png"); + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + auto png_ptr = + png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); + TORCH_CHECK(png_ptr, "libpng read structure allocation failed!") + auto info_ptr = png_create_info_struct(png_ptr); + if (!info_ptr) { + png_destroy_read_struct(&png_ptr, nullptr, nullptr); + // Seems redundant with the if statement. done here to avoid leaking memory. + TORCH_CHECK(info_ptr, "libpng info structure allocation failed!") + } + + auto accessor = data.accessor(); + auto datap = accessor.data(); + auto datap_len = accessor.size(0); + + if (setjmp(png_jmpbuf(png_ptr)) != 0) { + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK(false, "Internal error."); + } + TORCH_CHECK(datap_len >= 8, "Content is too small for png!") + auto is_png = !png_sig_cmp(datap, 0, 8); + TORCH_CHECK(is_png, "Content is not png!") + + struct Reader { + png_const_bytep ptr; + png_size_t count; + } reader; + reader.ptr = png_const_bytep(datap) + 8; + reader.count = datap_len - 8; + + auto read_callback = [](png_structp png_ptr, + png_bytep output, + png_size_t bytes) { + auto reader = static_cast(png_get_io_ptr(png_ptr)); + TORCH_CHECK( + reader->count >= bytes, + "Out of bound read in decode_png. Probably, the input image is corrupted"); + std::copy(reader->ptr, reader->ptr + bytes, output); + reader->ptr += bytes; + reader->count -= bytes; + }; + png_set_sig_bytes(png_ptr, 8); + png_set_read_fn(png_ptr, &reader, read_callback); + png_read_info(png_ptr, info_ptr); + + png_uint_32 width, height; + int bit_depth, color_type; + int interlace_type; + auto retval = png_get_IHDR( + png_ptr, + info_ptr, + &width, + &height, + &bit_depth, + &color_type, + &interlace_type, + nullptr, + nullptr); + + if (retval != 1) { + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK(retval == 1, "Could read image metadata from content.") + } + + if (bit_depth > 8 && bit_depth != 16) { + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK( + false, + "bit depth of png image is " + std::to_string(bit_depth) + + ". Only <=8 and 16 are supported.") + } + + int channels = png_get_channels(png_ptr, info_ptr); + + if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8) + png_set_expand_gray_1_2_4_to_8(png_ptr); + + int number_of_passes; + if (interlace_type == PNG_INTERLACE_ADAM7) { + number_of_passes = png_set_interlace_handling(png_ptr); + } else { + number_of_passes = 1; + } + + if (mode != IMAGE_READ_MODE_UNCHANGED) { + // TODO: consider supporting PNG_INFO_tRNS + bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0; + bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0; + bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0; + + switch (mode) { + case IMAGE_READ_MODE_GRAY: + if (color_type != PNG_COLOR_TYPE_GRAY) { + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } + + if (has_alpha) { + png_set_strip_alpha(png_ptr); + } + + if (has_color) { + png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); + } + channels = 1; + } + break; + case IMAGE_READ_MODE_GRAY_ALPHA: + if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) { + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } + + if (!has_alpha) { + png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); + } + + if (has_color) { + png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); + } + channels = 2; + } + break; + case IMAGE_READ_MODE_RGB: + if (color_type != PNG_COLOR_TYPE_RGB) { + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } else if (!has_color) { + png_set_gray_to_rgb(png_ptr); + } + + if (has_alpha) { + png_set_strip_alpha(png_ptr); + } + channels = 3; + } + break; + case IMAGE_READ_MODE_RGB_ALPHA: + if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) { + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } else if (!has_color) { + png_set_gray_to_rgb(png_ptr); + } + + if (!has_alpha) { + png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); + } + channels = 4; + } + break; + default: + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK(false, "The provided mode is not supported for PNG files"); + } + + png_read_update_info(png_ptr, info_ptr); + } + + auto num_pixels_per_row = width * channels; + auto is_16_bits = bit_depth == 16; + auto tensor = torch::empty( + {int64_t(height), int64_t(width), channels}, + is_16_bits ? at::kUInt16 : torch::kU8); + if (is_little_endian()) { + png_set_swap(png_ptr); + } + auto t_ptr = (uint8_t*)tensor.data_ptr(); + for (int pass = 0; pass < number_of_passes; pass++) { + for (png_uint_32 i = 0; i < height; ++i) { + png_read_row(png_ptr, t_ptr, nullptr); + t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1); + } + t_ptr = (uint8_t*)tensor.data_ptr(); + } + + int exif_orientation = -1; + if (apply_exif_orientation) { + exif_orientation = fetch_png_exif_orientation(png_ptr, info_ptr); + } + + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + + auto output = tensor.permute({2, 0, 1}); + if (apply_exif_orientation) { + return exif_orientation_transform(output, exif_orientation); + } + return output; +} +#endif + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_png.h b/product/include/torchvision/io/image/cpu/decode_png.h new file mode 100644 index 00000000000..0866711e987 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/decode_png.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include "../image_read_mode.h" + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_png( + const torch::Tensor& data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, + bool apply_exif_orientation = false); + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_webp.cpp b/product/include/torchvision/io/image/cpu/decode_webp.cpp new file mode 100644 index 00000000000..844ce61a3e3 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/decode_webp.cpp @@ -0,0 +1,40 @@ +#include "decode_webp.h" + +#if WEBP_FOUND +#include "webp/decode.h" +#endif // WEBP_FOUND + +namespace vision { +namespace image { + +#if !WEBP_FOUND +torch::Tensor decode_webp(const torch::Tensor& data) { + TORCH_CHECK( + false, "decode_webp: torchvision not compiled with libwebp support"); +} +#else + +torch::Tensor decode_webp(const torch::Tensor& encoded_data) { + TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); + TORCH_CHECK( + encoded_data.dtype() == torch::kU8, + "Input tensor must have uint8 data type, got ", + encoded_data.dtype()); + TORCH_CHECK( + encoded_data.dim() == 1, + "Input tensor must be 1-dimensional, got ", + encoded_data.dim(), + " dims."); + + int width = 0; + int height = 0; + auto decoded_data = WebPDecodeRGB( + encoded_data.data_ptr(), encoded_data.numel(), &width, &height); + TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB failed."); + auto out = torch::from_blob(decoded_data, {height, width, 3}, torch::kUInt8); + return out.permute({2, 0, 1}); // return CHW, channels-last +} +#endif // WEBP_FOUND + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_webp.h b/product/include/torchvision/io/image/cpu/decode_webp.h new file mode 100644 index 00000000000..00a0c3362f7 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/decode_webp.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_webp(const torch::Tensor& data); + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/encode_jpeg.cpp b/product/include/torchvision/io/image/cpu/encode_jpeg.cpp new file mode 100644 index 00000000000..d2ed73071a2 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/encode_jpeg.cpp @@ -0,0 +1,113 @@ +#include "encode_jpeg.h" + +#include "common_jpeg.h" + +namespace vision { +namespace image { + +#if !JPEG_FOUND + +torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { + TORCH_CHECK( + false, "encode_jpeg: torchvision not compiled with libjpeg support"); +} + +#else +// For libjpeg version <= 9b, the out_size parameter in jpeg_mem_dest() is +// defined as unsigned long, whereas in later version, it is defined as size_t. +#if !defined(JPEG_LIB_VERSION_MAJOR) || JPEG_LIB_VERSION_MAJOR < 9 || \ + (JPEG_LIB_VERSION_MAJOR == 9 && JPEG_LIB_VERSION_MINOR <= 2) +using JpegSizeType = unsigned long; +#else +using JpegSizeType = size_t; +#endif + +using namespace detail; + +torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cpu.encode_jpeg.encode_jpeg"); + // Define compression structures and error handling + struct jpeg_compress_struct cinfo {}; + struct torch_jpeg_error_mgr jerr {}; + + // Define buffer to write JPEG information to and its size + JpegSizeType jpegSize = 0; + uint8_t* jpegBuf = nullptr; + + cinfo.err = jpeg_std_error(&jerr.pub); + jerr.pub.error_exit = torch_jpeg_error_exit; + + /* Establish the setjmp return context for my_error_exit to use. */ + if (setjmp(jerr.setjmp_buffer)) { + /* If we get here, the JPEG code has signaled an error. + * We need to clean up the JPEG object and the buffer. + */ + jpeg_destroy_compress(&cinfo); + if (jpegBuf != nullptr) { + free(jpegBuf); + } + + TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg); + } + + // Check that the input tensor is on CPU + TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); + + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + + // Check that the input tensor is 3-dimensional + TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); + + // Get image info + int channels = data.size(0); + int height = data.size(1); + int width = data.size(2); + auto input = data.permute({1, 2, 0}).contiguous(); + + TORCH_CHECK( + channels == 1 || channels == 3, + "The number of channels should be 1 or 3, got: ", + channels); + + // Initialize JPEG structure + jpeg_create_compress(&cinfo); + + // Set output image information + cinfo.image_width = width; + cinfo.image_height = height; + cinfo.input_components = channels; + cinfo.in_color_space = channels == 1 ? JCS_GRAYSCALE : JCS_RGB; + + jpeg_set_defaults(&cinfo); + jpeg_set_quality(&cinfo, quality, TRUE); + + // Save JPEG output to a buffer + jpeg_mem_dest(&cinfo, &jpegBuf, &jpegSize); + + // Start JPEG compression + jpeg_start_compress(&cinfo, TRUE); + + auto stride = width * channels; + auto ptr = input.data_ptr(); + + // Encode JPEG file + while (cinfo.next_scanline < cinfo.image_height) { + jpeg_write_scanlines(&cinfo, &ptr, 1); + ptr += stride; + } + + jpeg_finish_compress(&cinfo); + jpeg_destroy_compress(&cinfo); + + torch::TensorOptions options = torch::TensorOptions{torch::kU8}; + auto out_tensor = + torch::from_blob(jpegBuf, {(long)jpegSize}, ::free, options); + jpegBuf = nullptr; + return out_tensor; +} +#endif + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/encode_jpeg.h b/product/include/torchvision/io/image/cpu/encode_jpeg.h new file mode 100644 index 00000000000..25084e154d6 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/encode_jpeg.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor encode_jpeg( + const torch::Tensor& data, + int64_t quality); + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/encode_png.cpp b/product/include/torchvision/io/image/cpu/encode_png.cpp new file mode 100644 index 00000000000..5596d3a6789 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/encode_png.cpp @@ -0,0 +1,180 @@ +#include "encode_jpeg.h" + +#include "common_png.h" + +namespace vision { +namespace image { + +#if !PNG_FOUND + +torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { + TORCH_CHECK( + false, "encode_png: torchvision not compiled with libpng support"); +} + +#else + +namespace { + +struct torch_mem_encode { + char* buffer; + size_t size; +}; + +struct torch_png_error_mgr { + const char* pngLastErrorMsg; /* error messages */ + jmp_buf setjmp_buffer; /* for return to caller */ +}; + +using torch_png_error_mgr_ptr = torch_png_error_mgr*; + +void torch_png_error(png_structp png_ptr, png_const_charp error_msg) { + /* png_ptr->err really points to a torch_png_error_mgr struct, so coerce + * pointer */ + auto error_ptr = (torch_png_error_mgr_ptr)png_get_error_ptr(png_ptr); + /* Replace the error message on the error structure */ + error_ptr->pngLastErrorMsg = error_msg; + /* Return control to the setjmp point */ + longjmp(error_ptr->setjmp_buffer, 1); +} + +void torch_png_write_data( + png_structp png_ptr, + png_bytep data, + png_size_t length) { + struct torch_mem_encode* p = + (struct torch_mem_encode*)png_get_io_ptr(png_ptr); + size_t nsize = p->size + length; + + /* allocate or grow buffer */ + if (p->buffer) + p->buffer = (char*)realloc(p->buffer, nsize); + else + p->buffer = (char*)malloc(nsize); + + if (!p->buffer) + png_error(png_ptr, "Write Error"); + + /* copy new bytes to end of buffer */ + memcpy(p->buffer + p->size, data, length); + p->size += length; +} + +} // namespace + +torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.encode_png.encode_png"); + // Define compression structures and error handling + png_structp png_write; + png_infop info_ptr; + struct torch_png_error_mgr err_ptr; + + // Define output buffer + struct torch_mem_encode buf_info; + buf_info.buffer = nullptr; + buf_info.size = 0; + + /* Establish the setjmp return context for my_error_exit to use. */ + if (setjmp(err_ptr.setjmp_buffer)) { + /* If we get here, the PNG code has signaled an error. + * We need to clean up the PNG object and the buffer. + */ + if (info_ptr != nullptr) { + png_destroy_info_struct(png_write, &info_ptr); + } + + if (png_write != nullptr) { + png_destroy_write_struct(&png_write, nullptr); + } + + if (buf_info.buffer != nullptr) { + free(buf_info.buffer); + } + + TORCH_CHECK(false, err_ptr.pngLastErrorMsg); + } + + // Check that the compression level is between 0 and 9 + TORCH_CHECK( + compression_level >= 0 && compression_level <= 9, + "Compression level should be between 0 and 9"); + + // Check that the input tensor is on CPU + TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); + + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + + // Check that the input tensor is 3-dimensional + TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); + + // Get image info + int channels = data.size(0); + int height = data.size(1); + int width = data.size(2); + auto input = data.permute({1, 2, 0}).contiguous(); + + TORCH_CHECK( + channels == 1 || channels == 3, + "The number of channels should be 1 or 3, got: ", + channels); + + // Initialize PNG structures + png_write = png_create_write_struct( + PNG_LIBPNG_VER_STRING, &err_ptr, torch_png_error, nullptr); + + info_ptr = png_create_info_struct(png_write); + + // Define custom buffer output + png_set_write_fn(png_write, &buf_info, torch_png_write_data, nullptr); + + // Set output image information + auto color_type = channels == 1 ? PNG_COLOR_TYPE_GRAY : PNG_COLOR_TYPE_RGB; + png_set_IHDR( + png_write, + info_ptr, + width, + height, + 8, + color_type, + PNG_INTERLACE_NONE, + PNG_COMPRESSION_TYPE_DEFAULT, + PNG_FILTER_TYPE_DEFAULT); + + // Set image compression level + png_set_compression_level(png_write, compression_level); + + // Write file header + png_write_info(png_write, info_ptr); + + auto stride = width * channels; + auto ptr = input.data_ptr(); + + // Encode PNG file + for (int y = 0; y < height; ++y) { + png_write_row(png_write, ptr); + ptr += stride; + } + + // Write EOF + png_write_end(png_write, info_ptr); + + // Destroy structures + png_destroy_write_struct(&png_write, &info_ptr); + + torch::TensorOptions options = torch::TensorOptions{torch::kU8}; + auto outTensor = torch::empty({(long)buf_info.size}, options); + + // Copy memory from png buffer, since torch cannot get ownership of it via + // `from_blob` + auto outPtr = outTensor.data_ptr(); + std::memcpy(outPtr, buf_info.buffer, sizeof(uint8_t) * outTensor.numel()); + free(buf_info.buffer); + + return outTensor; +} + +#endif + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/encode_png.h b/product/include/torchvision/io/image/cpu/encode_png.h new file mode 100644 index 00000000000..86a67c8706e --- /dev/null +++ b/product/include/torchvision/io/image/cpu/encode_png.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor encode_png( + const torch::Tensor& data, + int64_t compression_level); + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/exif.h b/product/include/torchvision/io/image/cpu/exif.h new file mode 100644 index 00000000000..61948bfe16a --- /dev/null +++ b/product/include/torchvision/io/image/cpu/exif.h @@ -0,0 +1,256 @@ +/*M/////////////////////////////////////////////////////////////////////////////////////// +// +// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. +// +// By downloading, copying, installing or using the software you agree to this +license. +// If you do not agree to this license, do not download, install, +// copy or use the software. +// +// +// License Agreement +// For Open Source Computer Vision Library +// +// Copyright (C) 2000-2008, Intel Corporation, all rights reserved. +// Copyright (C) 2009, Willow Garage Inc., all rights reserved. +// Third party copyrights are property of their respective owners. +// +// Redistribution and use in source and binary forms, with or without +modification, +// are permitted provided that the following conditions are met: +// +// * Redistribution's of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistribution's in binary form must reproduce the above copyright +notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * The name of the copyright holders may not be used to endorse or promote +products +// derived from this software without specific prior written permission. +// +// This software is provided by the copyright holders and contributors "as is" +and +// any express or implied warranties, including, but not limited to, the implied +// warranties of merchantability and fitness for a particular purpose are +disclaimed. +// In no event shall the Intel Corporation or contributors be liable for any +direct, +// indirect, incidental, special, exemplary, or consequential damages +// (including, but not limited to, procurement of substitute goods or services; +// loss of use, data, or profits; or business interruption) however caused +// and on any theory of liability, whether in contract, strict liability, +// or tort (including negligence or otherwise) arising in any way out of +// the use of this software, even if advised of the possibility of such damage. +// +//M*/ +#pragma once +// Functions in this module are taken from OpenCV +// https://github.com/opencv/opencv/blob/097891e311fae1d8354eb092a0fd0171e630d78c/modules/imgcodecs/src/exif.cpp + +#if JPEG_FOUND +#include +#endif +#if PNG_FOUND +#include +#endif + +#include + +namespace vision { +namespace image { +namespace exif_private { + +constexpr uint16_t APP1 = 0xe1; +constexpr uint16_t ENDIANNESS_INTEL = 0x49; +constexpr uint16_t ENDIANNESS_MOTO = 0x4d; +constexpr uint16_t REQ_EXIF_TAG_MARK = 0x2a; +constexpr uint16_t ORIENTATION_EXIF_TAG = 0x0112; +constexpr uint16_t INCORRECT_TAG = -1; + +class ExifDataReader { + public: + ExifDataReader(unsigned char* p, size_t s) : _ptr(p), _size(s) {} + size_t size() const { + return _size; + } + const unsigned char& operator[](size_t index) const { + TORCH_CHECK(index >= 0 && index < _size); + return _ptr[index]; + } + + protected: + unsigned char* _ptr; + size_t _size; +}; + +inline uint16_t get_endianness(const ExifDataReader& exif_data) { + if ((exif_data.size() < 1) || + (exif_data.size() > 1 && exif_data[0] != exif_data[1])) { + return 0; + } + if (exif_data[0] == 'I') { + return ENDIANNESS_INTEL; + } + if (exif_data[0] == 'M') { + return ENDIANNESS_MOTO; + } + return 0; +} + +inline uint16_t get_uint16( + const ExifDataReader& exif_data, + uint16_t endianness, + const size_t offset) { + if (offset + 1 >= exif_data.size()) { + return INCORRECT_TAG; + } + + if (endianness == ENDIANNESS_INTEL) { + return exif_data[offset] + (exif_data[offset + 1] << 8); + } + return (exif_data[offset] << 8) + exif_data[offset + 1]; +} + +inline uint32_t get_uint32( + const ExifDataReader& exif_data, + uint16_t endianness, + const size_t offset) { + if (offset + 3 >= exif_data.size()) { + return INCORRECT_TAG; + } + + if (endianness == ENDIANNESS_INTEL) { + return exif_data[offset] + (exif_data[offset + 1] << 8) + + (exif_data[offset + 2] << 16) + (exif_data[offset + 3] << 24); + } + return (exif_data[offset] << 24) + (exif_data[offset + 1] << 16) + + (exif_data[offset + 2] << 8) + exif_data[offset + 3]; +} + +inline int fetch_exif_orientation(unsigned char* exif_data_ptr, size_t size) { + int exif_orientation = -1; + + // Exif binary structure looks like this + // First 6 bytes: [E, x, i, f, 0, 0] + // Endianness, 2 bytes : [M, M] or [I, I] + // Tag mark, 2 bytes: [0, 0x2a] + // Offset, 4 bytes + // Num entries, 2 bytes + // Tag entries and data, tag has 2 bytes and its data has 10 bytes + // For more details: + // http://www.media.mit.edu/pia/Research/deepview/exif.html + + ExifDataReader exif_data(exif_data_ptr, size); + auto endianness = get_endianness(exif_data); + + // Checking whether Tag Mark (0x002A) correspond to one contained in the + // Jpeg file + uint16_t tag_mark = get_uint16(exif_data, endianness, 2); + if (tag_mark == REQ_EXIF_TAG_MARK) { + auto offset = get_uint32(exif_data, endianness, 4); + size_t num_entry = get_uint16(exif_data, endianness, offset); + offset += 2; // go to start of tag fields + constexpr size_t tiff_field_size = 12; + for (size_t entry = 0; entry < num_entry; entry++) { + // Here we just search for orientation tag and parse it + auto tag_num = get_uint16(exif_data, endianness, offset); + if (tag_num == INCORRECT_TAG) { + break; + } + if (tag_num == ORIENTATION_EXIF_TAG) { + exif_orientation = get_uint16(exif_data, endianness, offset + 8); + break; + } + offset += tiff_field_size; + } + } + return exif_orientation; +} + +#if JPEG_FOUND +inline int fetch_jpeg_exif_orientation(j_decompress_ptr cinfo) { + // Check for Exif marker APP1 + jpeg_saved_marker_ptr exif_marker = 0; + jpeg_saved_marker_ptr cmarker = cinfo->marker_list; + while (cmarker && exif_marker == 0) { + if (cmarker->marker == APP1) { + exif_marker = cmarker; + } + cmarker = cmarker->next; + } + + if (!exif_marker) { + return -1; + } + + constexpr size_t start_offset = 6; + if (exif_marker->data_length <= start_offset) { + return -1; + } + + auto* exif_data_ptr = exif_marker->data + start_offset; + auto size = exif_marker->data_length - start_offset; + + return fetch_exif_orientation(exif_data_ptr, size); +} +#endif // #if JPEG_FOUND + +#if PNG_FOUND && defined(PNG_eXIf_SUPPORTED) +inline int fetch_png_exif_orientation(png_structp png_ptr, png_infop info_ptr) { + png_uint_32 num_exif = 0; + png_bytep exif = 0; + + // Exif info could be in info_ptr + if (png_get_valid(png_ptr, info_ptr, PNG_INFO_eXIf)) { + png_get_eXIf_1(png_ptr, info_ptr, &num_exif, &exif); + } + + if (exif && num_exif > 0) { + return fetch_exif_orientation(exif, num_exif); + } + return -1; +} +#endif // #if PNG_FOUND && defined(PNG_eXIf_SUPPORTED) + +constexpr uint16_t IMAGE_ORIENTATION_TL = 1; // normal orientation +constexpr uint16_t IMAGE_ORIENTATION_TR = 2; // needs horizontal flip +constexpr uint16_t IMAGE_ORIENTATION_BR = 3; // needs 180 rotation +constexpr uint16_t IMAGE_ORIENTATION_BL = 4; // needs vertical flip +constexpr uint16_t IMAGE_ORIENTATION_LT = + 5; // mirrored horizontal & rotate 270 CW +constexpr uint16_t IMAGE_ORIENTATION_RT = 6; // rotate 90 CW +constexpr uint16_t IMAGE_ORIENTATION_RB = + 7; // mirrored horizontal & rotate 90 CW +constexpr uint16_t IMAGE_ORIENTATION_LB = 8; // needs 270 CW rotation + +inline torch::Tensor exif_orientation_transform( + const torch::Tensor& image, + int orientation) { + if (orientation == IMAGE_ORIENTATION_TL) { + return image; + } else if (orientation == IMAGE_ORIENTATION_TR) { + return image.flip(-1); + } else if (orientation == IMAGE_ORIENTATION_BR) { + // needs 180 rotation equivalent to + // flip both horizontally and vertically + return image.flip({-2, -1}); + } else if (orientation == IMAGE_ORIENTATION_BL) { + return image.flip(-2); + } else if (orientation == IMAGE_ORIENTATION_LT) { + return image.transpose(-1, -2); + } else if (orientation == IMAGE_ORIENTATION_RT) { + return image.transpose(-1, -2).flip(-1); + } else if (orientation == IMAGE_ORIENTATION_RB) { + return image.transpose(-1, -2).flip({-2, -1}); + } else if (orientation == IMAGE_ORIENTATION_LB) { + return image.transpose(-1, -2).flip(-2); + } + return image; +} + +} // namespace exif_private +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/giflib/dgif_lib.c b/product/include/torchvision/io/image/cpu/giflib/dgif_lib.c new file mode 100644 index 00000000000..297f12f15c4 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/giflib/dgif_lib.c @@ -0,0 +1,1312 @@ +/****************************************************************************** + +dgif_lib.c - GIF decoding + +The functions here and in egif_lib.c are partitioned carefully so that +if you only require one of read and write capability, only one of these +two modules will be linked. Preserve this property! + +*****************************************************************************/ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (C) Eric S. Raymond + +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif /* _WIN32 */ + +#include "gif_lib.h" +#include "gif_lib_private.h" + +/* compose unsigned little endian value */ +#define UNSIGNED_LITTLE_ENDIAN(lo, hi) ((lo) | ((hi) << 8)) + +/* avoid extra function call in case we use fread (TVT) */ +static int InternalRead(GifFileType *gif, GifByteType *buf, int len) { + // fprintf(stderr, "### Read: %d\n", len); + return (((GifFilePrivateType *)gif->Private)->Read + ? ((GifFilePrivateType *)gif->Private)->Read(gif, buf, len) + : fread(buf, 1, len, + ((GifFilePrivateType *)gif->Private)->File)); +} + +static int DGifGetWord(GifFileType *GifFile, GifWord *Word); +static int DGifSetupDecompress(GifFileType *GifFile); +static int DGifDecompressLine(GifFileType *GifFile, GifPixelType *Line, + int LineLen); +static int DGifGetPrefixChar(const GifPrefixType *Prefix, int Code, + int ClearCode); +static int DGifDecompressInput(GifFileType *GifFile, int *Code); +static int DGifBufferedInput(GifFileType *GifFile, GifByteType *Buf, + GifByteType *NextByte); + +/****************************************************************************** + Open a new GIF file for read, given by its name. + Returns dynamically allocated GifFileType pointer which serves as the GIF + info record. +******************************************************************************/ +GifFileType *DGifOpenFileName(const char *FileName, int *Error) { + int FileHandle; + GifFileType *GifFile; + + if ((FileHandle = open(FileName, O_RDONLY)) == -1) { + if (Error != NULL) { + *Error = D_GIF_ERR_OPEN_FAILED; + } + return NULL; + } + + GifFile = DGifOpenFileHandle(FileHandle, Error); + return GifFile; +} + +/****************************************************************************** + Update a new GIF file, given its file handle. + Returns dynamically allocated GifFileType pointer which serves as the GIF + info record. +******************************************************************************/ +GifFileType *DGifOpenFileHandle(int FileHandle, int *Error) { + char Buf[GIF_STAMP_LEN + 1]; + GifFileType *GifFile; + GifFilePrivateType *Private; + FILE *f; + + GifFile = (GifFileType *)malloc(sizeof(GifFileType)); + if (GifFile == NULL) { + if (Error != NULL) { + *Error = D_GIF_ERR_NOT_ENOUGH_MEM; + } + (void)close(FileHandle); + return NULL; + } + + /*@i1@*/ memset(GifFile, '\0', sizeof(GifFileType)); + + /* Belt and suspenders, in case the null pointer isn't zero */ + GifFile->SavedImages = NULL; + GifFile->SColorMap = NULL; + + Private = (GifFilePrivateType *)calloc(1, sizeof(GifFilePrivateType)); + if (Private == NULL) { + if (Error != NULL) { + *Error = D_GIF_ERR_NOT_ENOUGH_MEM; + } + (void)close(FileHandle); + free((char *)GifFile); + return NULL; + } + + /*@i1@*/ memset(Private, '\0', sizeof(GifFilePrivateType)); + +#ifdef _WIN32 + _setmode(FileHandle, O_BINARY); /* Make sure it is in binary mode. */ +#endif /* _WIN32 */ + + f = fdopen(FileHandle, "rb"); /* Make it into a stream: */ + + /*@-mustfreeonly@*/ + GifFile->Private = (void *)Private; + Private->FileHandle = FileHandle; + Private->File = f; + Private->FileState = FILE_STATE_READ; + Private->Read = NULL; /* don't use alternate input method (TVT) */ + GifFile->UserData = NULL; /* TVT */ + /*@=mustfreeonly@*/ + + /* Let's see if this is a GIF file: */ + /* coverity[check_return] */ + if (InternalRead(GifFile, (unsigned char *)Buf, GIF_STAMP_LEN) != + GIF_STAMP_LEN) { + if (Error != NULL) { + *Error = D_GIF_ERR_READ_FAILED; + } + (void)fclose(f); + free((char *)Private); + free((char *)GifFile); + return NULL; + } + + /* Check for GIF prefix at start of file */ + Buf[GIF_STAMP_LEN] = 0; + if (strncmp(GIF_STAMP, Buf, GIF_VERSION_POS) != 0) { + if (Error != NULL) { + *Error = D_GIF_ERR_NOT_GIF_FILE; + } + (void)fclose(f); + free((char *)Private); + free((char *)GifFile); + return NULL; + } + + if (DGifGetScreenDesc(GifFile) == GIF_ERROR) { + (void)fclose(f); + free((char *)Private); + free((char *)GifFile); + return NULL; + } + + GifFile->Error = 0; + + /* What version of GIF? */ + Private->gif89 = (Buf[GIF_VERSION_POS + 1] == '9'); + + return GifFile; +} + +/****************************************************************************** + GifFileType constructor with user supplied input function (TVT) +******************************************************************************/ +GifFileType *DGifOpen(void *userData, InputFunc readFunc, int *Error) { + char Buf[GIF_STAMP_LEN + 1]; + GifFileType *GifFile; + GifFilePrivateType *Private; + + GifFile = (GifFileType *)malloc(sizeof(GifFileType)); + if (GifFile == NULL) { + if (Error != NULL) { + *Error = D_GIF_ERR_NOT_ENOUGH_MEM; + } + return NULL; + } + + memset(GifFile, '\0', sizeof(GifFileType)); + + /* Belt and suspenders, in case the null pointer isn't zero */ + GifFile->SavedImages = NULL; + GifFile->SColorMap = NULL; + + Private = (GifFilePrivateType *)calloc(1, sizeof(GifFilePrivateType)); + if (!Private) { + if (Error != NULL) { + *Error = D_GIF_ERR_NOT_ENOUGH_MEM; + } + free((char *)GifFile); + return NULL; + } + /*@i1@*/ memset(Private, '\0', sizeof(GifFilePrivateType)); + + GifFile->Private = (void *)Private; + Private->FileHandle = 0; + Private->File = NULL; + Private->FileState = FILE_STATE_READ; + + Private->Read = readFunc; /* TVT */ + GifFile->UserData = userData; /* TVT */ + + /* Lets see if this is a GIF file: */ + /* coverity[check_return] */ + if (InternalRead(GifFile, (unsigned char *)Buf, GIF_STAMP_LEN) != + GIF_STAMP_LEN) { + if (Error != NULL) { + *Error = D_GIF_ERR_READ_FAILED; + } + free((char *)Private); + free((char *)GifFile); + return NULL; + } + + /* Check for GIF prefix at start of file */ + Buf[GIF_STAMP_LEN] = '\0'; + if (strncmp(GIF_STAMP, Buf, GIF_VERSION_POS) != 0) { + if (Error != NULL) { + *Error = D_GIF_ERR_NOT_GIF_FILE; + } + free((char *)Private); + free((char *)GifFile); + return NULL; + } + + if (DGifGetScreenDesc(GifFile) == GIF_ERROR) { + free((char *)Private); + free((char *)GifFile); + if (Error != NULL) { + *Error = D_GIF_ERR_NO_SCRN_DSCR; + } + return NULL; + } + + GifFile->Error = 0; + + /* What version of GIF? */ + Private->gif89 = (Buf[GIF_VERSION_POS + 1] == '9'); + + return GifFile; +} + +/****************************************************************************** + This routine should be called before any other DGif calls. Note that + this routine is called automatically from DGif file open routines. +******************************************************************************/ +int DGifGetScreenDesc(GifFileType *GifFile) { + int BitsPerPixel; + bool SortFlag; + GifByteType Buf[3]; + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + if (!IS_READABLE(Private)) { + /* This file was NOT open for reading: */ + GifFile->Error = D_GIF_ERR_NOT_READABLE; + return GIF_ERROR; + } + + /* Put the screen descriptor into the file: */ + if (DGifGetWord(GifFile, &GifFile->SWidth) == GIF_ERROR || + DGifGetWord(GifFile, &GifFile->SHeight) == GIF_ERROR) { + return GIF_ERROR; + } + + if (InternalRead(GifFile, Buf, 3) != 3) { + GifFile->Error = D_GIF_ERR_READ_FAILED; + GifFreeMapObject(GifFile->SColorMap); + GifFile->SColorMap = NULL; + return GIF_ERROR; + } + GifFile->SColorResolution = (((Buf[0] & 0x70) + 1) >> 4) + 1; + SortFlag = (Buf[0] & 0x08) != 0; + BitsPerPixel = (Buf[0] & 0x07) + 1; + GifFile->SBackGroundColor = Buf[1]; + GifFile->AspectByte = Buf[2]; + if (Buf[0] & 0x80) { /* Do we have global color map? */ + int i; + + GifFile->SColorMap = GifMakeMapObject(1 << BitsPerPixel, NULL); + if (GifFile->SColorMap == NULL) { + GifFile->Error = D_GIF_ERR_NOT_ENOUGH_MEM; + return GIF_ERROR; + } + + /* Get the global color map: */ + GifFile->SColorMap->SortFlag = SortFlag; + for (i = 0; i < GifFile->SColorMap->ColorCount; i++) { + /* coverity[check_return] */ + if (InternalRead(GifFile, Buf, 3) != 3) { + GifFreeMapObject(GifFile->SColorMap); + GifFile->SColorMap = NULL; + GifFile->Error = D_GIF_ERR_READ_FAILED; + return GIF_ERROR; + } + GifFile->SColorMap->Colors[i].Red = Buf[0]; + GifFile->SColorMap->Colors[i].Green = Buf[1]; + GifFile->SColorMap->Colors[i].Blue = Buf[2]; + } + } else { + GifFile->SColorMap = NULL; + } + + /* + * No check here for whether the background color is in range for the + * screen color map. Possibly there should be. + */ + + return GIF_OK; +} + +const char *DGifGetGifVersion(GifFileType *GifFile) { + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + if (Private->gif89) { + return GIF89_STAMP; + } else { + return GIF87_STAMP; + } +} + +/****************************************************************************** + This routine should be called before any attempt to read an image. +******************************************************************************/ +int DGifGetRecordType(GifFileType *GifFile, GifRecordType *Type) { + GifByteType Buf; + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + if (!IS_READABLE(Private)) { + /* This file was NOT open for reading: */ + GifFile->Error = D_GIF_ERR_NOT_READABLE; + return GIF_ERROR; + } + + /* coverity[check_return] */ + if (InternalRead(GifFile, &Buf, 1) != 1) { + GifFile->Error = D_GIF_ERR_READ_FAILED; + return GIF_ERROR; + } + + // fprintf(stderr, "### DGifGetRecordType: %02x\n", Buf); + switch (Buf) { + case DESCRIPTOR_INTRODUCER: + *Type = IMAGE_DESC_RECORD_TYPE; + break; + case EXTENSION_INTRODUCER: + *Type = EXTENSION_RECORD_TYPE; + break; + case TERMINATOR_INTRODUCER: + *Type = TERMINATE_RECORD_TYPE; + break; + default: + *Type = UNDEFINED_RECORD_TYPE; + GifFile->Error = D_GIF_ERR_WRONG_RECORD; + return GIF_ERROR; + } + + return GIF_OK; +} + +int DGifGetImageHeader(GifFileType *GifFile) { + unsigned int BitsPerPixel; + GifByteType Buf[3]; + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + if (!IS_READABLE(Private)) { + /* This file was NOT open for reading: */ + GifFile->Error = D_GIF_ERR_NOT_READABLE; + return GIF_ERROR; + } + + if (DGifGetWord(GifFile, &GifFile->Image.Left) == GIF_ERROR || + DGifGetWord(GifFile, &GifFile->Image.Top) == GIF_ERROR || + DGifGetWord(GifFile, &GifFile->Image.Width) == GIF_ERROR || + DGifGetWord(GifFile, &GifFile->Image.Height) == GIF_ERROR) { + return GIF_ERROR; + } + if (InternalRead(GifFile, Buf, 1) != 1) { + GifFile->Error = D_GIF_ERR_READ_FAILED; + GifFreeMapObject(GifFile->Image.ColorMap); + GifFile->Image.ColorMap = NULL; + return GIF_ERROR; + } + BitsPerPixel = (Buf[0] & 0x07) + 1; + GifFile->Image.Interlace = (Buf[0] & 0x40) ? true : false; + + /* Setup the colormap */ + if (GifFile->Image.ColorMap) { + GifFreeMapObject(GifFile->Image.ColorMap); + GifFile->Image.ColorMap = NULL; + } + /* Does this image have local color map? */ + if (Buf[0] & 0x80) { + int i; + + GifFile->Image.ColorMap = + GifMakeMapObject(1 << BitsPerPixel, NULL); + if (GifFile->Image.ColorMap == NULL) { + GifFile->Error = D_GIF_ERR_NOT_ENOUGH_MEM; + return GIF_ERROR; + } + + /* Get the image local color map: */ + for (i = 0; i < GifFile->Image.ColorMap->ColorCount; i++) { + /* coverity[check_return] */ + if (InternalRead(GifFile, Buf, 3) != 3) { + GifFreeMapObject(GifFile->Image.ColorMap); + GifFile->Error = D_GIF_ERR_READ_FAILED; + GifFile->Image.ColorMap = NULL; + return GIF_ERROR; + } + GifFile->Image.ColorMap->Colors[i].Red = Buf[0]; + GifFile->Image.ColorMap->Colors[i].Green = Buf[1]; + GifFile->Image.ColorMap->Colors[i].Blue = Buf[2]; + } + } + + Private->PixelCount = + (long)GifFile->Image.Width * (long)GifFile->Image.Height; + + /* Reset decompress algorithm parameters. */ + return DGifSetupDecompress(GifFile); +} + +/****************************************************************************** + This routine should be called before any attempt to read an image. + Note it is assumed the Image desc. header has been read. +******************************************************************************/ +int DGifGetImageDesc(GifFileType *GifFile) { + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + SavedImage *sp; + + if (!IS_READABLE(Private)) { + /* This file was NOT open for reading: */ + GifFile->Error = D_GIF_ERR_NOT_READABLE; + return GIF_ERROR; + } + + if (DGifGetImageHeader(GifFile) == GIF_ERROR) { + return GIF_ERROR; + } + + if (GifFile->SavedImages) { + SavedImage *new_saved_images = (SavedImage *)reallocarray( + GifFile->SavedImages, (GifFile->ImageCount + 1), + sizeof(SavedImage)); + if (new_saved_images == NULL) { + GifFile->Error = D_GIF_ERR_NOT_ENOUGH_MEM; + return GIF_ERROR; + } + GifFile->SavedImages = new_saved_images; + } else { + if ((GifFile->SavedImages = + (SavedImage *)malloc(sizeof(SavedImage))) == NULL) { + GifFile->Error = D_GIF_ERR_NOT_ENOUGH_MEM; + return GIF_ERROR; + } + } + + sp = &GifFile->SavedImages[GifFile->ImageCount]; + memcpy(&sp->ImageDesc, &GifFile->Image, sizeof(GifImageDesc)); + if (GifFile->Image.ColorMap != NULL) { + sp->ImageDesc.ColorMap = + GifMakeMapObject(GifFile->Image.ColorMap->ColorCount, + GifFile->Image.ColorMap->Colors); + if (sp->ImageDesc.ColorMap == NULL) { + GifFile->Error = D_GIF_ERR_NOT_ENOUGH_MEM; + return GIF_ERROR; + } + } + sp->RasterBits = (unsigned char *)NULL; + sp->ExtensionBlockCount = 0; + sp->ExtensionBlocks = (ExtensionBlock *)NULL; + + GifFile->ImageCount++; + + return GIF_OK; +} + +/****************************************************************************** + Get one full scanned line (Line) of length LineLen from GIF file. +******************************************************************************/ +int DGifGetLine(GifFileType *GifFile, GifPixelType *Line, int LineLen) { + GifByteType *Dummy; + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + if (!IS_READABLE(Private)) { + /* This file was NOT open for reading: */ + GifFile->Error = D_GIF_ERR_NOT_READABLE; + return GIF_ERROR; + } + + if (!LineLen) { + LineLen = GifFile->Image.Width; + } + + if ((Private->PixelCount -= LineLen) > 0xffff0000UL) { + GifFile->Error = D_GIF_ERR_DATA_TOO_BIG; + return GIF_ERROR; + } + + if (DGifDecompressLine(GifFile, Line, LineLen) == GIF_OK) { + if (Private->PixelCount == 0) { + /* We probably won't be called any more, so let's clean + * up everything before we return: need to flush out all + * the rest of image until an empty block (size 0) + * detected. We use GetCodeNext. + */ + do { + if (DGifGetCodeNext(GifFile, &Dummy) == + GIF_ERROR) { + return GIF_ERROR; + } + } while (Dummy != NULL); + } + return GIF_OK; + } else { + return GIF_ERROR; + } +} + +/****************************************************************************** + Put one pixel (Pixel) into GIF file. +******************************************************************************/ +int DGifGetPixel(GifFileType *GifFile, GifPixelType Pixel) { + GifByteType *Dummy; + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + if (!IS_READABLE(Private)) { + /* This file was NOT open for reading: */ + GifFile->Error = D_GIF_ERR_NOT_READABLE; + return GIF_ERROR; + } + if (--Private->PixelCount > 0xffff0000UL) { + GifFile->Error = D_GIF_ERR_DATA_TOO_BIG; + return GIF_ERROR; + } + + if (DGifDecompressLine(GifFile, &Pixel, 1) == GIF_OK) { + if (Private->PixelCount == 0) { + /* We probably won't be called any more, so let's clean + * up everything before we return: need to flush out all + * the rest of image until an empty block (size 0) + * detected. We use GetCodeNext. + */ + do { + if (DGifGetCodeNext(GifFile, &Dummy) == + GIF_ERROR) { + return GIF_ERROR; + } + } while (Dummy != NULL); + } + return GIF_OK; + } else { + return GIF_ERROR; + } +} + +/****************************************************************************** + Get an extension block (see GIF manual) from GIF file. This routine only + returns the first data block, and DGifGetExtensionNext should be called + after this one until NULL extension is returned. + The Extension should NOT be freed by the user (not dynamically allocated). + Note it is assumed the Extension description header has been read. +******************************************************************************/ +int DGifGetExtension(GifFileType *GifFile, int *ExtCode, + GifByteType **Extension) { + GifByteType Buf; + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + // fprintf(stderr, "### -> DGifGetExtension:\n"); + if (!IS_READABLE(Private)) { + /* This file was NOT open for reading: */ + GifFile->Error = D_GIF_ERR_NOT_READABLE; + return GIF_ERROR; + } + + /* coverity[check_return] */ + if (InternalRead(GifFile, &Buf, 1) != 1) { + GifFile->Error = D_GIF_ERR_READ_FAILED; + return GIF_ERROR; + } + *ExtCode = Buf; + // fprintf(stderr, "### <- DGifGetExtension: %02x, about to call + // next\n", Buf); + + return DGifGetExtensionNext(GifFile, Extension); +} + +/****************************************************************************** + Get a following extension block (see GIF manual) from GIF file. This + routine should be called until NULL Extension is returned. + The Extension should NOT be freed by the user (not dynamically allocated). +******************************************************************************/ +int DGifGetExtensionNext(GifFileType *GifFile, GifByteType **Extension) { + GifByteType Buf; + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + // fprintf(stderr, "### -> DGifGetExtensionNext\n"); + if (InternalRead(GifFile, &Buf, 1) != 1) { + GifFile->Error = D_GIF_ERR_READ_FAILED; + return GIF_ERROR; + } + // fprintf(stderr, "### DGifGetExtensionNext sees %d\n", Buf); + + if (Buf > 0) { + *Extension = Private->Buf; /* Use private unused buffer. */ + (*Extension)[0] = + Buf; /* Pascal strings notation (pos. 0 is len.). */ + /* coverity[tainted_data,check_return] */ + if (InternalRead(GifFile, &((*Extension)[1]), Buf) != Buf) { + GifFile->Error = D_GIF_ERR_READ_FAILED; + return GIF_ERROR; + } + } else { + *Extension = NULL; + } + // fprintf(stderr, "### <- DGifGetExtensionNext: %p\n", Extension); + + return GIF_OK; +} + +/****************************************************************************** + Extract a Graphics Control Block from raw extension data +******************************************************************************/ + +int DGifExtensionToGCB(const size_t GifExtensionLength, + const GifByteType *GifExtension, + GraphicsControlBlock *GCB) { + if (GifExtensionLength != 4) { + return GIF_ERROR; + } + + GCB->DisposalMode = (GifExtension[0] >> 2) & 0x07; + GCB->UserInputFlag = (GifExtension[0] & 0x02) != 0; + GCB->DelayTime = + UNSIGNED_LITTLE_ENDIAN(GifExtension[1], GifExtension[2]); + if (GifExtension[0] & 0x01) { + GCB->TransparentColor = (int)GifExtension[3]; + } else { + GCB->TransparentColor = NO_TRANSPARENT_COLOR; + } + + return GIF_OK; +} + +/****************************************************************************** + Extract the Graphics Control Block for a saved image, if it exists. +******************************************************************************/ + +int DGifSavedExtensionToGCB(GifFileType *GifFile, int ImageIndex, + GraphicsControlBlock *GCB) { + int i; + + if (ImageIndex < 0 || ImageIndex > GifFile->ImageCount - 1) { + return GIF_ERROR; + } + + GCB->DisposalMode = DISPOSAL_UNSPECIFIED; + GCB->UserInputFlag = false; + GCB->DelayTime = 0; + GCB->TransparentColor = NO_TRANSPARENT_COLOR; + + for (i = 0; i < GifFile->SavedImages[ImageIndex].ExtensionBlockCount; + i++) { + ExtensionBlock *ep = + &GifFile->SavedImages[ImageIndex].ExtensionBlocks[i]; + if (ep->Function == GRAPHICS_EXT_FUNC_CODE) { + return DGifExtensionToGCB(ep->ByteCount, ep->Bytes, + GCB); + } + } + + return GIF_ERROR; +} + +/****************************************************************************** + This routine should be called last, to close the GIF file. +******************************************************************************/ +int DGifCloseFile(GifFileType *GifFile, int *ErrorCode) { + GifFilePrivateType *Private; + + if (GifFile == NULL || GifFile->Private == NULL) { + return GIF_ERROR; + } + + if (GifFile->Image.ColorMap) { + GifFreeMapObject(GifFile->Image.ColorMap); + GifFile->Image.ColorMap = NULL; + } + + if (GifFile->SColorMap) { + GifFreeMapObject(GifFile->SColorMap); + GifFile->SColorMap = NULL; + } + + if (GifFile->SavedImages) { + GifFreeSavedImages(GifFile); + GifFile->SavedImages = NULL; + } + + GifFreeExtensions(&GifFile->ExtensionBlockCount, + &GifFile->ExtensionBlocks); + + Private = (GifFilePrivateType *)GifFile->Private; + + if (!IS_READABLE(Private)) { + /* This file was NOT open for reading: */ + if (ErrorCode != NULL) { + *ErrorCode = D_GIF_ERR_NOT_READABLE; + } + free((char *)GifFile->Private); + free(GifFile); + return GIF_ERROR; + } + + if (Private->File && (fclose(Private->File) != 0)) { + if (ErrorCode != NULL) { + *ErrorCode = D_GIF_ERR_CLOSE_FAILED; + } + free((char *)GifFile->Private); + free(GifFile); + return GIF_ERROR; + } + + free((char *)GifFile->Private); + free(GifFile); + if (ErrorCode != NULL) { + *ErrorCode = D_GIF_SUCCEEDED; + } + return GIF_OK; +} + +/****************************************************************************** + Get 2 bytes (word) from the given file: +******************************************************************************/ +static int DGifGetWord(GifFileType *GifFile, GifWord *Word) { + unsigned char c[2]; + + /* coverity[check_return] */ + if (InternalRead(GifFile, c, 2) != 2) { + GifFile->Error = D_GIF_ERR_READ_FAILED; + return GIF_ERROR; + } + + *Word = (GifWord)UNSIGNED_LITTLE_ENDIAN(c[0], c[1]); + return GIF_OK; +} + +/****************************************************************************** + Get the image code in compressed form. This routine can be called if the + information needed to be piped out as is. Obviously this is much faster + than decoding and encoding again. This routine should be followed by calls + to DGifGetCodeNext, until NULL block is returned. + The block should NOT be freed by the user (not dynamically allocated). +******************************************************************************/ +int DGifGetCode(GifFileType *GifFile, int *CodeSize, GifByteType **CodeBlock) { + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + if (!IS_READABLE(Private)) { + /* This file was NOT open for reading: */ + GifFile->Error = D_GIF_ERR_NOT_READABLE; + return GIF_ERROR; + } + + *CodeSize = Private->BitsPerPixel; + + return DGifGetCodeNext(GifFile, CodeBlock); +} + +/****************************************************************************** + Continue to get the image code in compressed form. This routine should be + called until NULL block is returned. + The block should NOT be freed by the user (not dynamically allocated). +******************************************************************************/ +int DGifGetCodeNext(GifFileType *GifFile, GifByteType **CodeBlock) { + GifByteType Buf; + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + /* coverity[tainted_data_argument] */ + /* coverity[check_return] */ + if (InternalRead(GifFile, &Buf, 1) != 1) { + GifFile->Error = D_GIF_ERR_READ_FAILED; + return GIF_ERROR; + } + + /* coverity[lower_bounds] */ + if (Buf > 0) { + *CodeBlock = Private->Buf; /* Use private unused buffer. */ + (*CodeBlock)[0] = + Buf; /* Pascal strings notation (pos. 0 is len.). */ + /* coverity[tainted_data] */ + if (InternalRead(GifFile, &((*CodeBlock)[1]), Buf) != Buf) { + GifFile->Error = D_GIF_ERR_READ_FAILED; + return GIF_ERROR; + } + } else { + *CodeBlock = NULL; + Private->Buf[0] = 0; /* Make sure the buffer is empty! */ + Private->PixelCount = + 0; /* And local info. indicate image read. */ + } + + return GIF_OK; +} + +/****************************************************************************** + Setup the LZ decompression for this image: +******************************************************************************/ +static int DGifSetupDecompress(GifFileType *GifFile) { + int i, BitsPerPixel; + GifByteType CodeSize; + GifPrefixType *Prefix; + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + /* coverity[check_return] */ + if (InternalRead(GifFile, &CodeSize, 1) < + 1) { /* Read Code size from file. */ + GifFile->Error = D_GIF_ERR_READ_FAILED; + return GIF_ERROR; /* Failed to read Code size. */ + } + BitsPerPixel = CodeSize; + + /* this can only happen on a severely malformed GIF */ + if (BitsPerPixel > 8) { + GifFile->Error = + D_GIF_ERR_READ_FAILED; /* somewhat bogus error code */ + return GIF_ERROR; /* Failed to read Code size. */ + } + + Private->Buf[0] = 0; /* Input Buffer empty. */ + Private->BitsPerPixel = BitsPerPixel; + Private->ClearCode = (1 << BitsPerPixel); + Private->EOFCode = Private->ClearCode + 1; + Private->RunningCode = Private->EOFCode + 1; + Private->RunningBits = BitsPerPixel + 1; /* Number of bits per code. */ + Private->MaxCode1 = 1 << Private->RunningBits; /* Max. code + 1. */ + Private->StackPtr = 0; /* No pixels on the pixel stack. */ + Private->LastCode = NO_SUCH_CODE; + Private->CrntShiftState = 0; /* No information in CrntShiftDWord. */ + Private->CrntShiftDWord = 0; + + Prefix = Private->Prefix; + for (i = 0; i <= LZ_MAX_CODE; i++) { + Prefix[i] = NO_SUCH_CODE; + } + + return GIF_OK; +} + +/****************************************************************************** + The LZ decompression routine: + This version decompress the given GIF file into Line of length LineLen. + This routine can be called few times (one per scan line, for example), in + order the complete the whole image. +******************************************************************************/ +static int DGifDecompressLine(GifFileType *GifFile, GifPixelType *Line, + int LineLen) { + int i = 0; + int j, CrntCode, EOFCode, ClearCode, CrntPrefix, LastCode, StackPtr; + GifByteType *Stack, *Suffix; + GifPrefixType *Prefix; + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + StackPtr = Private->StackPtr; + Prefix = Private->Prefix; + Suffix = Private->Suffix; + Stack = Private->Stack; + EOFCode = Private->EOFCode; + ClearCode = Private->ClearCode; + LastCode = Private->LastCode; + + if (StackPtr > LZ_MAX_CODE) { + return GIF_ERROR; + } + + if (StackPtr != 0) { + /* Let pop the stack off before continueing to read the GIF + * file: */ + while (StackPtr != 0 && i < LineLen) { + Line[i++] = Stack[--StackPtr]; + } + } + + while (i < LineLen) { /* Decode LineLen items. */ + if (DGifDecompressInput(GifFile, &CrntCode) == GIF_ERROR) { + return GIF_ERROR; + } + + if (CrntCode == EOFCode) { + /* Note however that usually we will not be here as we + * will stop decoding as soon as we got all the pixel, + * or EOF code will not be read at all, and + * DGifGetLine/Pixel clean everything. */ + GifFile->Error = D_GIF_ERR_EOF_TOO_SOON; + return GIF_ERROR; + } else if (CrntCode == ClearCode) { + /* We need to start over again: */ + for (j = 0; j <= LZ_MAX_CODE; j++) { + Prefix[j] = NO_SUCH_CODE; + } + Private->RunningCode = Private->EOFCode + 1; + Private->RunningBits = Private->BitsPerPixel + 1; + Private->MaxCode1 = 1 << Private->RunningBits; + LastCode = Private->LastCode = NO_SUCH_CODE; + } else { + /* Its regular code - if in pixel range simply add it to + * output stream, otherwise trace to codes linked list + * until the prefix is in pixel range: */ + if (CrntCode < ClearCode) { + /* This is simple - its pixel scalar, so add it + * to output: */ + Line[i++] = CrntCode; + } else { + /* Its a code to needed to be traced: trace the + * linked list until the prefix is a pixel, + * while pushing the suffix pixels on our stack. + * If we done, pop the stack in reverse (thats + * what stack is good for!) order to output. */ + if (Prefix[CrntCode] == NO_SUCH_CODE) { + CrntPrefix = LastCode; + + /* Only allowed if CrntCode is exactly + * the running code: In that case + * CrntCode = XXXCode, CrntCode or the + * prefix code is last code and the + * suffix char is exactly the prefix of + * last code! */ + if (CrntCode == + Private->RunningCode - 2) { + Suffix[Private->RunningCode - + 2] = Stack[StackPtr++] = + DGifGetPrefixChar( + Prefix, LastCode, + ClearCode); + } else { + Suffix[Private->RunningCode - + 2] = Stack[StackPtr++] = + DGifGetPrefixChar( + Prefix, CrntCode, + ClearCode); + } + } else { + CrntPrefix = CrntCode; + } + + /* Now (if image is O.K.) we should not get a + * NO_SUCH_CODE during the trace. As we might + * loop forever, in case of defective image, we + * use StackPtr as loop counter and stop before + * overflowing Stack[]. */ + while (StackPtr < LZ_MAX_CODE && + CrntPrefix > ClearCode && + CrntPrefix <= LZ_MAX_CODE) { + Stack[StackPtr++] = Suffix[CrntPrefix]; + CrntPrefix = Prefix[CrntPrefix]; + } + if (StackPtr >= LZ_MAX_CODE || + CrntPrefix > LZ_MAX_CODE) { + GifFile->Error = D_GIF_ERR_IMAGE_DEFECT; + return GIF_ERROR; + } + /* Push the last character on stack: */ + Stack[StackPtr++] = CrntPrefix; + + /* Now lets pop all the stack into output: */ + while (StackPtr != 0 && i < LineLen) { + Line[i++] = Stack[--StackPtr]; + } + } + if (LastCode != NO_SUCH_CODE && + Private->RunningCode - 2 < (LZ_MAX_CODE + 1) && + Prefix[Private->RunningCode - 2] == NO_SUCH_CODE) { + Prefix[Private->RunningCode - 2] = LastCode; + + if (CrntCode == Private->RunningCode - 2) { + /* Only allowed if CrntCode is exactly + * the running code: In that case + * CrntCode = XXXCode, CrntCode or the + * prefix code is last code and the + * suffix char is exactly the prefix of + * last code! */ + Suffix[Private->RunningCode - 2] = + DGifGetPrefixChar(Prefix, LastCode, + ClearCode); + } else { + Suffix[Private->RunningCode - 2] = + DGifGetPrefixChar(Prefix, CrntCode, + ClearCode); + } + } + LastCode = CrntCode; + } + } + + Private->LastCode = LastCode; + Private->StackPtr = StackPtr; + + return GIF_OK; +} + +/****************************************************************************** + Routine to trace the Prefixes linked list until we get a prefix which is + not code, but a pixel value (less than ClearCode). Returns that pixel value. + If image is defective, we might loop here forever, so we limit the loops to + the maximum possible if image O.k. - LZ_MAX_CODE times. +******************************************************************************/ +static int DGifGetPrefixChar(const GifPrefixType *Prefix, int Code, + int ClearCode) { + int i = 0; + + while (Code > ClearCode && i++ <= LZ_MAX_CODE) { + if (Code > LZ_MAX_CODE) { + return NO_SUCH_CODE; + } + Code = Prefix[Code]; + } + return Code; +} + +/****************************************************************************** + Interface for accessing the LZ codes directly. Set Code to the real code + (12bits), or to -1 if EOF code is returned. +******************************************************************************/ +int DGifGetLZCodes(GifFileType *GifFile, int *Code) { + GifByteType *CodeBlock; + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + if (!IS_READABLE(Private)) { + /* This file was NOT open for reading: */ + GifFile->Error = D_GIF_ERR_NOT_READABLE; + return GIF_ERROR; + } + + if (DGifDecompressInput(GifFile, Code) == GIF_ERROR) { + return GIF_ERROR; + } + + if (*Code == Private->EOFCode) { + /* Skip rest of codes (hopefully only NULL terminating block): + */ + do { + if (DGifGetCodeNext(GifFile, &CodeBlock) == GIF_ERROR) { + return GIF_ERROR; + } + } while (CodeBlock != NULL); + + *Code = -1; + } else if (*Code == Private->ClearCode) { + /* We need to start over again: */ + Private->RunningCode = Private->EOFCode + 1; + Private->RunningBits = Private->BitsPerPixel + 1; + Private->MaxCode1 = 1 << Private->RunningBits; + } + + return GIF_OK; +} + +/****************************************************************************** + The LZ decompression input routine: + This routine is responsable for the decompression of the bit stream from + 8 bits (bytes) packets, into the real codes. + Returns GIF_OK if read successfully. +******************************************************************************/ +static int DGifDecompressInput(GifFileType *GifFile, int *Code) { + static const unsigned short CodeMasks[] = { + 0x0000, 0x0001, 0x0003, 0x0007, 0x000f, 0x001f, 0x003f, + 0x007f, 0x00ff, 0x01ff, 0x03ff, 0x07ff, 0x0fff}; + + GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; + + GifByteType NextByte; + + /* The image can't contain more than LZ_BITS per code. */ + if (Private->RunningBits > LZ_BITS) { + GifFile->Error = D_GIF_ERR_IMAGE_DEFECT; + return GIF_ERROR; + } + + while (Private->CrntShiftState < Private->RunningBits) { + /* Needs to get more bytes from input stream for next code: */ + if (DGifBufferedInput(GifFile, Private->Buf, &NextByte) == + GIF_ERROR) { + return GIF_ERROR; + } + Private->CrntShiftDWord |= ((unsigned long)NextByte) + << Private->CrntShiftState; + Private->CrntShiftState += 8; + } + *Code = Private->CrntShiftDWord & CodeMasks[Private->RunningBits]; + + Private->CrntShiftDWord >>= Private->RunningBits; + Private->CrntShiftState -= Private->RunningBits; + + /* If code cannot fit into RunningBits bits, must raise its size. Note + * however that codes above 4095 are used for special signaling. + * If we're using LZ_BITS bits already and we're at the max code, just + * keep using the table as it is, don't increment Private->RunningCode. + */ + if (Private->RunningCode < LZ_MAX_CODE + 2 && + ++Private->RunningCode > Private->MaxCode1 && + Private->RunningBits < LZ_BITS) { + Private->MaxCode1 <<= 1; + Private->RunningBits++; + } + return GIF_OK; +} + +/****************************************************************************** + This routines read one GIF data block at a time and buffers it internally + so that the decompression routine could access it. + The routine returns the next byte from its internal buffer (or read next + block in if buffer empty) and returns GIF_OK if succesful. +******************************************************************************/ +static int DGifBufferedInput(GifFileType *GifFile, GifByteType *Buf, + GifByteType *NextByte) { + if (Buf[0] == 0) { + /* Needs to read the next buffer - this one is empty: */ + /* coverity[check_return] */ + if (InternalRead(GifFile, Buf, 1) != 1) { + GifFile->Error = D_GIF_ERR_READ_FAILED; + return GIF_ERROR; + } + /* There shouldn't be any empty data blocks here as the LZW spec + * says the LZW termination code should come first. Therefore + * we shouldn't be inside this routine at that point. + */ + if (Buf[0] == 0) { + GifFile->Error = D_GIF_ERR_IMAGE_DEFECT; + return GIF_ERROR; + } + if (InternalRead(GifFile, &Buf[1], Buf[0]) != Buf[0]) { + GifFile->Error = D_GIF_ERR_READ_FAILED; + return GIF_ERROR; + } + *NextByte = Buf[1]; + Buf[1] = 2; /* We use now the second place as last char read! */ + Buf[0]--; + } else { + *NextByte = Buf[Buf[1]++]; + Buf[0]--; + } + + return GIF_OK; +} + +/****************************************************************************** + This routine is called in case of error during parsing image. We need to + decrease image counter and reallocate memory for saved images. Not decreasing + ImageCount may lead to null pointer dereference, because the last element in + SavedImages may point to the spoilt image and null pointer buffers. +*******************************************************************************/ +void DGifDecreaseImageCounter(GifFileType *GifFile) { + GifFile->ImageCount--; + if (GifFile->SavedImages[GifFile->ImageCount].RasterBits != NULL) { + free(GifFile->SavedImages[GifFile->ImageCount].RasterBits); + } + + // Realloc array according to the new image counter. + SavedImage *correct_saved_images = (SavedImage *)reallocarray( + GifFile->SavedImages, GifFile->ImageCount, sizeof(SavedImage)); + if (correct_saved_images != NULL) { + GifFile->SavedImages = correct_saved_images; + } +} + +/****************************************************************************** + This routine reads an entire GIF into core, hanging all its state info off + the GifFileType pointer. Call DGifOpenFileName() or DGifOpenFileHandle() + first to initialize I/O. Its inverse is EGifSpew(). +*******************************************************************************/ +int DGifSlurp(GifFileType *GifFile) { + size_t ImageSize; + GifRecordType RecordType; + SavedImage *sp; + GifByteType *ExtData; + int ExtFunction; + + GifFile->ExtensionBlocks = NULL; + GifFile->ExtensionBlockCount = 0; + + do { + if (DGifGetRecordType(GifFile, &RecordType) == GIF_ERROR) { + return (GIF_ERROR); + } + + switch (RecordType) { + case IMAGE_DESC_RECORD_TYPE: + if (DGifGetImageDesc(GifFile) == GIF_ERROR) { + return (GIF_ERROR); + } + + sp = &GifFile->SavedImages[GifFile->ImageCount - 1]; + /* Allocate memory for the image */ + if (sp->ImageDesc.Width <= 0 || + sp->ImageDesc.Height <= 0 || + sp->ImageDesc.Width > + (INT_MAX / sp->ImageDesc.Height)) { + DGifDecreaseImageCounter(GifFile); + return GIF_ERROR; + } + ImageSize = sp->ImageDesc.Width * sp->ImageDesc.Height; + + if (ImageSize > (SIZE_MAX / sizeof(GifPixelType))) { + DGifDecreaseImageCounter(GifFile); + return GIF_ERROR; + } + sp->RasterBits = (unsigned char *)reallocarray( + NULL, ImageSize, sizeof(GifPixelType)); + + if (sp->RasterBits == NULL) { + DGifDecreaseImageCounter(GifFile); + return GIF_ERROR; + } + + if (sp->ImageDesc.Interlace) { + int i, j; + /* + * The way an interlaced image should be read - + * offsets and jumps... + */ + static const int InterlacedOffset[] = {0, 4, 2, + 1}; + static const int InterlacedJumps[] = {8, 8, 4, + 2}; + /* Need to perform 4 passes on the image */ + for (i = 0; i < 4; i++) { + for (j = InterlacedOffset[i]; + j < sp->ImageDesc.Height; + j += InterlacedJumps[i]) { + if (DGifGetLine( + GifFile, + sp->RasterBits + + j * sp->ImageDesc + .Width, + sp->ImageDesc.Width) == + GIF_ERROR) { + DGifDecreaseImageCounter( + GifFile); + return GIF_ERROR; + } + } + } + } else { + if (DGifGetLine(GifFile, sp->RasterBits, + ImageSize) == GIF_ERROR) { + DGifDecreaseImageCounter(GifFile); + return GIF_ERROR; + } + } + + if (GifFile->ExtensionBlocks) { + sp->ExtensionBlocks = GifFile->ExtensionBlocks; + sp->ExtensionBlockCount = + GifFile->ExtensionBlockCount; + + GifFile->ExtensionBlocks = NULL; + GifFile->ExtensionBlockCount = 0; + } + break; + + case EXTENSION_RECORD_TYPE: + if (DGifGetExtension(GifFile, &ExtFunction, &ExtData) == + GIF_ERROR) { + return (GIF_ERROR); + } + /* Create an extension block with our data */ + if (ExtData != NULL) { + if (GifAddExtensionBlock( + &GifFile->ExtensionBlockCount, + &GifFile->ExtensionBlocks, ExtFunction, + ExtData[0], &ExtData[1]) == GIF_ERROR) { + return (GIF_ERROR); + } + } + for (;;) { + if (DGifGetExtensionNext(GifFile, &ExtData) == + GIF_ERROR) { + return (GIF_ERROR); + } + if (ExtData == NULL) { + break; + } + /* Continue the extension block */ + if (GifAddExtensionBlock( + &GifFile->ExtensionBlockCount, + &GifFile->ExtensionBlocks, + CONTINUE_EXT_FUNC_CODE, ExtData[0], + &ExtData[1]) == GIF_ERROR) { + return (GIF_ERROR); + } + } + break; + + case TERMINATE_RECORD_TYPE: + break; + + default: /* Should be trapped by DGifGetRecordType */ + break; + } + } while (RecordType != TERMINATE_RECORD_TYPE); + + /* Sanity check for corrupted file */ + if (GifFile->ImageCount == 0) { + GifFile->Error = D_GIF_ERR_NO_IMAG_DSCR; + return (GIF_ERROR); + } + + return (GIF_OK); +} + +/* end */ diff --git a/product/include/torchvision/io/image/cpu/giflib/gif_hash.c b/product/include/torchvision/io/image/cpu/giflib/gif_hash.c new file mode 100644 index 00000000000..e63a72accd4 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/giflib/gif_hash.c @@ -0,0 +1,128 @@ +/***************************************************************************** + +gif_hash.c -- module to support the following operations: + +1. InitHashTable - initialize hash table. +2. ClearHashTable - clear the hash table to an empty state. +2. InsertHashTable - insert one item into data structure. +3. ExistsHashTable - test if item exists in data structure. + +This module is used to hash the GIF codes during encoding. + +*****************************************************************************/ +// SPDX-License-Identifier: MIT +// SPDX-File-Copyright-Txt: (C) Copyright 1989 Gershon Elber + +#include +#include +#include +#include +#include + +#include "gif_hash.h" +#include "gif_lib.h" +#include "gif_lib_private.h" + +/* #define DEBUG_HIT_RATE Debug number of misses per hash Insert/Exists. */ + +#ifdef DEBUG_HIT_RATE +static long NumberOfTests = 0, NumberOfMisses = 0; +#endif /* DEBUG_HIT_RATE */ + +static int KeyItem(uint32_t Item); + +/****************************************************************************** + Initialize HashTable - allocate the memory needed and clear it. * +******************************************************************************/ +GifHashTableType *_InitHashTable(void) { + GifHashTableType *HashTable; + + if ((HashTable = (GifHashTableType *)malloc( + sizeof(GifHashTableType))) == NULL) { + return NULL; + } + + _ClearHashTable(HashTable); + + return HashTable; +} + +/****************************************************************************** + Routine to clear the HashTable to an empty state. * + This part is a little machine depended. Use the commented part otherwise. * +******************************************************************************/ +void _ClearHashTable(GifHashTableType *HashTable) { + memset(HashTable->HTable, 0xFF, HT_SIZE * sizeof(uint32_t)); +} + +/****************************************************************************** + Routine to insert a new Item into the HashTable. The data is assumed to be * + new one. * +******************************************************************************/ +void _InsertHashTable(GifHashTableType *HashTable, uint32_t Key, int Code) { + int HKey = KeyItem(Key); + uint32_t *HTable = HashTable->HTable; + +#ifdef DEBUG_HIT_RATE + NumberOfTests++; + NumberOfMisses++; +#endif /* DEBUG_HIT_RATE */ + + while (HT_GET_KEY(HTable[HKey]) != 0xFFFFFL) { +#ifdef DEBUG_HIT_RATE + NumberOfMisses++; +#endif /* DEBUG_HIT_RATE */ + HKey = (HKey + 1) & HT_KEY_MASK; + } + HTable[HKey] = HT_PUT_KEY(Key) | HT_PUT_CODE(Code); +} + +/****************************************************************************** + Routine to test if given Key exists in HashTable and if so returns its code * + Returns the Code if key was found, -1 if not. * +******************************************************************************/ +int _ExistsHashTable(GifHashTableType *HashTable, uint32_t Key) { + int HKey = KeyItem(Key); + uint32_t *HTable = HashTable->HTable, HTKey; + +#ifdef DEBUG_HIT_RATE + NumberOfTests++; + NumberOfMisses++; +#endif /* DEBUG_HIT_RATE */ + + while ((HTKey = HT_GET_KEY(HTable[HKey])) != 0xFFFFFL) { +#ifdef DEBUG_HIT_RATE + NumberOfMisses++; +#endif /* DEBUG_HIT_RATE */ + if (Key == HTKey) { + return HT_GET_CODE(HTable[HKey]); + } + HKey = (HKey + 1) & HT_KEY_MASK; + } + + return -1; +} + +/****************************************************************************** + Routine to generate an HKey for the hashtable out of the given unique key. * + The given Key is assumed to be 20 bits as follows: lower 8 bits are the * + new postfix character, while the upper 12 bits are the prefix code. * + Because the average hit ratio is only 2 (2 hash references per entry), * + evaluating more complex keys (such as twin prime keys) does not worth it! * +******************************************************************************/ +static int KeyItem(uint32_t Item) { + return ((Item >> 12) ^ Item) & HT_KEY_MASK; +} + +#ifdef DEBUG_HIT_RATE +/****************************************************************************** + Debugging routine to print the hit ratio - number of times the hash table * + was tested per operation. This routine was used to test the KeyItem routine * +******************************************************************************/ +void HashTablePrintHitRatio(void) { + printf("Hash Table Hit Ratio is %ld/%ld = %ld%%.\n", NumberOfMisses, + NumberOfTests, NumberOfMisses * 100 / NumberOfTests); +} +#endif /* DEBUG_HIT_RATE */ + +/* end */ diff --git a/product/include/torchvision/io/image/cpu/giflib/gif_hash.h b/product/include/torchvision/io/image/cpu/giflib/gif_hash.h new file mode 100644 index 00000000000..009cb5b8081 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/giflib/gif_hash.h @@ -0,0 +1,42 @@ +/****************************************************************************** + +gif_hash.h - magfic constants and declarations for GIF LZW + +******************************************************************************/ +// SPDX-License-Identifier: MIT + +#ifndef _GIF_HASH_H_ +#define _GIF_HASH_H_ + +#ifndef _WIN32 +#include +#endif /* _WIN32 */ +#include + +#define HT_SIZE 8192 /* 12bits = 4096 or twice as big! */ +#define HT_KEY_MASK 0x1FFF /* 13bits keys */ +#define HT_KEY_NUM_BITS 13 /* 13bits keys */ +#define HT_MAX_KEY 8191 /* 13bits - 1, maximal code possible */ +#define HT_MAX_CODE 4095 /* Biggest code possible in 12 bits. */ + +/* The 32 bits of the long are divided into two parts for the key & code: */ +/* 1. The code is 12 bits as our compression algorithm is limited to 12bits */ +/* 2. The key is 12 bits Prefix code + 8 bit new char or 20 bits. */ +/* The key is the upper 20 bits. The code is the lower 12. */ +#define HT_GET_KEY(l) (l >> 12) +#define HT_GET_CODE(l) (l & 0x0FFF) +#define HT_PUT_KEY(l) (l << 12) +#define HT_PUT_CODE(l) (l & 0x0FFF) + +typedef struct GifHashTableType { + uint32_t HTable[HT_SIZE]; +} GifHashTableType; + +GifHashTableType *_InitHashTable(void); +void _ClearHashTable(GifHashTableType *HashTable); +void _InsertHashTable(GifHashTableType *HashTable, uint32_t Key, int Code); +int _ExistsHashTable(GifHashTableType *HashTable, uint32_t Key); + +#endif /* _GIF_HASH_H_ */ + +/* end */ diff --git a/product/include/torchvision/io/image/cpu/giflib/gif_lib.h b/product/include/torchvision/io/image/cpu/giflib/gif_lib.h new file mode 100644 index 00000000000..d0c61d51682 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/giflib/gif_lib.h @@ -0,0 +1,291 @@ +/****************************************************************************** + +gif_lib.h - service library for decoding and encoding GIF images + +SPDX-License-Identifier: MIT + +*****************************************************************************/ + +#ifndef _GIF_LIB_H_ +#define _GIF_LIB_H_ 1 + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +#define GIFLIB_MAJOR 5 +#define GIFLIB_MINOR 2 +#define GIFLIB_RELEASE 2 + +#define GIF_ERROR 0 +#define GIF_OK 1 + +#include +#include + +#define GIF_STAMP "GIFVER" /* First chars in file - GIF stamp. */ +#define GIF_STAMP_LEN sizeof(GIF_STAMP) - 1 +#define GIF_VERSION_POS 3 /* Version first character in stamp. */ +#define GIF87_STAMP "GIF87a" /* First chars in file - GIF stamp. */ +#define GIF89_STAMP "GIF89a" /* First chars in file - GIF stamp. */ + +typedef unsigned char GifPixelType; +typedef unsigned char *GifRowType; +typedef unsigned char GifByteType; +typedef unsigned int GifPrefixType; +typedef int GifWord; + +typedef struct GifColorType { + GifByteType Red, Green, Blue; +} GifColorType; + +typedef struct ColorMapObject { + int ColorCount; + int BitsPerPixel; + bool SortFlag; + GifColorType *Colors; /* on malloc(3) heap */ +} ColorMapObject; + +typedef struct GifImageDesc { + GifWord Left, Top, Width, Height; /* Current image dimensions. */ + bool Interlace; /* Sequential/Interlaced lines. */ + ColorMapObject *ColorMap; /* The local color map */ +} GifImageDesc; + +typedef struct ExtensionBlock { + int ByteCount; + GifByteType *Bytes; /* on malloc(3) heap */ + int Function; /* The block function code */ +#define CONTINUE_EXT_FUNC_CODE 0x00 /* continuation subblock */ +#define COMMENT_EXT_FUNC_CODE 0xfe /* comment */ +#define GRAPHICS_EXT_FUNC_CODE 0xf9 /* graphics control (GIF89) */ +#define PLAINTEXT_EXT_FUNC_CODE 0x01 /* plaintext */ +#define APPLICATION_EXT_FUNC_CODE 0xff /* application block (GIF89) */ +} ExtensionBlock; + +typedef struct SavedImage { + GifImageDesc ImageDesc; + GifByteType *RasterBits; /* on malloc(3) heap */ + int ExtensionBlockCount; /* Count of extensions before image */ + ExtensionBlock *ExtensionBlocks; /* Extensions before image */ +} SavedImage; + +typedef struct GifFileType { + GifWord SWidth, SHeight; /* Size of virtual canvas */ + GifWord SColorResolution; /* How many colors can we generate? */ + GifWord SBackGroundColor; /* Background color for virtual canvas */ + GifByteType AspectByte; /* Used to compute pixel aspect ratio */ + ColorMapObject *SColorMap; /* Global colormap, NULL if nonexistent. */ + int ImageCount; /* Number of current image (both APIs) */ + GifImageDesc Image; /* Current image (low-level API) */ + SavedImage *SavedImages; /* Image sequence (high-level API) */ + int ExtensionBlockCount; /* Count extensions past last image */ + ExtensionBlock *ExtensionBlocks; /* Extensions past last image */ + int Error; /* Last error condition reported */ + void *UserData; /* hook to attach user data (TVT) */ + void *Private; /* Don't mess with this! */ +} GifFileType; + +#define GIF_ASPECT_RATIO(n) ((n) + 15.0 / 64.0) + +typedef enum { + UNDEFINED_RECORD_TYPE, + SCREEN_DESC_RECORD_TYPE, + IMAGE_DESC_RECORD_TYPE, /* Begin with ',' */ + EXTENSION_RECORD_TYPE, /* Begin with '!' */ + TERMINATE_RECORD_TYPE /* Begin with ';' */ +} GifRecordType; + +/* func type to read gif data from arbitrary sources (TVT) */ +typedef int (*InputFunc)(GifFileType *, GifByteType *, int); + +/* func type to write gif data to arbitrary targets. + * Returns count of bytes written. (MRB) + */ +typedef int (*OutputFunc)(GifFileType *, const GifByteType *, int); + +/****************************************************************************** + GIF89 structures +******************************************************************************/ + +typedef struct GraphicsControlBlock { + int DisposalMode; +#define DISPOSAL_UNSPECIFIED 0 /* No disposal specified. */ +#define DISPOSE_DO_NOT 1 /* Leave image in place */ +#define DISPOSE_BACKGROUND 2 /* Set area too background color */ +#define DISPOSE_PREVIOUS 3 /* Restore to previous content */ + bool UserInputFlag; /* User confirmation required before disposal */ + int DelayTime; /* pre-display delay in 0.01sec units */ + int TransparentColor; /* Palette index for transparency, -1 if none */ +#define NO_TRANSPARENT_COLOR -1 +} GraphicsControlBlock; + +/****************************************************************************** + GIF encoding routines +******************************************************************************/ + +/* Main entry points */ +GifFileType *EGifOpenFileName(const char *GifFileName, + const bool GifTestExistence, int *Error); +GifFileType *EGifOpenFileHandle(const int GifFileHandle, int *Error); +GifFileType *EGifOpen(void *userPtr, OutputFunc writeFunc, int *Error); +int EGifSpew(GifFileType *GifFile); +const char *EGifGetGifVersion(GifFileType *GifFile); /* new in 5.x */ +int EGifCloseFile(GifFileType *GifFile, int *ErrorCode); + +#define E_GIF_SUCCEEDED 0 +#define E_GIF_ERR_OPEN_FAILED 1 /* And EGif possible errors. */ +#define E_GIF_ERR_WRITE_FAILED 2 +#define E_GIF_ERR_HAS_SCRN_DSCR 3 +#define E_GIF_ERR_HAS_IMAG_DSCR 4 +#define E_GIF_ERR_NO_COLOR_MAP 5 +#define E_GIF_ERR_DATA_TOO_BIG 6 +#define E_GIF_ERR_NOT_ENOUGH_MEM 7 +#define E_GIF_ERR_DISK_IS_FULL 8 +#define E_GIF_ERR_CLOSE_FAILED 9 +#define E_GIF_ERR_NOT_WRITEABLE 10 + +/* These are legacy. You probably do not want to call them directly */ +int EGifPutScreenDesc(GifFileType *GifFile, const int GifWidth, + const int GifHeight, const int GifColorRes, + const int GifBackGround, + const ColorMapObject *GifColorMap); +int EGifPutImageDesc(GifFileType *GifFile, const int GifLeft, const int GifTop, + const int GifWidth, const int GifHeight, + const bool GifInterlace, + const ColorMapObject *GifColorMap); +void EGifSetGifVersion(GifFileType *GifFile, const bool gif89); +int EGifPutLine(GifFileType *GifFile, GifPixelType *GifLine, int GifLineLen); +int EGifPutPixel(GifFileType *GifFile, const GifPixelType GifPixel); +int EGifPutComment(GifFileType *GifFile, const char *GifComment); +int EGifPutExtensionLeader(GifFileType *GifFile, const int GifExtCode); +int EGifPutExtensionBlock(GifFileType *GifFile, const int GifExtLen, + const void *GifExtension); +int EGifPutExtensionTrailer(GifFileType *GifFile); +int EGifPutExtension(GifFileType *GifFile, const int GifExtCode, + const int GifExtLen, const void *GifExtension); +int EGifPutCode(GifFileType *GifFile, int GifCodeSize, + const GifByteType *GifCodeBlock); +int EGifPutCodeNext(GifFileType *GifFile, const GifByteType *GifCodeBlock); + +/****************************************************************************** + GIF decoding routines +******************************************************************************/ + +/* Main entry points */ +GifFileType *DGifOpenFileName(const char *GifFileName, int *Error); +GifFileType *DGifOpenFileHandle(int GifFileHandle, int *Error); +int DGifSlurp(GifFileType *GifFile); +GifFileType *DGifOpen(void *userPtr, InputFunc readFunc, + int *Error); /* new one (TVT) */ +int DGifCloseFile(GifFileType *GifFile, int *ErrorCode); + +#define D_GIF_SUCCEEDED 0 +#define D_GIF_ERR_OPEN_FAILED 101 /* And DGif possible errors. */ +#define D_GIF_ERR_READ_FAILED 102 +#define D_GIF_ERR_NOT_GIF_FILE 103 +#define D_GIF_ERR_NO_SCRN_DSCR 104 +#define D_GIF_ERR_NO_IMAG_DSCR 105 +#define D_GIF_ERR_NO_COLOR_MAP 106 +#define D_GIF_ERR_WRONG_RECORD 107 +#define D_GIF_ERR_DATA_TOO_BIG 108 +#define D_GIF_ERR_NOT_ENOUGH_MEM 109 +#define D_GIF_ERR_CLOSE_FAILED 110 +#define D_GIF_ERR_NOT_READABLE 111 +#define D_GIF_ERR_IMAGE_DEFECT 112 +#define D_GIF_ERR_EOF_TOO_SOON 113 + +/* These are legacy. You probably do not want to call them directly */ +int DGifGetScreenDesc(GifFileType *GifFile); +int DGifGetRecordType(GifFileType *GifFile, GifRecordType *GifType); +int DGifGetImageHeader(GifFileType *GifFile); +int DGifGetImageDesc(GifFileType *GifFile); +int DGifGetLine(GifFileType *GifFile, GifPixelType *GifLine, int GifLineLen); +int DGifGetPixel(GifFileType *GifFile, GifPixelType GifPixel); +int DGifGetExtension(GifFileType *GifFile, int *GifExtCode, + GifByteType **GifExtension); +int DGifGetExtensionNext(GifFileType *GifFile, GifByteType **GifExtension); +int DGifGetCode(GifFileType *GifFile, int *GifCodeSize, + GifByteType **GifCodeBlock); +int DGifGetCodeNext(GifFileType *GifFile, GifByteType **GifCodeBlock); +int DGifGetLZCodes(GifFileType *GifFile, int *GifCode); +const char *DGifGetGifVersion(GifFileType *GifFile); + +/****************************************************************************** + Error handling and reporting. +******************************************************************************/ +extern const char *GifErrorString(int ErrorCode); /* new in 2012 - ESR */ + +/***************************************************************************** + it g in core. +******************************************************************************/ + +/****************************************************************************** + Color map handling from gif_alloc.c +******************************************************************************/ + +extern ColorMapObject *GifMakeMapObject(int ColorCount, + const GifColorType *ColorMap); +extern void GifFreeMapObject(ColorMapObject *Object); +extern ColorMapObject *GifUnionColorMap(const ColorMapObject *ColorIn1, + const ColorMapObject *ColorIn2, + GifPixelType ColorTransIn2[]); +extern int GifBitSize(int n); + +/****************************************************************************** + Support for the in-core structures allocation (slurp mode). +******************************************************************************/ + +extern void GifApplyTranslation(SavedImage *Image, + const GifPixelType Translation[]); +extern int GifAddExtensionBlock(int *ExtensionBlock_Count, + ExtensionBlock **ExtensionBlocks, int Function, + unsigned int Len, unsigned char ExtData[]); +extern void GifFreeExtensions(int *ExtensionBlock_Count, + ExtensionBlock **ExtensionBlocks); +extern SavedImage *GifMakeSavedImage(GifFileType *GifFile, + const SavedImage *CopyFrom); +extern void GifFreeSavedImages(GifFileType *GifFile); + +/****************************************************************************** + 5.x functions for GIF89 graphics control blocks +******************************************************************************/ + +int DGifExtensionToGCB(const size_t GifExtensionLength, + const GifByteType *GifExtension, + GraphicsControlBlock *GCB); +size_t EGifGCBToExtension(const GraphicsControlBlock *GCB, + GifByteType *GifExtension); + +int DGifSavedExtensionToGCB(GifFileType *GifFile, int ImageIndex, + GraphicsControlBlock *GCB); +int EGifGCBToSavedExtension(const GraphicsControlBlock *GCB, + GifFileType *GifFile, int ImageIndex); + +/****************************************************************************** + The library's internal utility font +******************************************************************************/ + +#define GIF_FONT_WIDTH 8 +#define GIF_FONT_HEIGHT 8 +extern const unsigned char GifAsciiTable8x8[][GIF_FONT_WIDTH]; + +extern void GifDrawText8x8(SavedImage *Image, const int x, const int y, + const char *legend, const int color); + +extern void GifDrawBox(SavedImage *Image, const int x, const int y, const int w, + const int d, const int color); + +extern void GifDrawRectangle(SavedImage *Image, const int x, const int y, + const int w, const int d, const int color); + +extern void GifDrawBoxedText8x8(SavedImage *Image, const int x, const int y, + const char *legend, const int border, + const int bg, const int fg); + +#ifdef __cplusplus +} +#endif /* __cplusplus */ +#endif /* _GIF_LIB_H */ + +/* end */ diff --git a/product/include/torchvision/io/image/cpu/giflib/gif_lib_private.h b/product/include/torchvision/io/image/cpu/giflib/gif_lib_private.h new file mode 100644 index 00000000000..19578d4530c --- /dev/null +++ b/product/include/torchvision/io/image/cpu/giflib/gif_lib_private.h @@ -0,0 +1,72 @@ +/**************************************************************************** + +gif_lib_private.h - internal giflib routines and structures + +SPDX-License-Identifier: MIT + +****************************************************************************/ + +#ifndef _GIF_LIB_PRIVATE_H +#define _GIF_LIB_PRIVATE_H + +#include "gif_hash.h" +#include "gif_lib.h" + +#ifndef SIZE_MAX +#define SIZE_MAX UINTPTR_MAX +#endif + +#define EXTENSION_INTRODUCER 0x21 +#define DESCRIPTOR_INTRODUCER 0x2c +#define TERMINATOR_INTRODUCER 0x3b + +#define LZ_MAX_CODE 4095 /* Biggest code possible in 12 bits. */ +#define LZ_BITS 12 + +#define FLUSH_OUTPUT 4096 /* Impossible code, to signal flush. */ +#define FIRST_CODE 4097 /* Impossible code, to signal first. */ +#define NO_SUCH_CODE 4098 /* Impossible code, to signal empty. */ + +#define FILE_STATE_WRITE 0x01 +#define FILE_STATE_SCREEN 0x02 +#define FILE_STATE_IMAGE 0x04 +#define FILE_STATE_READ 0x08 + +#define IS_READABLE(Private) (Private->FileState & FILE_STATE_READ) +#define IS_WRITEABLE(Private) (Private->FileState & FILE_STATE_WRITE) + +typedef struct GifFilePrivateType { + GifWord FileState, FileHandle, /* Where all this data goes to! */ + BitsPerPixel, /* Bits per pixel (Codes uses at least this + 1). */ + ClearCode, /* The CLEAR LZ code. */ + EOFCode, /* The EOF LZ code. */ + RunningCode, /* The next code algorithm can generate. */ + RunningBits, /* The number of bits required to represent + RunningCode. */ + MaxCode1, /* 1 bigger than max. possible code, in RunningBits bits. + */ + LastCode, /* The code before the current code. */ + CrntCode, /* Current algorithm code. */ + StackPtr, /* For character stack (see below). */ + CrntShiftState; /* Number of bits in CrntShiftDWord. */ + unsigned long CrntShiftDWord; /* For bytes decomposition into codes. */ + unsigned long PixelCount; /* Number of pixels in image. */ + FILE *File; /* File as stream. */ + InputFunc Read; /* function to read gif input (TVT) */ + OutputFunc Write; /* function to write gif output (MRB) */ + GifByteType Buf[256]; /* Compressed input is buffered here. */ + GifByteType Stack[LZ_MAX_CODE]; /* Decoded pixels are stacked here. */ + GifByteType Suffix[LZ_MAX_CODE + 1]; /* So we can trace the codes. */ + GifPrefixType Prefix[LZ_MAX_CODE + 1]; + GifHashTableType *HashTable; + bool gif89; +} GifFilePrivateType; + +#ifndef HAVE_REALLOCARRAY +extern void *openbsd_reallocarray(void *optr, size_t nmemb, size_t size); +#define reallocarray openbsd_reallocarray +#endif + +#endif /* _GIF_LIB_PRIVATE_H */ + +/* end */ diff --git a/product/include/torchvision/io/image/cpu/giflib/gifalloc.c b/product/include/torchvision/io/image/cpu/giflib/gifalloc.c new file mode 100644 index 00000000000..926d54ebcf7 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/giflib/gifalloc.c @@ -0,0 +1,425 @@ +/***************************************************************************** + + GIF construction tools + +****************************************************************************/ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (C) Eric S. Raymond + +#include +#include +#include + +#include "gif_lib.h" +#include "gif_lib_private.h" + +#define MAX(x, y) (((x) > (y)) ? (x) : (y)) + +/****************************************************************************** + Miscellaneous utility functions +******************************************************************************/ + +/* return smallest bitfield size n will fit in */ +int GifBitSize(int n) { + int i; + + for (i = 1; i <= 8; i++) { + if ((1 << i) >= n) { + break; + } + } + return (i); +} + +/****************************************************************************** + Color map object functions +******************************************************************************/ + +/* + * Allocate a color map of given size; initialize with contents of + * ColorMap if that pointer is non-NULL. + */ +ColorMapObject *GifMakeMapObject(int ColorCount, const GifColorType *ColorMap) { + ColorMapObject *Object; + + /*** FIXME: Our ColorCount has to be a power of two. Is it necessary to + * make the user know that or should we automatically round up instead? + */ + if (ColorCount != (1 << GifBitSize(ColorCount))) { + return ((ColorMapObject *)NULL); + } + + Object = (ColorMapObject *)malloc(sizeof(ColorMapObject)); + if (Object == (ColorMapObject *)NULL) { + return ((ColorMapObject *)NULL); + } + + Object->Colors = + (GifColorType *)calloc(ColorCount, sizeof(GifColorType)); + if (Object->Colors == (GifColorType *)NULL) { + free(Object); + return ((ColorMapObject *)NULL); + } + + Object->ColorCount = ColorCount; + Object->BitsPerPixel = GifBitSize(ColorCount); + Object->SortFlag = false; + + if (ColorMap != NULL) { + memcpy((char *)Object->Colors, (char *)ColorMap, + ColorCount * sizeof(GifColorType)); + } + + return (Object); +} + +/******************************************************************************* + Free a color map object +*******************************************************************************/ +void GifFreeMapObject(ColorMapObject *Object) { + if (Object != NULL) { + (void)free(Object->Colors); + (void)free(Object); + } +} + +#ifdef DEBUG +void DumpColorMap(ColorMapObject *Object, FILE *fp) { + if (Object != NULL) { + int i, j, Len = Object->ColorCount; + + for (i = 0; i < Len; i += 4) { + for (j = 0; j < 4 && j < Len; j++) { + (void)fprintf(fp, "%3d: %02x %02x %02x ", + i + j, Object->Colors[i + j].Red, + Object->Colors[i + j].Green, + Object->Colors[i + j].Blue); + } + (void)fprintf(fp, "\n"); + } + } +} +#endif /* DEBUG */ + +/******************************************************************************* + Compute the union of two given color maps and return it. If result can't + fit into 256 colors, NULL is returned, the allocated union otherwise. + ColorIn1 is copied as is to ColorUnion, while colors from ColorIn2 are + copied iff they didn't exist before. ColorTransIn2 maps the old + ColorIn2 into the ColorUnion color map table./ +*******************************************************************************/ +ColorMapObject *GifUnionColorMap(const ColorMapObject *ColorIn1, + const ColorMapObject *ColorIn2, + GifPixelType ColorTransIn2[]) { + int i, j, CrntSlot, RoundUpTo, NewGifBitSize; + ColorMapObject *ColorUnion; + + /* + * We don't worry about duplicates within either color map; if + * the caller wants to resolve those, he can perform unions + * with an empty color map. + */ + + /* Allocate table which will hold the result for sure. */ + ColorUnion = GifMakeMapObject( + MAX(ColorIn1->ColorCount, ColorIn2->ColorCount) * 2, NULL); + + if (ColorUnion == NULL) { + return (NULL); + } + + /* + * Copy ColorIn1 to ColorUnion. + */ + for (i = 0; i < ColorIn1->ColorCount; i++) { + ColorUnion->Colors[i] = ColorIn1->Colors[i]; + } + CrntSlot = ColorIn1->ColorCount; + + /* + * Potentially obnoxious hack: + * + * Back CrntSlot down past all contiguous {0, 0, 0} slots at the end + * of table 1. This is very useful if your display is limited to + * 16 colors. + */ + while (ColorIn1->Colors[CrntSlot - 1].Red == 0 && + ColorIn1->Colors[CrntSlot - 1].Green == 0 && + ColorIn1->Colors[CrntSlot - 1].Blue == 0) { + CrntSlot--; + } + + /* Copy ColorIn2 to ColorUnion (use old colors if they exist): */ + for (i = 0; i < ColorIn2->ColorCount && CrntSlot <= 256; i++) { + /* Let's see if this color already exists: */ + for (j = 0; j < ColorIn1->ColorCount; j++) { + if (memcmp(&ColorIn1->Colors[j], &ColorIn2->Colors[i], + sizeof(GifColorType)) == 0) { + break; + } + } + + if (j < ColorIn1->ColorCount) { + ColorTransIn2[i] = j; /* color exists in Color1 */ + } else { + /* Color is new - copy it to a new slot: */ + ColorUnion->Colors[CrntSlot] = ColorIn2->Colors[i]; + ColorTransIn2[i] = CrntSlot++; + } + } + + if (CrntSlot > 256) { + GifFreeMapObject(ColorUnion); + return ((ColorMapObject *)NULL); + } + + NewGifBitSize = GifBitSize(CrntSlot); + RoundUpTo = (1 << NewGifBitSize); + + if (RoundUpTo != ColorUnion->ColorCount) { + GifColorType *Map = ColorUnion->Colors; + + /* + * Zero out slots up to next power of 2. + * We know these slots exist because of the way ColorUnion's + * start dimension was computed. + */ + for (j = CrntSlot; j < RoundUpTo; j++) { + Map[j].Red = Map[j].Green = Map[j].Blue = 0; + } + + /* perhaps we can shrink the map? */ + if (RoundUpTo < ColorUnion->ColorCount) { + GifColorType *new_map = (GifColorType *)reallocarray( + Map, RoundUpTo, sizeof(GifColorType)); + if (new_map == NULL) { + GifFreeMapObject(ColorUnion); + return ((ColorMapObject *)NULL); + } + ColorUnion->Colors = new_map; + } + } + + ColorUnion->ColorCount = RoundUpTo; + ColorUnion->BitsPerPixel = NewGifBitSize; + + return (ColorUnion); +} + +/******************************************************************************* + Apply a given color translation to the raster bits of an image +*******************************************************************************/ +void GifApplyTranslation(SavedImage *Image, const GifPixelType Translation[]) { + int i; + int RasterSize = + Image->ImageDesc.Height * Image->ImageDesc.Width; + + for (i = 0; i < RasterSize; i++) { + Image->RasterBits[i] = Translation[Image->RasterBits[i]]; + } +} + +/****************************************************************************** + Extension record functions +******************************************************************************/ +int GifAddExtensionBlock(int *ExtensionBlockCount, + ExtensionBlock **ExtensionBlocks, int Function, + unsigned int Len, unsigned char ExtData[]) { + ExtensionBlock *ep; + + if (*ExtensionBlocks == NULL) { + *ExtensionBlocks = + (ExtensionBlock *)malloc(sizeof(ExtensionBlock)); + } else { + ExtensionBlock *ep_new = (ExtensionBlock *)reallocarray( + *ExtensionBlocks, (*ExtensionBlockCount + 1), + sizeof(ExtensionBlock)); + if (ep_new == NULL) { + return (GIF_ERROR); + } + *ExtensionBlocks = ep_new; + } + + if (*ExtensionBlocks == NULL) { + return (GIF_ERROR); + } + + ep = &(*ExtensionBlocks)[(*ExtensionBlockCount)++]; + + ep->Function = Function; + ep->ByteCount = Len; + ep->Bytes = (GifByteType *)malloc(ep->ByteCount); + if (ep->Bytes == NULL) { + return (GIF_ERROR); + } + + if (ExtData != NULL) { + memcpy(ep->Bytes, ExtData, Len); + } + + return (GIF_OK); +} + +void GifFreeExtensions(int *ExtensionBlockCount, + ExtensionBlock **ExtensionBlocks) { + ExtensionBlock *ep; + + if (*ExtensionBlocks == NULL) { + return; + } + + for (ep = *ExtensionBlocks; + ep < (*ExtensionBlocks + *ExtensionBlockCount); ep++) { + (void)free((char *)ep->Bytes); + } + (void)free((char *)*ExtensionBlocks); + *ExtensionBlocks = NULL; + *ExtensionBlockCount = 0; +} + +/****************************************************************************** + Image block allocation functions +******************************************************************************/ + +/* Private Function: + * Frees the last image in the GifFile->SavedImages array + */ +void FreeLastSavedImage(GifFileType *GifFile) { + SavedImage *sp; + + if ((GifFile == NULL) || (GifFile->SavedImages == NULL)) { + return; + } + + /* Remove one SavedImage from the GifFile */ + GifFile->ImageCount--; + sp = &GifFile->SavedImages[GifFile->ImageCount]; + + /* Deallocate its Colormap */ + if (sp->ImageDesc.ColorMap != NULL) { + GifFreeMapObject(sp->ImageDesc.ColorMap); + sp->ImageDesc.ColorMap = NULL; + } + + /* Deallocate the image data */ + if (sp->RasterBits != NULL) { + free((char *)sp->RasterBits); + } + + /* Deallocate any extensions */ + GifFreeExtensions(&sp->ExtensionBlockCount, &sp->ExtensionBlocks); + + /*** FIXME: We could realloc the GifFile->SavedImages structure but is + * there a point to it? Saves some memory but we'd have to do it every + * time. If this is used in GifFreeSavedImages then it would be + * inefficient (The whole array is going to be deallocated.) If we just + * use it when we want to free the last Image it's convenient to do it + * here. + */ +} + +/* + * Append an image block to the SavedImages array + */ +SavedImage *GifMakeSavedImage(GifFileType *GifFile, + const SavedImage *CopyFrom) { + // cppcheck-suppress ctunullpointer + if (GifFile->SavedImages == NULL) { + GifFile->SavedImages = (SavedImage *)malloc(sizeof(SavedImage)); + } else { + SavedImage *newSavedImages = (SavedImage *)reallocarray( + GifFile->SavedImages, (GifFile->ImageCount + 1), + sizeof(SavedImage)); + if (newSavedImages == NULL) { + return ((SavedImage *)NULL); + } + GifFile->SavedImages = newSavedImages; + } + if (GifFile->SavedImages == NULL) { + return ((SavedImage *)NULL); + } else { + SavedImage *sp = &GifFile->SavedImages[GifFile->ImageCount++]; + + if (CopyFrom != NULL) { + memcpy((char *)sp, CopyFrom, sizeof(SavedImage)); + + /* + * Make our own allocated copies of the heap fields in + * the copied record. This guards against potential + * aliasing problems. + */ + + /* first, the local color map */ + if (CopyFrom->ImageDesc.ColorMap != NULL) { + sp->ImageDesc.ColorMap = GifMakeMapObject( + CopyFrom->ImageDesc.ColorMap->ColorCount, + CopyFrom->ImageDesc.ColorMap->Colors); + if (sp->ImageDesc.ColorMap == NULL) { + FreeLastSavedImage(GifFile); + return (SavedImage *)(NULL); + } + } + + /* next, the raster */ + sp->RasterBits = (unsigned char *)reallocarray( + NULL, + (CopyFrom->ImageDesc.Height * + CopyFrom->ImageDesc.Width), + sizeof(GifPixelType)); + if (sp->RasterBits == NULL) { + FreeLastSavedImage(GifFile); + return (SavedImage *)(NULL); + } + memcpy(sp->RasterBits, CopyFrom->RasterBits, + sizeof(GifPixelType) * + CopyFrom->ImageDesc.Height * + CopyFrom->ImageDesc.Width); + + /* finally, the extension blocks */ + if (CopyFrom->ExtensionBlocks != NULL) { + sp->ExtensionBlocks = + (ExtensionBlock *)reallocarray( + NULL, CopyFrom->ExtensionBlockCount, + sizeof(ExtensionBlock)); + if (sp->ExtensionBlocks == NULL) { + FreeLastSavedImage(GifFile); + return (SavedImage *)(NULL); + } + memcpy(sp->ExtensionBlocks, + CopyFrom->ExtensionBlocks, + sizeof(ExtensionBlock) * + CopyFrom->ExtensionBlockCount); + } + } else { + memset((char *)sp, '\0', sizeof(SavedImage)); + } + + return (sp); + } +} + +void GifFreeSavedImages(GifFileType *GifFile) { + SavedImage *sp; + + if ((GifFile == NULL) || (GifFile->SavedImages == NULL)) { + return; + } + for (sp = GifFile->SavedImages; + sp < GifFile->SavedImages + GifFile->ImageCount; sp++) { + if (sp->ImageDesc.ColorMap != NULL) { + GifFreeMapObject(sp->ImageDesc.ColorMap); + sp->ImageDesc.ColorMap = NULL; + } + + if (sp->RasterBits != NULL) { + free((char *)sp->RasterBits); + } + + GifFreeExtensions(&sp->ExtensionBlockCount, + &sp->ExtensionBlocks); + } + free((char *)GifFile->SavedImages); + GifFile->SavedImages = NULL; +} + +/* end */ diff --git a/product/include/torchvision/io/image/cpu/giflib/openbsd-reallocarray.c b/product/include/torchvision/io/image/cpu/giflib/openbsd-reallocarray.c new file mode 100644 index 00000000000..e09ab245ad4 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/giflib/openbsd-reallocarray.c @@ -0,0 +1,73 @@ +/* + * SPDX-FileCopyrightText: Copyright (C) 2008 Otto Moerbeek + * SPDX-License-Identifier: MIT + */ + +#include +#include +#include +#include + +#ifndef SIZE_MAX +#define SIZE_MAX UINTPTR_MAX +#endif + +/* + * This is sqrt(SIZE_MAX+1), as s1*s2 <= SIZE_MAX + * if both s1 < MUL_NO_OVERFLOW and s2 < MUL_NO_OVERFLOW + */ +#define MUL_NO_OVERFLOW ((size_t)1 << (sizeof(size_t) * 4)) + +void *openbsd_reallocarray(void *optr, size_t nmemb, size_t size) { + if ((nmemb >= MUL_NO_OVERFLOW || size >= MUL_NO_OVERFLOW) && + nmemb > 0 && SIZE_MAX / nmemb < size) { + errno = ENOMEM; + return NULL; + } + /* + * Head off variations in realloc behavior on different + * platforms (reported by MarkR ) + * + * The behaviour of reallocarray is implementation-defined if + * nmemb or size is zero. It can return NULL or non-NULL + * depending on the platform. + * https://www.securecoding.cert.org/confluence/display/c/MEM04-C.Beware+of+zero-lengthallocations + * + * Here are some extracts from realloc man pages on different platforms. + * + * void realloc( void memblock, size_t size ); + * + * Windows: + * + * If there is not enough available memory to expand the block + * to the given size, the original block is left unchanged, + * and NULL is returned. If size is zero, then the block + * pointed to by memblock is freed; the return value is NULL, + * and memblock is left pointing at a freed block. + * + * OpenBSD: + * + * If size or nmemb is equal to 0, a unique pointer to an + * access protected, zero sized object is returned. Access via + * this pointer will generate a SIGSEGV exception. + * + * Linux: + * + * If size was equal to 0, either NULL or a pointer suitable + * to be passed to free() is returned. + * + * OS X: + * + * If size is zero and ptr is not NULL, a new, minimum sized + * object is allocated and the original object is freed. + * + * It looks like images with zero width or height can trigger + * this, and fuzzing behaviour will differ by platform, so + * fuzzing on one platform may not detect zero-size allocation + * problems on other platforms. + */ + if (size == 0 || nmemb == 0) { + return NULL; + } + return realloc(optr, size * nmemb); +} diff --git a/product/include/torchvision/io/image/cpu/read_write_file.cpp b/product/include/torchvision/io/image/cpu/read_write_file.cpp new file mode 100644 index 00000000000..06de72a5053 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/read_write_file.cpp @@ -0,0 +1,108 @@ +#include "read_write_file.h" + +#include + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#include +#endif + +namespace vision { +namespace image { + +#ifdef _WIN32 +namespace { +std::wstring utf8_decode(const std::string& str) { + if (str.empty()) { + return std::wstring(); + } + int size_needed = MultiByteToWideChar( + CP_UTF8, 0, str.c_str(), static_cast(str.size()), nullptr, 0); + TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode"); + std::wstring wstrTo(size_needed, 0); + MultiByteToWideChar( + CP_UTF8, + 0, + str.c_str(), + static_cast(str.size()), + &wstrTo[0], + size_needed); + return wstrTo; +} +} // namespace +#endif + +torch::Tensor read_file(const std::string& filename) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cpu.read_write_file.read_file"); +#ifdef _WIN32 + // According to + // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019, + // we should use struct __stat64 and _wstat64 for 64-bit file size on Windows. + struct __stat64 stat_buf; + auto fileW = utf8_decode(filename); + int rc = _wstat64(fileW.c_str(), &stat_buf); +#else + struct stat stat_buf; + int rc = stat(filename.c_str(), &stat_buf); +#endif + // errno is a variable defined in errno.h + TORCH_CHECK( + rc == 0, "[Errno ", errno, "] ", strerror(errno), ": '", filename, "'"); + + int64_t size = stat_buf.st_size; + + TORCH_CHECK(size > 0, "Expected a non empty file"); + +#ifdef _WIN32 + // TODO: Once torch::from_file handles UTF-8 paths correctly, we should move + // back to use the following implementation since it uses file mapping. + // auto data = + // torch::from_file(filename, /*shared=*/false, /*size=*/size, + // torch::kU8).clone() + FILE* infile = _wfopen(fileW.c_str(), L"rb"); + + TORCH_CHECK(infile != nullptr, "Error opening input file"); + + auto data = torch::empty({size}, torch::kU8); + auto dataBytes = data.data_ptr(); + + fread(dataBytes, sizeof(uint8_t), size, infile); + fclose(infile); +#else + auto data = + torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8); +#endif + + return data; +} + +void write_file(const std::string& filename, torch::Tensor& data) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cpu.read_write_file.write_file"); + // Check that the input tensor is on CPU + TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); + + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + + // Check that the input tensor is 3-dimensional + TORCH_CHECK(data.dim() == 1, "Input data should be a 1-dimensional tensor"); + + auto fileBytes = data.data_ptr(); + auto fileCStr = filename.c_str(); +#ifdef _WIN32 + auto fileW = utf8_decode(filename); + FILE* outfile = _wfopen(fileW.c_str(), L"wb"); +#else + FILE* outfile = fopen(fileCStr, "wb"); +#endif + + TORCH_CHECK(outfile != nullptr, "Error opening output file"); + + fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile); + fclose(outfile); +} + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/read_write_file.h b/product/include/torchvision/io/image/cpu/read_write_file.h new file mode 100644 index 00000000000..a5a712dd8e2 --- /dev/null +++ b/product/include/torchvision/io/image/cpu/read_write_file.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor read_file(const std::string& filename); + +C10_EXPORT void write_file(const std::string& filename, torch::Tensor& data); + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cuda/decode_jpegs_cuda.cpp b/product/include/torchvision/io/image/cuda/decode_jpegs_cuda.cpp new file mode 100644 index 00000000000..6314ececef1 --- /dev/null +++ b/product/include/torchvision/io/image/cuda/decode_jpegs_cuda.cpp @@ -0,0 +1,603 @@ +#include "decode_jpegs_cuda.h" +#if !NVJPEG_FOUND +namespace vision { +namespace image { +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, + vision::image::ImageReadMode mode, + torch::Device device) { + TORCH_CHECK( + false, "decode_jpegs_cuda: torchvision not compiled with nvJPEG support"); +} +} // namespace image +} // namespace vision + +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace vision { +namespace image { + +std::mutex decoderMutex; +std::unique_ptr cudaJpegDecoder; + +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, + vision::image::ImageReadMode mode, + torch::Device device) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda"); + + std::lock_guard lock(decoderMutex); + std::vector contig_images; + contig_images.reserve(encoded_images.size()); + + TORCH_CHECK( + device.is_cuda(), "Expected the device parameter to be a cuda device"); + + for (auto& encoded_image : encoded_images) { + TORCH_CHECK( + encoded_image.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + + TORCH_CHECK( + !encoded_image.is_cuda(), + "The input tensor must be on CPU when decoding with nvjpeg") + + TORCH_CHECK( + encoded_image.dim() == 1 && encoded_image.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + // nvjpeg requires images to be contiguous + if (encoded_image.is_contiguous()) { + contig_images.push_back(encoded_image); + } else { + contig_images.push_back(encoded_image.contiguous()); + } + } + + int major_version; + int minor_version; + nvjpegStatus_t get_major_property_status = + nvjpegGetProperty(MAJOR_VERSION, &major_version); + nvjpegStatus_t get_minor_property_status = + nvjpegGetProperty(MINOR_VERSION, &minor_version); + + TORCH_CHECK( + get_major_property_status == NVJPEG_STATUS_SUCCESS, + "nvjpegGetProperty failed: ", + get_major_property_status); + TORCH_CHECK( + get_minor_property_status == NVJPEG_STATUS_SUCCESS, + "nvjpegGetProperty failed: ", + get_minor_property_status); + if ((major_version < 11) || ((major_version == 11) && (minor_version < 6))) { + TORCH_WARN_ONCE( + "There is a memory leak issue in the nvjpeg library for CUDA versions < 11.6. " + "Make sure to rely on CUDA 11.6 or above before using decode_jpeg(..., device='cuda')."); + } + + at::cuda::CUDAGuard device_guard(device); + + if (cudaJpegDecoder == nullptr || device != cudaJpegDecoder->target_device) { + if (cudaJpegDecoder != nullptr) + cudaJpegDecoder.reset(new CUDAJpegDecoder(device)); + else { + cudaJpegDecoder = std::make_unique(device); + std::atexit([]() { cudaJpegDecoder.reset(); }); + } + } + + nvjpegOutputFormat_t output_format; + + switch (mode) { + case vision::image::IMAGE_READ_MODE_UNCHANGED: + // Using NVJPEG_OUTPUT_UNCHANGED causes differently sized output channels + // which is related to the subsampling used I'm not sure why this is the + // case, but for now we're just using RGB and later removing channels from + // grayscale images. + output_format = NVJPEG_OUTPUT_UNCHANGED; + break; + case vision::image::IMAGE_READ_MODE_GRAY: + output_format = NVJPEG_OUTPUT_Y; + break; + case vision::image::IMAGE_READ_MODE_RGB: + output_format = NVJPEG_OUTPUT_RGB; + break; + default: + TORCH_CHECK( + false, "The provided mode is not supported for JPEG decoding on GPU"); + } + + try { + at::cuda::CUDAEvent event; + auto result = cudaJpegDecoder->decode_images(contig_images, output_format); + auto current_stream{ + device.has_index() ? at::cuda::getCurrentCUDAStream( + cudaJpegDecoder->original_device.index()) + : at::cuda::getCurrentCUDAStream()}; + event.record(cudaJpegDecoder->stream); + event.block(current_stream); + return result; + } catch (const std::exception& e) { + if (typeid(e) != typeid(std::runtime_error)) { + TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); + } else { + throw; + } + } +} + +CUDAJpegDecoder::CUDAJpegDecoder(const torch::Device& target_device) + : original_device{torch::kCUDA, torch::cuda::current_device()}, + target_device{target_device}, + stream{ + target_device.has_index() + ? at::cuda::getStreamFromPool(false, target_device.index()) + : at::cuda::getStreamFromPool(false)} { + nvjpegStatus_t status; + + hw_decode_available = true; + status = nvjpegCreateEx( + NVJPEG_BACKEND_HARDWARE, + NULL, + NULL, + NVJPEG_FLAGS_DEFAULT, + &nvjpeg_handle); + if (status == NVJPEG_STATUS_ARCH_MISMATCH) { + status = nvjpegCreateEx( + NVJPEG_BACKEND_DEFAULT, + NULL, + NULL, + NVJPEG_FLAGS_DEFAULT, + &nvjpeg_handle); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to initialize nvjpeg with default backend: ", + status); + hw_decode_available = false; + } else { + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to initialize nvjpeg with hardware backend: ", + status); + } + + status = nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg state: ", + status); + + status = nvjpegDecoderCreate( + nvjpeg_handle, NVJPEG_BACKEND_DEFAULT, &nvjpeg_decoder); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg decoder: ", + status); + + status = nvjpegDecoderStateCreate( + nvjpeg_handle, nvjpeg_decoder, &nvjpeg_decoupled_state); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg decoder state: ", + status); + + status = nvjpegBufferPinnedCreate(nvjpeg_handle, NULL, &pinned_buffers[0]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create pinned buffer: ", + status); + + status = nvjpegBufferPinnedCreate(nvjpeg_handle, NULL, &pinned_buffers[1]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create pinned buffer: ", + status); + + status = nvjpegBufferDeviceCreate(nvjpeg_handle, NULL, &device_buffer); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create device buffer: ", + status); + + status = nvjpegJpegStreamCreate(nvjpeg_handle, &jpeg_streams[0]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create jpeg stream: ", + status); + + status = nvjpegJpegStreamCreate(nvjpeg_handle, &jpeg_streams[1]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create jpeg stream: ", + status); + + status = nvjpegDecodeParamsCreate(nvjpeg_handle, &nvjpeg_decode_params); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create decode params: ", + status); +} + +CUDAJpegDecoder::~CUDAJpegDecoder() { + /* + The below code works on Mac and Linux, but fails on Windows. + This is because on Windows, the atexit hook which calls this + destructor executes after cuda is already shut down causing SIGSEGV. + We do not have a solution to this problem at the moment, so we'll + just leak the libnvjpeg & cuda variables for the time being and hope + that the CUDA runtime handles cleanup for us. + Please send a PR if you have a solution for this problem. + */ + + // nvjpegStatus_t status; + + // status = nvjpegDecodeParamsDestroy(nvjpeg_decode_params); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg decode params: ", + // status); + + // status = nvjpegJpegStreamDestroy(jpeg_streams[0]); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy jpeg stream: ", + // status); + + // status = nvjpegJpegStreamDestroy(jpeg_streams[1]); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy jpeg stream: ", + // status); + + // status = nvjpegBufferPinnedDestroy(pinned_buffers[0]); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy pinned buffer[0]: ", + // status); + + // status = nvjpegBufferPinnedDestroy(pinned_buffers[1]); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy pinned buffer[1]: ", + // status); + + // status = nvjpegBufferDeviceDestroy(device_buffer); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy device buffer: ", + // status); + + // status = nvjpegJpegStateDestroy(nvjpeg_decoupled_state); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg decoupled state: ", + // status); + + // status = nvjpegDecoderDestroy(nvjpeg_decoder); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg decoder: ", + // status); + + // status = nvjpegJpegStateDestroy(nvjpeg_state); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg state: ", + // status); + + // status = nvjpegDestroy(nvjpeg_handle); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); +} + +std::tuple< + std::vector, + std::vector, + std::vector> +CUDAJpegDecoder::prepare_buffers( + const std::vector& encoded_images, + const nvjpegOutputFormat_t& output_format) { + /* + This function scans the encoded images' jpeg headers and + allocates decoding buffers based on the metadata found + + Args: + - encoded_images (std::vector): a vector of tensors + containing the jpeg bitstreams to be decoded. Each tensor must have dtype + torch.uint8 and device cpu + - output_format (nvjpegOutputFormat_t): NVJPEG_OUTPUT_RGB, NVJPEG_OUTPUT_Y + or NVJPEG_OUTPUT_UNCHANGED + + Returns: + - decoded_images (std::vector): a vector of nvjpegImages + containing pointers to the memory of the decoded images + - output_tensors (std::vector): a vector of Tensors + containing the decoded images. `decoded_images` points to the memory of + output_tensors + - channels (std::vector): a vector of ints containing the number of + output image channels for every image + */ + + int width[NVJPEG_MAX_COMPONENT]; + int height[NVJPEG_MAX_COMPONENT]; + std::vector channels(encoded_images.size()); + nvjpegChromaSubsampling_t subsampling; + nvjpegStatus_t status; + + std::vector output_tensors{encoded_images.size()}; + std::vector decoded_images{encoded_images.size()}; + + for (std::vector::size_type i = 0; i < encoded_images.size(); + i++) { + // extract bitstream meta data to figure out the number of channels, height, + // width for every image + status = nvjpegGetImageInfo( + nvjpeg_handle, + (unsigned char*)encoded_images[i].data_ptr(), + encoded_images[i].numel(), + &channels[i], + &subsampling, + width, + height); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, "Failed to get image info: ", status); + + TORCH_CHECK( + subsampling != NVJPEG_CSS_UNKNOWN, "Unknown chroma subsampling"); + + // output channels may be different from the actual number of channels in + // the image, e.g. we decode a grayscale image as RGB and slice off the + // extra channels later + int output_channels = 3; + if (output_format == NVJPEG_OUTPUT_RGB || + output_format == NVJPEG_OUTPUT_UNCHANGED) { + output_channels = 3; + } else if (output_format == NVJPEG_OUTPUT_Y) { + output_channels = 1; + } + + // reserve output buffer + auto output_tensor = torch::empty( + {int64_t(output_channels), int64_t(height[0]), int64_t(width[0])}, + torch::dtype(torch::kU8).device(target_device)); + output_tensors[i] = output_tensor; + + // fill nvjpegImage_t struct + for (int c = 0; c < output_channels; c++) { + decoded_images[i].channel[c] = output_tensor[c].data_ptr(); + decoded_images[i].pitch[c] = width[0]; + } + for (int c = output_channels; c < NVJPEG_MAX_COMPONENT; c++) { + decoded_images[i].channel[c] = NULL; + decoded_images[i].pitch[c] = 0; + } + } + return {decoded_images, output_tensors, channels}; +} + +std::vector CUDAJpegDecoder::decode_images( + const std::vector& encoded_images, + const nvjpegOutputFormat_t& output_format) { + /* + This function decodes a batch of jpeg bitstreams. + We scan all encoded bitstreams and sort them into two groups: + 1. Baseline JPEGs: Can be decoded with hardware support on A100+ GPUs. + 2. Other JPEGs (e.g. progressive JPEGs): Can also be decoded on the + GPU (albeit with software support only) but need some preprocessing on the + host first. + + See + https://github.com/NVIDIA/CUDALibrarySamples/blob/f17940ac4e705bf47a8c39f5365925c1665f6c98/nvJPEG/nvJPEG-Decoder/nvjpegDecoder.cpp#L33 + for reference. + + Args: + - encoded_images (std::vector): a vector of tensors + containing the jpeg bitstreams to be decoded + - output_format (nvjpegOutputFormat_t): NVJPEG_OUTPUT_RGB, NVJPEG_OUTPUT_Y + or NVJPEG_OUTPUT_UNCHANGED + - device (torch::Device): The desired CUDA device for the returned Tensors + + Returns: + - output_tensors (std::vector): a vector of Tensors + containing the decoded images + */ + + auto [decoded_imgs_buf, output_tensors, channels] = + prepare_buffers(encoded_images, output_format); + + nvjpegStatus_t status; + cudaError_t cudaStatus; + + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK( + cudaStatus == cudaSuccess, + "Failed to synchronize CUDA stream: ", + cudaStatus); + + // baseline JPEGs can be batch decoded with hardware support on A100+ GPUs + // ultra fast! + std::vector hw_input_buffer; + std::vector hw_input_buffer_size; + std::vector hw_output_buffer; + + // other JPEG types such as progressive JPEGs can be decoded one-by-one in + // software slow :( + std::vector sw_input_buffer; + std::vector sw_input_buffer_size; + std::vector sw_output_buffer; + + if (hw_decode_available) { + for (std::vector::size_type i = 0; i < encoded_images.size(); + ++i) { + // extract bitstream meta data to figure out whether a bit-stream can be + // decoded + nvjpegJpegStreamParseHeader( + nvjpeg_handle, + encoded_images[i].data_ptr(), + encoded_images[i].numel(), + jpeg_streams[0]); + int isSupported = -1; + nvjpegDecodeBatchedSupported( + nvjpeg_handle, jpeg_streams[0], &isSupported); + + if (isSupported == 0) { + hw_input_buffer.push_back(encoded_images[i].data_ptr()); + hw_input_buffer_size.push_back(encoded_images[i].numel()); + hw_output_buffer.push_back(decoded_imgs_buf[i]); + } else { + sw_input_buffer.push_back(encoded_images[i].data_ptr()); + sw_input_buffer_size.push_back(encoded_images[i].numel()); + sw_output_buffer.push_back(decoded_imgs_buf[i]); + } + } + } else { + for (std::vector::size_type i = 0; i < encoded_images.size(); + ++i) { + sw_input_buffer.push_back(encoded_images[i].data_ptr()); + sw_input_buffer_size.push_back(encoded_images[i].numel()); + sw_output_buffer.push_back(decoded_imgs_buf[i]); + } + } + + if (hw_input_buffer.size() > 0) { + // UNCHANGED behaves weird, so we use RGB instead + status = nvjpegDecodeBatchedInitialize( + nvjpeg_handle, + nvjpeg_state, + hw_input_buffer.size(), + 1, + output_format == NVJPEG_OUTPUT_UNCHANGED ? NVJPEG_OUTPUT_RGB + : output_format); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to initialize batch decoding: ", + status); + + status = nvjpegDecodeBatched( + nvjpeg_handle, + nvjpeg_state, + hw_input_buffer.data(), + hw_input_buffer_size.data(), + hw_output_buffer.data(), + stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, "Failed to decode batch: ", status); + } + + if (sw_input_buffer.size() > 0) { + status = + nvjpegStateAttachDeviceBuffer(nvjpeg_decoupled_state, device_buffer); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to attach device buffer: ", + status); + int buffer_index = 0; + // UNCHANGED behaves weird, so we use RGB instead + status = nvjpegDecodeParamsSetOutputFormat( + nvjpeg_decode_params, + output_format == NVJPEG_OUTPUT_UNCHANGED ? NVJPEG_OUTPUT_RGB + : output_format); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to set output format: ", + status); + for (std::vector::size_type i = 0; i < sw_input_buffer.size(); + ++i) { + status = nvjpegJpegStreamParse( + nvjpeg_handle, + sw_input_buffer[i], + sw_input_buffer_size[i], + 0, + 0, + jpeg_streams[buffer_index]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to parse jpeg stream: ", + status); + + status = nvjpegStateAttachPinnedBuffer( + nvjpeg_decoupled_state, pinned_buffers[buffer_index]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to attach pinned buffer: ", + status); + + status = nvjpegDecodeJpegHost( + nvjpeg_handle, + nvjpeg_decoder, + nvjpeg_decoupled_state, + nvjpeg_decode_params, + jpeg_streams[buffer_index]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to decode jpeg stream: ", + status); + + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK( + cudaStatus == cudaSuccess, + "Failed to synchronize CUDA stream: ", + cudaStatus); + + status = nvjpegDecodeJpegTransferToDevice( + nvjpeg_handle, + nvjpeg_decoder, + nvjpeg_decoupled_state, + jpeg_streams[buffer_index], + stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to transfer jpeg to device: ", + status); + + buffer_index = 1 - buffer_index; // switch pinned buffer in pipeline mode + // to avoid an extra sync + + status = nvjpegDecodeJpegDevice( + nvjpeg_handle, + nvjpeg_decoder, + nvjpeg_decoupled_state, + &sw_output_buffer[i], + stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to decode jpeg stream: ", + status); + } + } + + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK( + cudaStatus == cudaSuccess, + "Failed to synchronize CUDA stream: ", + cudaStatus); + + // prune extraneous channels from single channel images + if (output_format == NVJPEG_OUTPUT_UNCHANGED) { + for (std::vector::size_type i = 0; i < output_tensors.size(); + ++i) { + if (channels[i] == 1) { + output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); + } + } + } + + return output_tensors; +} + +} // namespace image +} // namespace vision + +#endif diff --git a/product/include/torchvision/io/image/cuda/decode_jpegs_cuda.h b/product/include/torchvision/io/image/cuda/decode_jpegs_cuda.h new file mode 100644 index 00000000000..2458a103a3a --- /dev/null +++ b/product/include/torchvision/io/image/cuda/decode_jpegs_cuda.h @@ -0,0 +1,45 @@ +#pragma once +#include +#include +#include "../image_read_mode.h" + +#if NVJPEG_FOUND +#include +#include + +namespace vision { +namespace image { +class CUDAJpegDecoder { + public: + CUDAJpegDecoder(const torch::Device& target_device); + ~CUDAJpegDecoder(); + + std::vector decode_images( + const std::vector& encoded_images, + const nvjpegOutputFormat_t& output_format); + + const torch::Device original_device; + const torch::Device target_device; + const c10::cuda::CUDAStream stream; + + private: + std::tuple< + std::vector, + std::vector, + std::vector> + prepare_buffers( + const std::vector& encoded_images, + const nvjpegOutputFormat_t& output_format); + nvjpegJpegState_t nvjpeg_state; + nvjpegJpegState_t nvjpeg_decoupled_state; + nvjpegBufferPinned_t pinned_buffers[2]; + nvjpegBufferDevice_t device_buffer; + nvjpegJpegStream_t jpeg_streams[2]; + nvjpegDecodeParams_t nvjpeg_decode_params; + nvjpegJpegDecoder_t nvjpeg_decoder; + bool hw_decode_available{false}; + nvjpegHandle_t nvjpeg_handle; +}; +} // namespace image +} // namespace vision +#endif diff --git a/product/include/torchvision/io/image/cuda/encode_decode_jpegs_cuda.h b/product/include/torchvision/io/image/cuda/encode_decode_jpegs_cuda.h new file mode 100644 index 00000000000..3fdf715b00f --- /dev/null +++ b/product/include/torchvision/io/image/cuda/encode_decode_jpegs_cuda.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include "../image_read_mode.h" +#include "decode_jpegs_cuda.h" +#include "encode_jpegs_cuda.h" + +namespace vision { +namespace image { + +/* + +Fast jpeg decoding with CUDA. +A100+ GPUs have dedicated hardware support for jpeg decoding. + +Args: + - encoded_images (const std::vector&): a vector of tensors + containing the jpeg bitstreams to be decoded. Each tensor must have dtype + torch.uint8 and device cpu + - mode (ImageReadMode): IMAGE_READ_MODE_UNCHANGED, IMAGE_READ_MODE_GRAY and +IMAGE_READ_MODE_RGB are supported + - device (torch::Device): The desired CUDA device to run the decoding on and +which will contain the output tensors + +Returns: + - decoded_images (std::vector): a vector of torch::Tensors of +dtype torch.uint8 on the specified containing the decoded images + +Notes: + - If a single image fails, the whole batch fails. + - This function is thread-safe +*/ +C10_EXPORT std::vector decode_jpegs_cuda( + const std::vector& encoded_images, + vision::image::ImageReadMode mode, + torch::Device device); + +/* +Fast jpeg encoding with CUDA. + +Args: + - decoded_images (const std::vector&): a vector of contiguous +CUDA tensors of dtype torch.uint8 to be encoded. + - quality (int64_t): 0-100, 75 is the default + +Returns: + - encoded_images (std::vector): a vector of CUDA +torch::Tensors of dtype torch.uint8 containing the encoded images + +Notes: + - If a single image fails, the whole batch fails. + - This function is thread-safe +*/ +C10_EXPORT std::vector encode_jpegs_cuda( + const std::vector& decoded_images, + const int64_t quality); + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/cuda/encode_jpegs_cuda.cpp b/product/include/torchvision/io/image/cuda/encode_jpegs_cuda.cpp new file mode 100644 index 00000000000..1f10327ddbf --- /dev/null +++ b/product/include/torchvision/io/image/cuda/encode_jpegs_cuda.cpp @@ -0,0 +1,274 @@ +#include "encode_jpegs_cuda.h" +#if !NVJPEG_FOUND +namespace vision { +namespace image { +std::vector encode_jpegs_cuda( + const std::vector& decoded_images, + const int64_t quality) { + TORCH_CHECK( + false, "encode_jpegs_cuda: torchvision not compiled with nvJPEG support"); +} +} // namespace image +} // namespace vision +#else + +#include +#include +#include +#include +#include +#include +#include +#include +#include "c10/core/ScalarType.h" + +namespace vision { +namespace image { + +// We use global variables to cache the encoder and decoder instances and +// reuse them across calls to the corresponding pytorch functions +std::mutex encoderMutex; +std::unique_ptr cudaJpegEncoder; + +std::vector encode_jpegs_cuda( + const std::vector& decoded_images, + const int64_t quality) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cuda.encode_jpegs_cuda.encode_jpegs_cuda"); + + // Some nvjpeg structures are not thread safe so we're keeping it single + // threaded for now. In the future this may be an opportunity to unlock + // further speedups + std::lock_guard lock(encoderMutex); + TORCH_CHECK(decoded_images.size() > 0, "Empty input tensor list"); + torch::Device device = decoded_images[0].device(); + at::cuda::CUDAGuard device_guard(device); + + // lazy init of the encoder class + // the encoder object holds on to a lot of state and is expensive to create, + // so we reuse it across calls. NB: the cached structures are device specific + // and cannot be reused across devices + if (cudaJpegEncoder == nullptr || device != cudaJpegEncoder->target_device) { + if (cudaJpegEncoder != nullptr) + delete cudaJpegEncoder.release(); + + cudaJpegEncoder = std::make_unique(device); + + // Unfortunately, we cannot rely on the smart pointer releasing the encoder + // object correctly upon program exit. This is because, when cudaJpegEncoder + // gets destroyed, the CUDA runtime may already be shut down, rendering all + // destroy* calls in the encoder destructor invalid. Instead, we use an + // atexit hook which executes after main() finishes, but hopefully before + // CUDA shuts down when the program exits. If CUDA is already shut down the + // destructor will detect this and will not attempt to destroy any encoder + // structures. + std::atexit([]() { delete cudaJpegEncoder.release(); }); + } + + std::vector contig_images; + contig_images.reserve(decoded_images.size()); + for (const auto& image : decoded_images) { + TORCH_CHECK( + image.dtype() == torch::kU8, "Input tensor dtype should be uint8"); + + TORCH_CHECK( + image.device() == device, + "All input tensors must be on the same CUDA device when encoding with nvjpeg") + + TORCH_CHECK( + image.dim() == 3 && image.numel() > 0, + "Input data should be a 3-dimensional tensor"); + + TORCH_CHECK( + image.size(0) == 3, + "The number of channels should be 3, got: ", + image.size(0)); + + // nvjpeg requires images to be contiguous + if (image.is_contiguous()) { + contig_images.push_back(image); + } else { + contig_images.push_back(image.contiguous()); + } + } + + cudaJpegEncoder->set_quality(quality); + std::vector encoded_images; + at::cuda::CUDAEvent event; + event.record(cudaJpegEncoder->stream); + for (const auto& image : contig_images) { + auto encoded_image = cudaJpegEncoder->encode_jpeg(image); + encoded_images.push_back(encoded_image); + } + + // We use a dedicated stream to do the encoding and even though the results + // may be ready on that stream we cannot assume that they are also available + // on the current stream of the calling context when this function returns. We + // use a blocking event to ensure that this is indeed the case. Crucially, we + // do not want to block the host at this particular point + // (which is what cudaStreamSynchronize would do.) Events allow us to + // synchronize the streams without blocking the host. + event.block(at::cuda::getCurrentCUDAStream( + cudaJpegEncoder->original_device.has_index() + ? cudaJpegEncoder->original_device.index() + : 0)); + return encoded_images; +} + +CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device) + : original_device{torch::kCUDA, torch::cuda::current_device()}, + target_device{target_device}, + stream{ + target_device.has_index() + ? at::cuda::getStreamFromPool(false, target_device.index()) + : at::cuda::getStreamFromPool(false)} { + nvjpegStatus_t status; + status = nvjpegCreateSimple(&nvjpeg_handle); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg handle: ", + status); + + status = nvjpegEncoderStateCreate(nvjpeg_handle, &nv_enc_state, stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg encoder state: ", + status); + + status = nvjpegEncoderParamsCreate(nvjpeg_handle, &nv_enc_params, stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg encoder params: ", + status); +} + +CUDAJpegEncoder::~CUDAJpegEncoder() { + /* + The below code works on Mac and Linux, but fails on Windows. + This is because on Windows, the atexit hook which calls this + destructor executes after cuda is already shut down causing SIGSEGV. + We do not have a solution to this problem at the moment, so we'll + just leak the libnvjpeg & cuda variables for the time being and hope + that the CUDA runtime handles cleanup for us. + Please send a PR if you have a solution for this problem. + */ + + // // We run cudaGetDeviceCount as a dummy to test if the CUDA runtime is + // still + // // initialized. If it is not, we can skip the rest of this function as it + // is + // // unsafe to execute. + // int deviceCount = 0; + // cudaError_t error = cudaGetDeviceCount(&deviceCount); + // if (error != cudaSuccess) + // return; // CUDA runtime has already shut down. There's nothing we can do + // // now. + + // nvjpegStatus_t status; + + // status = nvjpegEncoderParamsDestroy(nv_enc_params); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg encoder params: ", + // status); + + // status = nvjpegEncoderStateDestroy(nv_enc_state); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg encoder state: ", + // status); + + // cudaStreamSynchronize(stream); + + // status = nvjpegDestroy(nvjpeg_handle); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); +} + +torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { + int channels = src_image.size(0); + int height = src_image.size(1); + int width = src_image.size(2); + + nvjpegStatus_t status; + cudaError_t cudaStatus; + status = nvjpegEncoderParamsSetSamplingFactors( + nv_enc_params, NVJPEG_CSS_444, stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to set nvjpeg encoder params sampling factors: ", + status); + + nvjpegImage_t target_image; + for (int c = 0; c < channels; c++) { + target_image.channel[c] = src_image[c].data_ptr(); + // this is why we need contiguous tensors + target_image.pitch[c] = width; + } + for (int c = channels; c < NVJPEG_MAX_COMPONENT; c++) { + target_image.channel[c] = nullptr; + target_image.pitch[c] = 0; + } + // Encode the image + status = nvjpegEncodeImage( + nvjpeg_handle, + nv_enc_state, + nv_enc_params, + &target_image, + NVJPEG_INPUT_RGB, + width, + height, + stream); + + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, "image encoding failed: ", status); + // Retrieve length of the encoded image + size_t length; + status = nvjpegEncodeRetrieveBitstreamDevice( + nvjpeg_handle, nv_enc_state, NULL, &length, stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to retrieve encoded image stream state: ", + status); + + // Synchronize the stream to ensure that the encoded image is ready + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); + + // Reserve buffer for the encoded image + torch::Tensor encoded_image = torch::empty( + {static_cast(length)}, + torch::TensorOptions() + .dtype(torch::kByte) + .layout(torch::kStrided) + .device(target_device) + .requires_grad(false)); + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); + // Retrieve the encoded image + status = nvjpegEncodeRetrieveBitstreamDevice( + nvjpeg_handle, + nv_enc_state, + encoded_image.data_ptr(), + &length, + 0); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to retrieve encoded image: ", + status); + return encoded_image; +} + +void CUDAJpegEncoder::set_quality(const int64_t quality) { + nvjpegStatus_t paramsQualityStatus = + nvjpegEncoderParamsSetQuality(nv_enc_params, quality, stream); + TORCH_CHECK( + paramsQualityStatus == NVJPEG_STATUS_SUCCESS, + "Failed to set nvjpeg encoder params quality: ", + paramsQualityStatus); +} + +} // namespace image +} // namespace vision + +#endif // NVJPEG_FOUND diff --git a/product/include/torchvision/io/image/cuda/encode_jpegs_cuda.h b/product/include/torchvision/io/image/cuda/encode_jpegs_cuda.h new file mode 100644 index 00000000000..543940f1585 --- /dev/null +++ b/product/include/torchvision/io/image/cuda/encode_jpegs_cuda.h @@ -0,0 +1,33 @@ +#pragma once +#include +#include +#if NVJPEG_FOUND + +#include +#include +#include + +namespace vision { +namespace image { + +class CUDAJpegEncoder { + public: + CUDAJpegEncoder(const torch::Device& device); + ~CUDAJpegEncoder(); + + torch::Tensor encode_jpeg(const torch::Tensor& src_image); + + void set_quality(const int64_t quality); + + const torch::Device original_device; + const torch::Device target_device; + const c10::cuda::CUDAStream stream; + + protected: + nvjpegEncoderState_t nv_enc_state; + nvjpegEncoderParams_t nv_enc_params; + nvjpegHandle_t nvjpeg_handle; +}; +} // namespace image +} // namespace vision +#endif diff --git a/product/include/torchvision/io/image/image.cpp b/product/include/torchvision/io/image/image.cpp new file mode 100644 index 00000000000..43e8ecbe4a2 --- /dev/null +++ b/product/include/torchvision/io/image/image.cpp @@ -0,0 +1,37 @@ +#include "image.h" + +#include + +// If we are in a Windows environment, we need to define +// initialization functions for the _custom_ops extension +#ifdef _WIN32 +void* PyInit_image(void) { + return nullptr; +} +#endif + +namespace vision { +namespace image { + +static auto registry = + torch::RegisterOperators() + .op("image::decode_gif", &decode_gif) + .op("image::decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", + &decode_png) + .op("image::encode_png", &encode_png) + .op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", + &decode_jpeg) + .op("image::decode_webp", &decode_webp) + .op("image::decode_avif", &decode_avif) + .op("image::encode_jpeg", &encode_jpeg) + .op("image::read_file", &read_file) + .op("image::write_file", &write_file) + .op("image::decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", + &decode_image) + .op("image::decode_jpegs_cuda", &decode_jpegs_cuda) + .op("image::encode_jpegs_cuda", &encode_jpegs_cuda) + .op("image::_jpeg_version", &_jpeg_version) + .op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo); + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/io/image/image.h b/product/include/torchvision/io/image/image.h new file mode 100644 index 00000000000..91a5144fa1c --- /dev/null +++ b/product/include/torchvision/io/image/image.h @@ -0,0 +1,12 @@ +#pragma once + +#include "cpu/decode_avif.h" +#include "cpu/decode_gif.h" +#include "cpu/decode_image.h" +#include "cpu/decode_jpeg.h" +#include "cpu/decode_png.h" +#include "cpu/decode_webp.h" +#include "cpu/encode_jpeg.h" +#include "cpu/encode_png.h" +#include "cpu/read_write_file.h" +#include "cuda/encode_decode_jpegs_cuda.h" diff --git a/product/include/torchvision/io/image/image_read_mode.h b/product/include/torchvision/io/image/image_read_mode.h new file mode 100644 index 00000000000..84425265c34 --- /dev/null +++ b/product/include/torchvision/io/image/image_read_mode.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +namespace vision { +namespace image { + +/* Should be kept in-sync with Python ImageReadMode enum */ +using ImageReadMode = int64_t; +const ImageReadMode IMAGE_READ_MODE_UNCHANGED = 0; +const ImageReadMode IMAGE_READ_MODE_GRAY = 1; +const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2; +const ImageReadMode IMAGE_READ_MODE_RGB = 3; +const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4; + +} // namespace image +} // namespace vision diff --git a/product/include/torchvision/macros.h b/product/include/torchvision/macros.h new file mode 100644 index 00000000000..f907280e24e --- /dev/null +++ b/product/include/torchvision/macros.h @@ -0,0 +1,11 @@ +#pragma once + +#if defined(_WIN32) && !defined(TORCHVISION_BUILD_STATIC_LIBS) +#if defined(torchvision_EXPORTS) +#define VISION_API __declspec(dllexport) +#else +#define VISION_API __declspec(dllimport) +#endif +#else +#define VISION_API +#endif diff --git a/product/include/torchvision/ops/autograd/deform_conv2d_kernel.cpp b/product/include/torchvision/ops/autograd/deform_conv2d_kernel.cpp new file mode 100644 index 00000000000..0a7bbf9014e --- /dev/null +++ b/product/include/torchvision/ops/autograd/deform_conv2d_kernel.cpp @@ -0,0 +1,266 @@ +#include "../deform_conv2d.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +class DeformConv2dFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& weight, + const torch::autograd::Variable& offset, + const torch::autograd::Variable& mask, + const torch::autograd::Variable& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + at::AutoDispatchBelowADInplaceOrView g; + auto output = deform_conv2d_symint( + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); + + ctx->save_for_backward({input, weight, offset, mask, bias}); + ctx->saved_data["stride_h"] = stride_h; + ctx->saved_data["stride_w"] = stride_w; + ctx->saved_data["pad_h"] = pad_h; + ctx->saved_data["pad_w"] = pad_w; + ctx->saved_data["dilation_h"] = dilation_h; + ctx->saved_data["dilation_w"] = dilation_w; + ctx->saved_data["groups"] = groups; + ctx->saved_data["offset_groups"] = offset_groups; + ctx->saved_data["use_mask"] = use_mask; + + return { + output, + }; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + auto weight = saved[1]; + auto offset = saved[2]; + auto mask = saved[3]; + auto bias = saved[4]; + + auto stride_h = ctx->saved_data["stride_h"].toSymInt(); + auto stride_w = ctx->saved_data["stride_w"].toSymInt(); + auto pad_h = ctx->saved_data["pad_h"].toSymInt(); + auto pad_w = ctx->saved_data["pad_w"].toSymInt(); + auto dilation_h = ctx->saved_data["dilation_h"].toSymInt(); + auto dilation_w = ctx->saved_data["dilation_w"].toSymInt(); + auto groups = ctx->saved_data["groups"].toSymInt(); + auto offset_groups = ctx->saved_data["offset_groups"].toSymInt(); + auto use_mask = ctx->saved_data["use_mask"].toBool(); + + auto grads = detail::_deform_conv2d_backward_symint( + grad_output[0], + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); + auto grad_input = std::get<0>(grads); + auto grad_weight = std::get<1>(grads); + auto grad_offset = std::get<2>(grads); + auto grad_mask = std::get<3>(grads); + auto grad_bias = std::get<4>(grads); + + return { + grad_input, + grad_weight, + grad_offset, + grad_mask, + grad_bias, + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + }; + } +}; + +// TODO: There should be an easier way to do this +class DeformConv2dBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& grad, + const torch::autograd::Variable& input, + const torch::autograd::Variable& weight, + const torch::autograd::Variable& offset, + const torch::autograd::Variable& mask, + const torch::autograd::Variable& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + at::AutoDispatchBelowADInplaceOrView g; + auto result = detail::_deform_conv2d_backward_symint( + grad, + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); + + auto grad_input = std::get<0>(result); + auto grad_weight = std::get<1>(result); + auto grad_offset = std::get<2>(result); + auto grad_mask = std::get<3>(result); + auto grad_bias = std::get<4>(result); + + return { + grad_input, + grad_weight, + grad_offset, + grad_mask, + grad_bias, + }; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + TORCH_CHECK(0, "double backwards on deform_conv2d not supported"); + } +}; + +at::Tensor deform_conv2d_autograd( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + return DeformConv2dFunction::apply( + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask)[0]; +} + +std::tuple +deform_conv2d_backward_autograd( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + auto result = DeformConv2dBackwardFunction::apply( + grad, + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); + + return std::make_tuple(result[0], result[1], result[2], result[3], result[4]); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), + TORCH_FN(deform_conv2d_autograd)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), + TORCH_FN(deform_conv2d_backward_autograd)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp b/product/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp new file mode 100644 index 00000000000..7205e9b15db --- /dev/null +++ b/product/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp @@ -0,0 +1,167 @@ +#include "../ps_roi_align.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +class PSROIAlignFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio) { + ctx->saved_data["spatial_scale"] = spatial_scale; + ctx->saved_data["pooled_height"] = pooled_height; + ctx->saved_data["pooled_width"] = pooled_width; + ctx->saved_data["sampling_ratio"] = sampling_ratio; + ctx->saved_data["input_shape"] = input.sym_sizes(); + at::AutoDispatchBelowADInplaceOrView g; + auto result = ps_roi_align_symint( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio); + + auto output = std::get<0>(result); + auto channel_mapping = std::get<1>(result); + ctx->save_for_backward({rois, channel_mapping}); + ctx->mark_non_differentiable({channel_mapping}); + + return {output, channel_mapping}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + // Use data saved in forward + auto saved = ctx->get_saved_variables(); + auto rois = saved[0]; + auto channel_mapping = saved[1]; + auto input_shape = ctx->saved_data["input_shape"].toList(); + auto grad_in = detail::_ps_roi_align_backward_symint( + grad_output[0], + rois, + channel_mapping, + ctx->saved_data["spatial_scale"].toDouble(), + ctx->saved_data["pooled_height"].toSymInt(), + ctx->saved_data["pooled_width"].toSymInt(), + ctx->saved_data["sampling_ratio"].toInt(), + input_shape[0].get().toSymInt(), + input_shape[1].get().toSymInt(), + input_shape[2].get().toSymInt(), + input_shape[3].get().toSymInt()); + + return { + grad_in, + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable()}; + } +}; + +// TODO: There should be an easier way to do this +class PSROIAlignBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& grad, + const torch::autograd::Variable& rois, + const torch::autograd::Variable& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + at::AutoDispatchBelowADInplaceOrView g; + auto grad_in = detail::_ps_roi_align_backward_symint( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width); + + return {grad_in}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + TORCH_CHECK(0, "double backwards on ps_roi_align not supported"); + } +}; + +std::tuple ps_roi_align_autograd( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio) { + auto result = PSROIAlignFunction::apply( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); + + return std::make_tuple(result[0], result[1]); +} + +at::Tensor ps_roi_align_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + return PSROIAlignBackwardFunction::apply( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width)[0]; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), + TORCH_FN(ps_roi_align_autograd)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), + TORCH_FN(ps_roi_align_backward_autograd)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp b/product/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp new file mode 100644 index 00000000000..39b83819f94 --- /dev/null +++ b/product/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp @@ -0,0 +1,152 @@ +#include "../ps_roi_pool.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +class PSROIPoolFunction : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + ctx->saved_data["spatial_scale"] = spatial_scale; + ctx->saved_data["pooled_height"] = pooled_height; + ctx->saved_data["pooled_width"] = pooled_width; + ctx->saved_data["input_shape"] = input.sym_sizes(); + at::AutoDispatchBelowADInplaceOrView g; + auto result = ps_roi_pool_symint( + input, rois, spatial_scale, pooled_height, pooled_width); + + auto output = std::get<0>(result); + auto channel_mapping = std::get<1>(result); + ctx->save_for_backward({rois, channel_mapping}); + ctx->mark_non_differentiable({channel_mapping}); + + return {output, channel_mapping}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + // Use data saved in forward + auto saved = ctx->get_saved_variables(); + auto rois = saved[0]; + auto channel_mapping = saved[1]; + auto input_shape = ctx->saved_data["input_shape"].toList(); + auto grad_in = detail::_ps_roi_pool_backward_symint( + grad_output[0], + rois, + channel_mapping, + ctx->saved_data["spatial_scale"].toDouble(), + ctx->saved_data["pooled_height"].toSymInt(), + ctx->saved_data["pooled_width"].toSymInt(), + input_shape[0].get().toSymInt(), + input_shape[1].get().toSymInt(), + input_shape[2].get().toSymInt(), + input_shape[3].get().toSymInt()); + + return { + grad_in, + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable()}; + } +}; + +// TODO: There should be an easier way to do this +class PSROIPoolBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& grad, + const torch::autograd::Variable& rois, + const torch::autograd::Variable& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + at::AutoDispatchBelowADInplaceOrView g; + auto grad_in = detail::_ps_roi_pool_backward_symint( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); + + return {grad_in}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + TORCH_CHECK(0, "double backwards on ps_roi_pool not supported"); + } +}; + +std::tuple ps_roi_pool_autograd( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + auto result = PSROIPoolFunction::apply( + input, rois, spatial_scale, pooled_height, pooled_width); + + return std::make_tuple(result[0], result[1]); +} + +at::Tensor ps_roi_pool_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + return PSROIPoolBackwardFunction::apply( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width)[0]; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), + TORCH_FN(ps_roi_pool_autograd)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), + TORCH_FN(ps_roi_pool_backward_autograd)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/autograd/roi_align_kernel.cpp b/product/include/torchvision/ops/autograd/roi_align_kernel.cpp new file mode 100644 index 00000000000..6d792fe09d9 --- /dev/null +++ b/product/include/torchvision/ops/autograd/roi_align_kernel.cpp @@ -0,0 +1,167 @@ +#include "../roi_align.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +class ROIAlignFunction : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + bool aligned) { + ctx->saved_data["spatial_scale"] = spatial_scale; + ctx->saved_data["pooled_height"] = pooled_height; + ctx->saved_data["pooled_width"] = pooled_width; + ctx->saved_data["sampling_ratio"] = sampling_ratio; + ctx->saved_data["aligned"] = aligned; + ctx->saved_data["input_shape"] = input.sym_sizes(); + ctx->save_for_backward({rois}); + at::AutoDispatchBelowADInplaceOrView g; + auto result = roi_align_symint( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + aligned); + return {result}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + // Use data saved in forward + auto saved = ctx->get_saved_variables(); + auto rois = saved[0]; + auto input_shape = ctx->saved_data["input_shape"].toList(); + auto grad_in = detail::_roi_align_backward_symint( + grad_output[0], + rois, + ctx->saved_data["spatial_scale"].toDouble(), + ctx->saved_data["pooled_height"].toSymInt(), + ctx->saved_data["pooled_width"].toSymInt(), + input_shape[0].get().toSymInt(), + input_shape[1].get().toSymInt(), + input_shape[2].get().toSymInt(), + input_shape[3].get().toSymInt(), + ctx->saved_data["sampling_ratio"].toInt(), + ctx->saved_data["aligned"].toBool()); + return { + grad_in, + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable()}; + } +}; + +// TODO: There should be an easier way to do this +class ROIAlignBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& grad, + const torch::autograd::Variable& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width, + int64_t sampling_ratio, + bool aligned) { + at::AutoDispatchBelowADInplaceOrView g; + auto result = detail::_roi_align_backward_symint( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio, + aligned); + return {result}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + TORCH_CHECK(0, "double backwards on roi_align not supported"); + } +}; + +at::Tensor roi_align_autograd( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + bool aligned) { + return ROIAlignFunction::apply( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + aligned)[0]; +} + +at::Tensor roi_align_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width, + int64_t sampling_ratio, + bool aligned) { + return ROIAlignBackwardFunction::apply( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio, + aligned)[0]; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_align"), + TORCH_FN(roi_align_autograd)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), + TORCH_FN(roi_align_backward_autograd)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/autograd/roi_pool_kernel.cpp b/product/include/torchvision/ops/autograd/roi_pool_kernel.cpp new file mode 100644 index 00000000000..508bafb2b1e --- /dev/null +++ b/product/include/torchvision/ops/autograd/roi_pool_kernel.cpp @@ -0,0 +1,152 @@ +#include "../roi_pool.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +class ROIPoolFunction : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + ctx->saved_data["spatial_scale"] = spatial_scale; + ctx->saved_data["pooled_height"] = pooled_height; + ctx->saved_data["pooled_width"] = pooled_width; + ctx->saved_data["input_shape"] = input.sym_sizes(); + at::AutoDispatchBelowADInplaceOrView g; + auto result = roi_pool_symint( + input, rois, spatial_scale, pooled_height, pooled_width); + + auto output = std::get<0>(result); + auto argmax = std::get<1>(result); + ctx->save_for_backward({rois, argmax}); + ctx->mark_non_differentiable({argmax}); + + return {output, argmax}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + // Use data saved in forward + auto saved = ctx->get_saved_variables(); + auto rois = saved[0]; + auto argmax = saved[1]; + auto input_shape = ctx->saved_data["input_shape"].toList(); + auto grad_in = detail::_roi_pool_backward_symint( + grad_output[0], + rois, + argmax, + ctx->saved_data["spatial_scale"].toDouble(), + ctx->saved_data["pooled_height"].toSymInt(), + ctx->saved_data["pooled_width"].toSymInt(), + input_shape[0].get().toSymInt(), + input_shape[1].get().toSymInt(), + input_shape[2].get().toSymInt(), + input_shape[3].get().toSymInt()); + + return { + grad_in, + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable()}; + } +}; + +// TODO: There should be an easier way to do this +class ROIPoolBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& grad, + const torch::autograd::Variable& rois, + const torch::autograd::Variable& argmax, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + at::AutoDispatchBelowADInplaceOrView g; + auto grad_in = detail::_roi_pool_backward_symint( + grad, + rois, + argmax, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); + + return {grad_in}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list& grad_output) { + TORCH_CHECK(0, "double backwards on roi_pool not supported"); + } +}; + +std::tuple roi_pool_autograd( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + auto result = ROIPoolFunction::apply( + input, rois, spatial_scale, pooled_height, pooled_width); + + return std::make_tuple(result[0], result[1]); +} + +at::Tensor roi_pool_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + return ROIPoolBackwardFunction::apply( + grad, + rois, + argmax, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width)[0]; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_pool"), + TORCH_FN(roi_pool_autograd)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), + TORCH_FN(roi_pool_backward_autograd)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp b/product/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp new file mode 100644 index 00000000000..c5e59077aa6 --- /dev/null +++ b/product/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp @@ -0,0 +1,1172 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer + ***************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE + *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer + ********************* + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +// modified from +// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +const int kMaxParallelImgs = 32; + +template +scalar_t bilinear_interpolate( + const scalar_t* in, + int height, + int width, + scalar_t h, + scalar_t w) { + if (h <= -1 || height <= h || w <= -1 || width <= w) { + return 0; + } + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = in[h_low * width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = in[h_low * width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = in[h_high * width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = in[h_high * width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +void deformable_im2col_kernel( + int n, + const scalar_t* input, + const scalar_t* offset, + const scalar_t* mask, + int height, + int width, + int weight_h, + int weight_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int batch_sz, + int n_in_channels, + int n_offset_grps, + int out_h, + int out_w, + bool use_mask, + scalar_t* columns) { + for (int index = 0; index != n; ++index) { + const int out_x = index % out_w; + const int out_y = (index / out_w) % out_h; + const int out_b = (index / (out_w * out_h)) % batch_sz; + const int in_c = index / (out_w * out_h * batch_sz); + const int out_c = in_c * weight_h * weight_w; + + int c_per_offset_grp = n_in_channels / n_offset_grps; + const int grp_idx = in_c / c_per_offset_grp; + + auto columns_ptr = columns + + (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + + out_y * out_w + out_x); + + auto input_ptr = input + + (out_b * (n_in_channels * height * width) + in_c * (height * width)); + + auto offset_ptr = offset + + (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * + out_w; + + auto mask_ptr = mask; + if (use_mask) { + mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * + out_h * out_w; + } + + for (int i = 0; i < weight_h; ++i) { + for (int j = 0; j < weight_w; ++j) { + const int mask_idx = i * weight_w + j; + const int offset_idx = 2 * mask_idx; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = + mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; + } + + const scalar_t offset_h = + offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t offset_w = offset_ptr + [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t y = + (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = + (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + *columns_ptr = + mask_value * bilinear_interpolate(input_ptr, height, width, y, x); + columns_ptr += batch_sz * out_h * out_w; + } + } + } +} + +void deformable_im2col( + const at::Tensor& input, + const at::Tensor& data_offset, + const at::Tensor& data_mask, + int n_in_channels, + int height, + int width, + int weight_h, + int weight_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int out_h, + int out_w, + int parallel_imgs, + int deformable_group, + bool use_mask, + at::Tensor data_col) { + int num_kernels = n_in_channels * out_h * out_w * parallel_imgs; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "deformable_im2col", ([&] { + deformable_im2col_kernel( + num_kernels, + input.data_ptr(), + data_offset.data_ptr(), + data_mask.data_ptr(), + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + n_in_channels, + deformable_group, + out_h, + out_w, + use_mask, + data_col.data_ptr()); + })); +} + +int get_greatest_divisor_below_bound(int n, int bound) { + for (int k = bound; k > 1; --k) { + if (n % k == 0) { + return k; + } + } + return 1; +} + +template +void deformable_col2im_kernel( + int n, + const scalar_t* col, + const scalar_t* offset, + const scalar_t* mask, + int channels, + int height, + int width, + int kernel_h, + int kernel_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int batch_sz, + int n_offset_grps, + int out_h, + int out_w, + bool use_mask, + scalar_t* grad_im) { + for (int index = 0; index != n; ++index) { + const int out_x = index % out_w; + const int out_y = (index / out_w) % out_h; + const int b = (index / (out_w * out_h)) % batch_sz; + const int j = (index / (out_w * out_h * batch_sz)) % kernel_w; + const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; + const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); + + int c_per_offset_grp = channels / n_offset_grps; + const int offset_grp = c / c_per_offset_grp; + + auto offset_ptr = offset + + (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * + out_w; + + auto mask_ptr = mask; + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * + out_h * out_w; + } + + const int mask_idx = i * kernel_w + j; + const int offset_idx = 2 * mask_idx; + + const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; + const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; + + const scalar_t offset_h = offset_ptr[offset_h_ptr]; + const scalar_t offset_w = offset_ptr[offset_w_ptr]; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + + const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + for (int dy = -1; dy <= 1; dy++) { + for (int dx = -1; dx <= 1; dx++) { + int yp = int(y) + dy; + int xp = int(x) + dx; + if (0 <= yp && yp < height && 0 <= xp && xp < width && + std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { + int grad_pos = ((b * channels + c) * height + yp) * width + xp; + scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); + grad_im[grad_pos] += mask_value * weight * col[index]; + } + } + } + } +} + +void compute_grad_input( + const at::Tensor& columns, + const at::Tensor& offset, + const at::Tensor& mask, + int channels, + int height, + int width, + int weight_h, + int weight_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int parallel_imgs, + int n_offset_grps, + bool use_mask, + at::Tensor grad_im) { + int out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + int out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int num_kernels = + channels * weight_h * weight_w * out_h * out_w * parallel_imgs; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + columns.scalar_type(), "compute_grad_input", ([&] { + deformable_col2im_kernel( + num_kernels, + columns.data_ptr(), + offset.data_ptr(), + mask.data_ptr(), + channels, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + n_offset_grps, + out_h, + out_w, + use_mask, + grad_im.data_ptr()); + })); +} + +template +scalar_t get_coordinate_weight( + const scalar_t* im_data, + int height, + int width, + scalar_t y, + scalar_t x, + bool is_y_direction) { + int y_l = floor(y); + int x_l = floor(x); + int y_h = y_l + 1; + int x_h = x_l + 1; + + bool valid_y_l = 0 <= y_l && y_l < height; + bool valid_y_h = 0 <= y_h && y_h < height; + bool valid_x_l = 0 <= x_l && x_l < width; + bool valid_x_h = 0 <= x_h && x_h < width; + + scalar_t zero = 0; + scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero; + scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero; + scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero; + scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero; + + if (is_y_direction) { + scalar_t dx = x - x_l; + return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx); + } else { + scalar_t dy = y - y_l; + return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx); + } +} + +template +void deformable_col2im_coord_kernel( + int n, + const scalar_t* col, + const scalar_t* im, + const scalar_t* offset, + const scalar_t* mask, + int channels, + int height, + int width, + int weight_h, + int weight_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int batch_sz, + int offset_channels, + int n_offset_grps, + int out_h, + int out_w, + bool use_mask, + scalar_t* grad_offset, + scalar_t* grad_mask) { + for (int index = 0; index != n; ++index) { + scalar_t grad_offset_val = 0; + scalar_t grad_mask_val = 0; + + int w = index % out_w; + int h = (index / out_w) % out_h; + int w_w = (index / (out_w * out_h * 2)) % weight_w; + int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; + int c = (index / (out_w * out_h)) % offset_channels; + int b = index / (out_w * out_h * offset_channels); + + const int offset_grp = c / (2 * weight_h * weight_w); + const int col_step = weight_h * weight_w; + + int c_per_offset_grp = channels / n_offset_grps; + + auto col_ptr = col + + offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * + out_h; + auto im_ptr = im + + (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; + auto offset_ptr = offset + + (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * + out_w; + + auto mask_ptr = mask; + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * + out_h * out_w; + } + + const int offset_c = c - offset_grp * 2 * weight_h * weight_w; + const bool is_y_direction = offset_c % 2 == 0; + + const int c_bound = c_per_offset_grp * weight_h * weight_w; + for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { + const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w; + + int out_x = col_pos % out_w; + int out_y = (col_pos / out_w) % out_h; + int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; + int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; + + const int mask_idx = i * weight_w + j; + + const int offset_h_idx = + (((2 * mask_idx) * out_h + out_y) * out_w + out_x); + const int offset_w_idx = + (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); + const scalar_t offset_h = offset_ptr[offset_h_idx]; + const scalar_t offset_w = offset_ptr[offset_w_idx]; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + + scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + const scalar_t weight = + get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); + grad_offset_val += mask_value * weight * col_ptr[col_pos]; + + if (use_mask && is_y_direction) { + grad_mask_val += col_ptr[col_pos] * + bilinear_interpolate(im_ptr, height, width, y, x); + } + + im_ptr += height * width; + } + + grad_offset[index] = grad_offset_val; + + if (use_mask && is_y_direction) { + const int idx = + ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + + w_w) * + out_h + + h) * + out_w + + w; + grad_mask[idx] = grad_mask_val; + } + } +} + +void compute_grad_offset_and_mask( + const at::Tensor& columns, + const at::Tensor& input, + const at::Tensor& offset, + const at::Tensor& mask, + int channels, + int height, + int width, + int weight_h, + int weight_w, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int parallel_imgs, + int n_offset_grps, + bool use_mask, + at::Tensor grad_offset, + at::Tensor grad_mask) { + int out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + int out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int num_kernels = + out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + columns.scalar_type(), "compute_grad_offset_and_mask", ([&] { + deformable_col2im_coord_kernel( + num_kernels, + columns.data_ptr(), + input.data_ptr(), + offset.data_ptr(), + mask.data_ptr(), + channels, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + 2 * weight_h * weight_w * n_offset_grps, + n_offset_grps, + out_h, + out_w, + use_mask, + grad_offset.data_ptr(), + grad_mask.data_ptr()); + })); +} + +std::tuple backward_gradient_inputs( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor mask, + at::Tensor grad_out, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w, + int n_weight_grps, + int n_offset_grps, + int n_parallel_imgs, + bool use_mask) { + int batch_sz = input.size(0); + int n_in_channels = input.size(1); + int in_h = input.size(2); + int in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + long n_out_channels = weight.size(0); + int weight_h = weight.size(2); + int weight_w = weight.size(3); + + long out_h = + (in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + long out_w = + (in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + + auto grad_input = at::zeros_like(input); + auto grad_offset = at::zeros_like(offset); + auto grad_mask = at::zeros_like(mask); + + if (batch_sz == 0) { + return std::make_tuple(grad_input, grad_offset, grad_mask); + } + + auto columns = at::empty( + {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, + input.options()); + + // Separate into blocks + grad_input = grad_input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + input = input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + + grad_offset = grad_offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + offset = offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + grad_mask = grad_mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + mask = mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + grad_out = grad_out + .reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_weight_grps, + n_out_channels / n_weight_grps, + out_h, + out_w}) + .permute({0, 2, 3, 1, 4, 5}); + + weight = weight.reshape( + {n_weight_grps, + weight.size(0) / n_weight_grps, + weight.size(1), + weight.size(2), + weight.size(3)}); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + + for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + columns.zero_(); + // Separate into weight groups + for (int g = 0; g < n_weight_grps; g++) { + columns[g] = columns[g].addmm_( + weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); + } + + compute_grad_offset_and_mask( + columns, + input[elt], + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + grad_offset[elt], + grad_mask[elt]); + + compute_grad_input( + columns, + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + grad_input[elt]); + } + + grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); + grad_offset = grad_offset.view( + {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + if (use_mask) { + grad_mask = grad_mask.view( + {batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w}); + } + + return std::make_tuple(grad_input, grad_offset, grad_mask); +} + +at::Tensor backward_gradient_parameters( + at::Tensor input, + const at::Tensor& weight, + at::Tensor offset, + at::Tensor mask, + const at::Tensor& grad_out, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w, + int n_weight_grps, + int n_offset_grps, + int n_parallel_imgs, + bool use_mask) { + int batch_sz = input.size(0); + int n_in_channels = input.size(1); + int in_h = input.size(2); + int in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + long n_out_channels = weight.size(0); + int weight_h = weight.size(2); + int weight_w = weight.size(3); + + long out_h = grad_out.size(2); + long out_w = grad_out.size(3); + + auto grad_weight = at::zeros_like(weight); + if (batch_sz == 0) { + return grad_weight; + } + + at::Tensor grad_out_buf = grad_out + .reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_weight_grps, + n_out_channels / n_weight_grps, + out_h, + out_w}) + .permute({0, 2, 3, 1, 4, 5}) + .contiguous(); + + input = input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + + offset = offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + mask = mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + grad_weight = grad_weight.view( + {n_weight_grps, + grad_weight.size(0) / n_weight_grps, + grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3)}); + + auto columns = at::empty( + {n_weight_grps, + n_in_channels * weight_w * weight_h / n_weight_grps, + n_parallel_imgs * out_h * out_w}, + input.options()); + + for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + deformable_im2col( + input[elt], + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + columns); + + for (int g = 0; g < n_weight_grps; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_( + grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) + .view_as(grad_weight[g]); + } + } + + grad_weight = grad_weight.view( + {grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3), + grad_weight.size(4)}); + return grad_weight; +} + +at::Tensor deform_conv2d_forward_kernel( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + at::Tensor input_c = input.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + TORCH_CHECK(input_c.ndimension() == 4); + TORCH_CHECK(offset_c.ndimension() == 4); + TORCH_CHECK(!use_mask || mask_c.ndimension() == 4); + TORCH_CHECK(weight_c.ndimension() == 4); + TORCH_CHECK(input_c.device().is_cpu(), "input must be a CPU tensor"); + + int batch_sz = input_c.size(0); + int n_in_channels = input_c.size(1); + int in_h = input_c.size(2); + int in_w = input_c.size(3); + + int n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); + + // Unpack shapes and args + int out_channels = weight_c.size(0); + int weight_h = weight_c.size(2); + int weight_w = weight_c.size(3); + + int ker_h = dilation_h * (weight_h - 1) + 1; + int ker_w = dilation_w * (weight_w - 1) + 1; + int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; + + TORCH_CHECK( + weight_h > 0 && weight_w > 0, + "weight_h: ", + weight_h, + " weight_w: ", + weight_w); + TORCH_CHECK( + stride_h > 0 && stride_w > 0, + "stride_h: ", + stride_h, + " stride_w: ", + stride_w); + TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); + TORCH_CHECK( + dilation_h > 0 && dilation_w > 0, + "dilation_h: ", + dilation_h, + " dilation_w: ", + dilation_w); + + TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1)); + TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0); + TORCH_CHECK( + (offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w), + "offset.shape[1] is not valid: got: ", + offset_c.size(1), + " expected: ", + n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK( + (!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w), + "mask.shape[1] is not valid: got: ", + mask_c.size(1), + " expected: ", + n_offset_grps * weight_h * weight_w); + TORCH_CHECK(input_c.size(1) % n_offset_grps == 0); + + TORCH_CHECK( + (offset_c.size(0) == input_c.size(0)), "invalid batch size of offset"); + TORCH_CHECK( + (offset_c.size(2) == out_h && offset_c.size(3) == out_w), + "offset output dims: (", + offset_c.size(2), + ", ", + offset_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + (mask_c.size(0) == input_c.size(0)), "invalid batch size of mask"); + TORCH_CHECK( + (!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)), + "mask output dims: (", + mask_c.size(2), + ", ", + mask_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + out_h > 0 && out_w > 0, + "Calculated output size too small - out_h: ", + out_h, + " out_w: ", + out_w); + + auto out = + at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options()); + if (batch_sz == 0) { + return out; + } + + // Separate batches into blocks + out = out.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + out_channels, + out_h, + out_w}); + input_c = input_c.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + + offset_c = offset_c.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + mask_c = mask_c.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + at::Tensor out_buf = at::zeros( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs * out_h, + out_w}, + out.options()); + + // Separate channels into convolution groups + out_buf = out_buf.view( + {out_buf.size(0), + n_weight_grps, + out_buf.size(1) / n_weight_grps, + out_buf.size(2), + out_buf.size(3)}); + weight_c = weight_c.view( + {n_weight_grps, + weight_c.size(0) / n_weight_grps, + weight_c.size(1), + weight_c.size(2), + weight_c.size(3)}); + + // Sample points and perform convolution + auto columns = at::zeros( + {n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, + input_c.options()); + for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { + deformable_im2col( + input_c[b], + offset_c[b], + mask_c[b], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + columns); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + for (int g = 0; g < n_weight_grps; g++) { + out_buf[b][g] = out_buf[b][g] + .flatten(1) + .addmm_(weight_c[g].flatten(1), columns[g]) + .view_as(out_buf[b][g]); + } + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + out_buf = out_buf.view( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs, + out_h, + out_w}); + out_buf.transpose_(1, 2); + out.copy_(out_buf); + out = out.view({batch_sz, out_channels, out_h, out_w}); + + return out + bias_c.view({1, out_channels, 1, 1}); +} + +std::tuple +deform_conv2d_backward_kernel( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + at::Tensor grad_out_c = grad_out.contiguous(); + at::Tensor input_c = input.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + const int batch_sz = input_c.size(0); + const int n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); + + auto grad_input_and_offset_and_mask = backward_gradient_inputs( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + n_weight_grps, + n_offset_grps, + n_parallel_imgs, + use_mask); + + auto grad_input = std::get<0>(grad_input_and_offset_and_mask); + auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); + auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); + + auto grad_weight = backward_gradient_parameters( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + n_weight_grps, + n_offset_grps, + n_parallel_imgs, + use_mask); + + auto grad_bias = at::ones_like(bias_c) * grad_out_c.sum({0, 2, 3}); + + return std::make_tuple( + grad_input, grad_weight, grad_offset, grad_mask, grad_bias); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), + TORCH_FN(deform_conv2d_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), + TORCH_FN(deform_conv2d_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/cpu/nms_kernel.cpp b/product/include/torchvision/ops/cpu/nms_kernel.cpp new file mode 100644 index 00000000000..50479066cbd --- /dev/null +++ b/product/include/torchvision/ops/cpu/nms_kernel.cpp @@ -0,0 +1,117 @@ +#include +#include + +namespace vision { +namespace ops { + +namespace { + +template +at::Tensor nms_kernel_impl( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); + TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); + TORCH_CHECK( + dets.scalar_type() == scores.scalar_type(), + "dets should have the same type as scores"); + + if (dets.numel() == 0) + return at::empty({0}, dets.options().dtype(at::kLong)); + + auto x1_t = dets.select(1, 0).contiguous(); + auto y1_t = dets.select(1, 1).contiguous(); + auto x2_t = dets.select(1, 2).contiguous(); + auto y2_t = dets.select(1, 3).contiguous(); + + at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); + + auto order_t = std::get<1>( + scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + + auto ndets = dets.size(0); + at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); + at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); + + auto suppressed = suppressed_t.data_ptr(); + auto keep = keep_t.data_ptr(); + auto order = order_t.data_ptr(); + auto x1 = x1_t.data_ptr(); + auto y1 = y1_t.data_ptr(); + auto x2 = x2_t.data_ptr(); + auto y2 = y2_t.data_ptr(); + auto areas = areas_t.data_ptr(); + + int64_t num_to_keep = 0; + + for (int64_t _i = 0; _i < ndets; _i++) { + auto i = order[_i]; + if (suppressed[i] == 1) + continue; + keep[num_to_keep++] = i; + auto ix1 = x1[i]; + auto iy1 = y1[i]; + auto ix2 = x2[i]; + auto iy2 = y2[i]; + auto iarea = areas[i]; + + for (int64_t _j = _i + 1; _j < ndets; _j++) { + auto j = order[_j]; + if (suppressed[j] == 1) + continue; + auto xx1 = std::max(ix1, x1[j]); + auto yy1 = std::max(iy1, y1[j]); + auto xx2 = std::min(ix2, x2[j]); + auto yy2 = std::min(iy2, y2[j]); + + auto w = std::max(static_cast(0), xx2 - xx1); + auto h = std::max(static_cast(0), yy2 - yy1); + auto inter = w * h; + auto ovr = inter / (iarea + areas[j] - inter); + if (ovr > iou_threshold) + suppressed[j] = 1; + } + } + return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); +} + +at::Tensor nms_kernel( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 4, + "boxes should have 4 elements in dimension 1, got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)); + + auto result = at::empty({0}, dets.options()); + + AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] { + result = nms_kernel_impl(dets, scores, iou_threshold); + }); + return result; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp b/product/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp new file mode 100644 index 00000000000..1c272427d3f --- /dev/null +++ b/product/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp @@ -0,0 +1,429 @@ +#include +#include + +namespace vision { +namespace ops { + +namespace { + +template +T bilinear_interpolate( + const T* input, + int height, + int width, + T y, + T x, + int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +void ps_roi_align_forward_kernel_impl( + int num_rois, + const T* input, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + const T* rois, + int channels_out, + T* output, + int* channel_mapping) { + for (int n = 0; n < num_rois; n++) { + // [start, end) interval for spatial sampling + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int c_in = 0; + for (int c_out = 0; c_out < channels_out; ++c_out) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int index = + ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + + pw; + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + const T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + + T out_sum = 0; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = bilinear_interpolate( + offset_input, height, width, y, x, index); + out_sum += val; + } + } + + out_sum /= count; + output[index] = out_sum; + channel_mapping[index] = c_in; + c_in++; + } + } + } + } +} + +template +void bilinear_interpolate_gradient( + int height, + int width, + T y, + T x, + T& w1, + T& w2, + T& w3, + T& w4, + int& x_low, + int& x_high, + int& y_low, + int& y_high, + int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; +} + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void ps_roi_align_backward_kernel_impl( + int nthreads, + const T* grad_output, + const int* channel_mapping, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + int channels_out, + T* grad_input, + const T* rois) { + for (int index = 0; index < nthreads; index++) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int n = index / pooled_width / pooled_height / channels_out; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + // Force too small ROIs to be 1x1 + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + int c_in = channel_mapping[index]; + T* grad_input_offset = + grad_input + (roi_batch_ind * channels + c_in) * height * width; + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + const T grad_output_this_bin = grad_output[index]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + add(grad_input_offset + y_low * width + x_low, g1); + add(grad_input_offset + y_low * width + x_high, g2); + add(grad_input_offset + y_high * width + x_low, g3); + add(grad_input_offset + y_high * width + x_high, g4); + } // if + } // ix + } // iy + } +} + +std::tuple ps_roi_align_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { + // Check if input tensors are CPU tensors + TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + TORCH_CHECK( + rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_align_forward_kernel"; + at::checkAllSameType(c, {input_t, rois_t}); + + int num_rois = rois.size(0); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + TORCH_CHECK( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros( + {num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = + at::zeros(output.sizes(), input.options().dtype(at::kInt)); + + if (output.numel() == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "ps_roi_align_forward_kernel", [&] { + ps_roi_align_forward_kernel_impl( + num_rois, + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois_.data_ptr(), + channels_out, + output.data_ptr(), + channel_mapping.data_ptr()); + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor ps_roi_align_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + // Check if input tensors are CPU tensors + TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + TORCH_CHECK( + channel_mapping.device().is_cpu(), + "channel_mapping must be a CPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "ps_roi_align_backward_kernel"; + at::checkAllSameType(c, {grad_t, rois_t}); + + auto grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + int channels_out = channels / (pooled_height * pooled_width); + + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "ps_roi_align_backward_kernel", [&] { + ps_roi_align_backward_kernel_impl( + grad.numel(), + grad_.data_ptr(), + channel_mapping.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + channels_out, + grad_input.data_ptr(), + rois_.data_ptr()); + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), + TORCH_FN(ps_roi_align_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), + TORCH_FN(ps_roi_align_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp b/product/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp new file mode 100644 index 00000000000..607cbe4bab6 --- /dev/null +++ b/product/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp @@ -0,0 +1,273 @@ +#include +#include + +namespace vision { +namespace ops { + +namespace { + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void ps_roi_pool_forward_kernel_impl( + const T* input, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + const T* rois, + int channels_out, + int num_rois, + T* output, + int* channel_mapping) { + for (int n = 0; n < num_rois; ++n) { + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = round(offset_rois[1] * spatial_scale); + int roi_start_h = round(offset_rois[2] * spatial_scale); + int roi_end_w = round(offset_rois[3] * spatial_scale); + int roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + int roi_width = std::max(roi_end_w - roi_start_w, 1); + int roi_height = std::max(roi_end_h - roi_start_h, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int c_in = 0; + for (int c_out = 0; c_out < channels_out; ++c_out) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = + static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = + static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), height - 1); + hend = std::min(std::max(hend + roi_start_h, 0), height - 1); + wstart = std::min(std::max(wstart + roi_start_w, 0), width - 1); + wend = std::min(std::max(wend + roi_start_w, 0), width - 1); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + const T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + + T out_sum = 0; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * width + w; + out_sum += offset_input[input_index]; + } + } + + int index = + ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + + pw; + T bin_area = (hend - hstart) * (wend - wstart); + output[index] = is_empty ? static_cast(0) : out_sum / bin_area; + channel_mapping[index] = c_in; + c_in++; + } + } + } + } +} + +template +void ps_roi_pool_backward_kernel_impl( + const T* grad_output, + const int* channel_mapping, + int num_rois, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int channels_out, + T* grad_input, + const T* rois) { + for (int n = 0; n < num_rois; ++n) { + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = roundf(offset_rois[1] * spatial_scale); + int roi_start_h = roundf(offset_rois[2] * spatial_scale); + int roi_end_w = roundf(offset_rois[3] * spatial_scale); + int roi_end_h = roundf(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + int roi_width = std::max(roi_end_w - roi_start_w, 1); + int roi_height = std::max(roi_end_h - roi_start_h, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), height); + hend = std::min(std::max(hend + roi_start_h, 0), height); + wstart = std::min(std::max(wstart + roi_start_w, 0), width); + wend = std::min(std::max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + for (int c_out = 0; c_out < channels_out; ++c_out) { + int index = + ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + + pw; + int c_in = channel_mapping[index]; + + T* grad_input_offset = + grad_input + (roi_batch_ind * channels + c_in) * height * width; + T bin_area = (hend - hstart) * (wend - wstart); + T diff_val = + is_empty ? static_cast(0) : grad_output[index] / bin_area; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int grad_input_index = h * width + w; + add(grad_input_offset + grad_input_index, diff_val); + } + } + } + } + } + } +} + +std::tuple ps_roi_pool_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + // Check if input tensors are CPU tensors + TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + TORCH_CHECK( + rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_pool_forward_kernel"; + at::checkAllSameType(c, {input_t, rois_t}); + + int num_rois = rois.size(0); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + TORCH_CHECK( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros( + {num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = + at::zeros(output.sizes(), input.options().dtype(at::kInt)); + + auto output_size = output.numel(); + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "ps_roi_pool_forward_kernel", [&] { + ps_roi_pool_forward_kernel_impl( + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.data_ptr(), + channels_out, + num_rois, + output.data_ptr(), + channel_mapping.data_ptr()); + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor ps_roi_pool_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + // Check if input tensors are CPU tensors + TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + TORCH_CHECK( + channel_mapping.device().is_cpu(), + "channel_mapping must be a CPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "ps_roi_pool_backward_kernel"; + at::checkAllSameType(c, {grad_t, rois_t}); + + auto num_rois = rois.size(0); + auto grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + int channels_out = channels / (pooled_height * pooled_width); + + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] { + ps_roi_pool_backward_kernel_impl( + grad_.data_ptr(), + channel_mapping.data_ptr(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + channels_out, + grad_input.data_ptr(), + rois_.data_ptr()); + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), + TORCH_FN(ps_roi_pool_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), + TORCH_FN(ps_roi_pool_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/cpu/roi_align_common.h b/product/include/torchvision/ops/cpu/roi_align_common.h new file mode 100644 index 00000000000..e10c67b5b79 --- /dev/null +++ b/product/include/torchvision/ops/cpu/roi_align_common.h @@ -0,0 +1,128 @@ +#pragma once + +#include + +namespace vision { +namespace ops { +namespace detail { + +template +struct PreCalc { + int pos1; + int pos2; + int pos3; + int pos4; + T w1; + T w2; + T w3; + T w4; +}; + +// This helper computes the interpolation weights (w1, w2...) for every sampling +// point of a given box. There are pool_height * pool_width * roi_bin_grid_h * +// roi_bin_grid_w such sampling points. +// +// The weights (w1, w2...) are computed as the areas in this figure: +// https://en.wikipedia.org/wiki/Bilinear_interpolation#/media/File:Bilinear_interpolation_visualisation.svg +// and pos1, pos2 etc correspond to the indices of their respective pixels. +// +// Note: the weights and indices are shared across all channels, which is why +// they are pre-calculated prior to the main loop in the RoIAlign kernel. +// implementation taken from Caffe2 +template +void pre_calc_for_bilinear_interpolate( + int height, + int width, + int pooled_height, + int pooled_width, + T roi_start_h, + T roi_start_w, + T bin_size_h, + T bin_size_w, + int roi_bin_grid_h, + int roi_bin_grid_w, + std::vector>& pre_calc) { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T x = xx; + T y = yy; + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + PreCalc pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc[pre_calc_index] = pc; + pre_calc_index += 1; + continue; + } + + if (y <= 0) { + y = 0; + } + if (x <= 0) { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // save weights and indices + PreCalc pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc[pre_calc_index] = pc; + + pre_calc_index += 1; + } + } + } + } +} + +} // namespace detail +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/cpu/roi_align_kernel.cpp b/product/include/torchvision/ops/cpu/roi_align_kernel.cpp new file mode 100644 index 00000000000..b787de6f6bb --- /dev/null +++ b/product/include/torchvision/ops/cpu/roi_align_kernel.cpp @@ -0,0 +1,400 @@ +#include +#include + +#include "./roi_align_common.h" + +namespace vision { +namespace ops { + +namespace { + +template +void roi_align_forward_kernel_impl( + int n_rois, + const T* input, + const T& spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + bool aligned, + const T* rois, + T* output) { + // (n, c, ph, pw) is an element in the pooled output + // can be parallelized using omp + // #pragma omp parallel for num_threads(32) + for (int n = 0; n < n_rois; n++) { + int index_n = n * channels * pooled_width * pooled_height; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + // When the grid is empty, output zeros. + const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + // we want to precalculate indices and weights shared by all channels, + // this is the key point of optimization + std::vector> pre_calc( + roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); + detail::pre_calc_for_bilinear_interpolate( + height, + width, + pooled_height, + pooled_width, + roi_start_h, + roi_start_w, + bin_size_h, + bin_size_w, + roi_bin_grid_h, + roi_bin_grid_w, + pre_calc); + + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * pooled_width * pooled_height; + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + int index = index_n_c + ph * pooled_width + pw; + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + detail::PreCalc pc = pre_calc[pre_calc_index]; + output_val += pc.w1 * offset_input[pc.pos1] + + pc.w2 * offset_input[pc.pos2] + + pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4]; + + pre_calc_index += 1; + } + } + output_val /= count; // Average pooling + + output[index] = output_val; + } // for pw + } // for ph + } // for c + } // for n +} + +template +void bilinear_interpolate_gradient( + int height, + int width, + T y, + T x, + T& w1, + T& w2, + T& w3, + T& w4, + int& x_low, + int& x_high, + int& y_low, + int& y_high, + int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; +} + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void roi_align_backward_kernel_impl( + int nthreads, + const T* grad_output, + const T& spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + int sampling_ratio, + bool aligned, + T* grad_input, + const T* rois, + int n_stride, + int c_stride, + int h_stride, + int w_stride) { + for (int index = 0; index < nthreads; index++) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_grad_input = + grad_input + ((roi_batch_ind * channels + c) * height * width); + + int output_offset = n * n_stride + c * c_stride; + const T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = + offset_grad_output[ph * h_stride + pw * w_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + // atomic add is not needed for now since it is single threaded + add(offset_grad_input + y_low * width + x_low, static_cast(g1)); + add(offset_grad_input + y_low * width + x_high, static_cast(g2)); + add(offset_grad_input + y_high * width + x_low, static_cast(g3)); + add(offset_grad_input + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } // for +} + +at::Tensor roi_align_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { + TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_forward_kernel"; + at::checkAllSameType(c, {input_t, rois_t}); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + + if (output.numel() == 0) + return output; + + auto input_ = input.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "roi_align_forward_kernel", [&] { + roi_align_forward_kernel_impl( + num_rois, + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + rois_.data_ptr(), + output.data_ptr()); + }); + return output; +} + +at::Tensor roi_align_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned) { + TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_backward_kernel"; + at::checkAllSameType(c, {grad_t, rois_t}); + + at::Tensor grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "roi_align_backward_kernel", [&] { + roi_align_backward_kernel_impl( + grad.numel(), + grad.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + aligned, + grad_input.data_ptr(), + rois_.data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_align"), + TORCH_FN(roi_align_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), + TORCH_FN(roi_align_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/cpu/roi_pool_kernel.cpp b/product/include/torchvision/ops/cpu/roi_pool_kernel.cpp new file mode 100644 index 00000000000..b099523896a --- /dev/null +++ b/product/include/torchvision/ops/cpu/roi_pool_kernel.cpp @@ -0,0 +1,249 @@ +#include + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void roi_pool_forward_kernel_impl( + const T* input, + const T spatial_scale, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + const T* rois, + int num_rois, + T* output, + int* argmax_data) { + for (int n = 0; n < num_rois; ++n) { + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = round(offset_rois[1] * spatial_scale); + int roi_start_h = round(offset_rois[2] * spatial_scale); + int roi_end_w = round(offset_rois[3] * spatial_scale); + int roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + int roi_width = std::max(roi_end_w - roi_start_w + 1, 1); + int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), height); + hend = std::min(std::max(hend + roi_start_h, 0), height); + wstart = std::min(std::max(wstart + roi_start_w, 0), width); + wend = std::min(std::max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + for (int c = 0; c < channels; ++c) { + // Define an empty pooling region to be zero + T maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + + const T* input_offset = + input + (roi_batch_ind * channels + c) * height * width; + + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * width + w; + if (input_offset[input_index] > maxval) { + maxval = input_offset[input_index]; + maxidx = input_index; + } + } + } + int index = + ((n * channels + c) * pooled_height + ph) * pooled_width + pw; + output[index] = maxval; + argmax_data[index] = maxidx; + } // channels + } // pooled_width + } // pooled_height + } // num_rois +} + +template +void roi_pool_backward_kernel_impl( + const T* grad_output, + const int* argmax_data, + int num_rois, + int channels, + int height, + int width, + int pooled_height, + int pooled_width, + T* grad_input, + const T* rois, + int n_stride, + int c_stride, + int h_stride, + int w_stride) { + for (int n = 0; n < num_rois; ++n) { + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + for (int c = 0; c < channels; ++c) { + T* grad_input_offset = + grad_input + ((roi_batch_ind * channels + c) * height * width); + const int* argmax_data_offset = + argmax_data + (n * channels + c) * pooled_height * pooled_width; + + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int output_offset = n * n_stride + c * c_stride; + int argmax = argmax_data_offset[ph * pooled_width + pw]; + + if (argmax != -1) { + add(grad_input_offset + argmax, + static_cast( + grad_output + [output_offset + ph * h_stride + pw * w_stride])); + } + } // pooled_width + } // pooled_height + } // channels + } // num_rois +} + +std::tuple roi_pool_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_pool_forward_kernel"; + at::checkAllSameType(c, {input_t, rois_t}); + + int num_rois = rois.size(0); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, + input.options().dtype(at::kInt)); + + if (output.numel() == 0) { + return std::make_tuple(output, argmax); + } + + auto input_ = input.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "roi_pool_forward_kernel", [&] { + roi_pool_forward_kernel_impl( + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois_.data_ptr(), + num_rois, + output.data_ptr(), + argmax.data_ptr()); + }); + return std::make_tuple(output, argmax); +} + +at::Tensor roi_pool_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + // Check if input tensors are CPU tensors + TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); + TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); + TORCH_CHECK(argmax.device().is_cpu(), "argmax must be a CPU tensor"); + TORCH_CHECK( + rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_pool_backward_kernel"; + at::checkAllSameType(c, {grad_t, rois_t}); + + auto num_rois = rois.size(0); + + at::Tensor grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "roi_pool_backward_kernel", [&] { + roi_pool_backward_kernel_impl( + grad.data_ptr(), + argmax.data_ptr(), + num_rois, + channels, + height, + width, + pooled_height, + pooled_width, + grad_input.data_ptr(), + rois_.data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_pool"), + TORCH_FN(roi_pool_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), + TORCH_FN(roi_pool_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/deform_conv2d.cpp b/product/include/torchvision/ops/deform_conv2d.cpp new file mode 100644 index 00000000000..3cda60fe0bc --- /dev/null +++ b/product/include/torchvision/ops/deform_conv2d.cpp @@ -0,0 +1,172 @@ +#include "deform_conv2d.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +at::Tensor deform_conv2d( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::deform_conv2d", "") + .typed(); + return op.call( + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); +} + +at::Tensor deform_conv2d_symint( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::deform_conv2d", "") + .typed(); + return op.call( + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); +} + +namespace detail { + +std::tuple +_deform_conv2d_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") + .typed(); + return op.call( + grad, + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); +} + +std::tuple +_deform_conv2d_backward_symint( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") + .typed(); + return op.call( + grad, + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); +} + +} // namespace detail + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)")); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/deform_conv2d.h b/product/include/torchvision/ops/deform_conv2d.h new file mode 100644 index 00000000000..cf1f142e648 --- /dev/null +++ b/product/include/torchvision/ops/deform_conv2d.h @@ -0,0 +1,82 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API at::Tensor deform_conv2d( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask); + +VISION_API at::Tensor deform_conv2d_symint( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask); + +namespace detail { + +std::tuple +_deform_conv2d_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask); + +std::tuple +_deform_conv2d_backward_symint( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask); + +} // namespace detail + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/mps/mps_helpers.h b/product/include/torchvision/ops/mps/mps_helpers.h new file mode 100644 index 00000000000..d3c0e8d94b7 --- /dev/null +++ b/product/include/torchvision/ops/mps/mps_helpers.h @@ -0,0 +1,6 @@ +constexpr int threadsPerBlock = 512; + +template +constexpr inline T ceil_div(T n, T m) { + return (n + m - 1) / m; +} diff --git a/product/include/torchvision/ops/mps/mps_kernels.h b/product/include/torchvision/ops/mps/mps_kernels.h new file mode 100644 index 00000000000..e720a1608f1 --- /dev/null +++ b/product/include/torchvision/ops/mps/mps_kernels.h @@ -0,0 +1,1102 @@ +#include + +namespace vision { +namespace ops { + +namespace mps { + +static const char* METAL_VISION = R"VISION_METAL( + +#include +#include +using namespace metal; + +/*----------Macros----------*/ + +#define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \ + for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \ + i += (tptg.x * n_tgs)) + +#define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, uint) + +/*----------Helpers--------*/ + +template +inline T ceil_div(T n, T m) { + return (n + m - 1) / m; +} + +template +inline void atomic_add_float( device T* data_ptr, const T val) +{ +#if __METAL_VERSION__ >= 300 + // atomic_float is supported in Metal 3 (macOS Ventura) onward. + device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); +#else + // Custom atomic addition implementation + // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 + // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639 + // https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide) + + // Create an atomic uint pointer for atomic transaction. + device atomic_uint* atom_var = (device atomic_uint*)data_ptr; + // Create necessary storage. + uint fetched_uint, assigning_uint; + T fetched_float, assigning_float; + + // Replace the value in atom_var with 0 and return the previous value in atom_var. + fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed); + // Read out the previous value as float. + fetched_float = *( (thread T*) &fetched_uint ); + + // Do addition and represent the addition result in uint for atomic transaction. + assigning_float = fetched_float + val; + assigning_uint = *((thread uint*) &assigning_float); + + // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr). + while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) { + // If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads. + // Try to assign 0 and get the previously assigned addition result. + uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed); + T fetched_float_again = *( (thread T*) &fetched_uint_again ); + // Re-add again + fetched_float = *((thread T*) &(fetched_uint)); + // Previously assigned addition result + addition result from other threads. + assigning_float = fetched_float_again + fetched_float; + assigning_uint = *( (thread uint*) &assigning_float); + } +#endif +} + +template +inline T bilinear_interpolate( + constant T* input, + integer_t height, + integer_t width, + T y, + T x, + uint index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + integer_t y_low = (integer_t)y; + integer_t x_low = (integer_t)x; + integer_t y_high; + integer_t x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +inline void bilinear_interpolate_gradient( + integer_t height, + integer_t width, + T y, + T x, + thread T& w1, + thread T& w2, + thread T& w3, + thread T& w4, + thread integer_t& x_low, + thread integer_t& x_high, + thread integer_t& y_low, + thread integer_t& y_high, + uint index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + y_low = (integer_t)y; + x_low = (integer_t)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; +} + +template +inline bool IoU( + constant T & a, + threadgroup T & b, + const float threshold) { + auto xx1 = max(a.x, b.x); + auto yy1 = max(a.y, b.y); + auto xx2 = min(a.z, b.z); + auto yy2 = min(a.w, b.w); + auto w = max(static_cast(0), xx2 - xx1); + auto h = max(static_cast(0), yy2 - yy1); + // Upcast to float before multiplications to circumvent precision issues in half. + auto inter = static_cast(w) * static_cast(h); + auto area_b = static_cast(b.z - b.x) * static_cast(b.w - b.y); + auto area_a = static_cast(a.z - a.x) * static_cast(a.w - a.y); + return (inter / (area_a + area_b - inter)) > threshold; +} + +/*----------Kernels----------*/ + +// This should be in sync with the one in nms_kernel.mm. +// Since metal does not support dynamic array, +// we need to make it static instead of deriving it from [[threads_per_threadgroup]]. +constant int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8; + +template +kernel void nms(constant T * dev_boxes [[buffer(0)]], + device uint64_t * mask [[buffer(1)]], + constant int64_t & n_boxes [[buffer(2)]], + constant float & iou_threshold [[buffer(3)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tid2 [[thread_position_in_threadgroup]]) { + + const uint row_start = tgid.y; + const uint col_start = tgid.x; + const uint tid = tid2.x; + const uint row_size = + min(n_boxes - row_start * nmsThreadsPerBlock, nmsThreadsPerBlock); + const uint col_size = + min(n_boxes - col_start * nmsThreadsPerBlock, nmsThreadsPerBlock); + + threadgroup T block_boxes[nmsThreadsPerBlock]; + block_boxes[tid] = dev_boxes[nmsThreadsPerBlock * col_start + tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid < row_size) { + const uint cur_box_idx = nmsThreadsPerBlock * row_start + tid; + uint64_t t = 0; + uint start = 0; + + if (row_start == col_start) { + start = tid + 1; + } + + for (uint i = start; i < col_size; i++){ + if (IoU(dev_boxes[cur_box_idx], block_boxes[i], iou_threshold)){ + t |= static_cast(1) << i; // discard 1 keep 0 + } + } + const uint col_blocks = ceil_div(n_boxes, nmsThreadsPerBlock); + mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +#define REGISTER_NMS_OP(DTYPE) \ +template \ +[[host_name("nms_" #DTYPE)]] \ +kernel void nms( \ + constant DTYPE ## 4 * dev_boxes [[buffer(0)]], \ + device uint64_t * mask [[buffer(1)]], \ + constant int64_t & n_boxes [[buffer(2)]], \ + constant float & iou_threshold [[buffer(3)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void roi_align( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + constant int64_t & output_size [[buffer(3)]], + constant int64_t & channels [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], + constant int64_t & pooled_height [[buffer(7)]], + constant int64_t & pooled_width [[buffer(8)]], + constant int64_t & sampling_ratio [[buffer(9)]], + constant bool & aligned [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + constant T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + integer_t roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + integer_t roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + // When the grid is empty, output zeros. + const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast(1)); // e.g. = 4 + + T output_val = 0.; + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + output_val += val; + } + } + output_val /= count; + + output[index] = output_val; + } +} + +#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_align_" #DTYPE)]] \ +kernel void roi_align( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + constant int64_t & output_size [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & pooled_height [[buffer(7)]], \ + constant int64_t & pooled_width [[buffer(8)]], \ + constant int64_t & sampling_ratio [[buffer(9)]], \ + constant bool & aligned [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void roi_align_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * grad_input [[buffer(2)]], + constant int64_t & output_size [[buffer(3)]], + constant int64_t & channels [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], + constant int64_t & pooled_height [[buffer(7)]], + constant int64_t & pooled_width [[buffer(8)]], + constant int64_t & sampling_ratio [[buffer(9)]], + constant bool & aligned [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + constant int64_t & n_stride [[buffer(12)]], + constant int64_t & c_stride [[buffer(13)]], + constant int64_t & h_stride [[buffer(14)]], + constant int64_t & w_stride [[buffer(15)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We need to index the gradient using the tensor strides to access the + // correct values. + const integer_t output_offset = n * n_stride + c * c_stride; + constant T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = + offset_grad_output[ph * h_stride + pw * w_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + integer_t roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + integer_t roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + const integer_t input_offset = (roi_batch_ind * channels + c) * height * width; + + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + integer_t x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomic_add_float(grad_input + input_offset + y_low * width + x_low, static_cast(g1)); + atomic_add_float(grad_input + input_offset + y_low * width + x_high, static_cast(g2)); + atomic_add_float(grad_input + input_offset + y_high * width + x_low, static_cast(g3)); + atomic_add_float(grad_input + input_offset + y_high * width + x_high, static_cast(g4)); + + } // if + } // ix + } // iy + } // MPS_1D_KERNEL_LOOP +} + +#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_align_backward_" #DTYPE)]] \ +kernel void roi_align_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * grad_input [[buffer(2)]], \ + constant int64_t & output_size [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & pooled_height [[buffer(7)]], \ + constant int64_t & pooled_width [[buffer(8)]], \ + constant int64_t & sampling_ratio [[buffer(9)]], \ + constant bool & aligned [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + constant int64_t & n_stride [[buffer(12)]], \ + constant int64_t & c_stride [[buffer(13)]], \ + constant int64_t & h_stride [[buffer(14)]], \ + constant int64_t & w_stride [[buffer(15)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void roi_pool( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * argmax [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant float & spatial_scale [[buffer(10)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + integer_t roi_start_w = round(offset_rois[1] * spatial_scale); + integer_t roi_start_h = round(offset_rois[2] * spatial_scale); + integer_t roi_end_w = round(offset_rois[3] * spatial_scale); + integer_t roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + integer_t roi_width = max(roi_end_w - roi_start_w + 1, static_cast(1)); + integer_t roi_height = max(roi_end_h - roi_start_h + 1, static_cast(1)); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height)); + hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height)); + wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width)); + wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width)); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + T maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + integer_t maxidx = -1; + constant T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + for (integer_t h = hstart; h < hend; ++h) { + for (integer_t w = wstart; w < wend; ++w) { + integer_t input_index = h * width + w; + if (offset_input[input_index] > maxval) { + maxval = offset_input[input_index]; + maxidx = input_index; + } + } + } + output[index] = maxval; + argmax[index] = maxidx; + } +} + +#define REGISTER_ROI_POOL_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_pool_" #DTYPE)]] \ +kernel void roi_pool( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * argmax_data [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant float & spatial_scale [[buffer(10)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void roi_pool_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * argmax_data [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant float & spatial_scale [[buffer(10)]], + constant int64_t & n_stride [[buffer(11)]], + constant int64_t & c_stride [[buffer(12)]], + constant int64_t & h_stride [[buffer(13)]], + constant int64_t & w_stride [[buffer(14)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + const integer_t output_offset = n * n_stride + c * c_stride; + constant integer_t * argmax_data_offset = + argmax_data + (n * channels + c) * pooled_height * pooled_width; + const integer_t argmax = argmax_data_offset[ph * pooled_width + pw]; + const integer_t offset = (roi_batch_ind * channels + c) * height * width; + + if (argmax != -1) { + atomic_add_float(grad_input + offset + argmax, static_cast(grad_output[output_offset + ph * h_stride + pw * w_stride])); + } + + } // MPS_1D_KERNEL_LOOP +} + +#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_pool_backward_" #DTYPE)]] \ +kernel void roi_pool_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * argmax_data [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant float & spatial_scale [[buffer(10)]], \ + constant int64_t & n_stride [[buffer(11)]], \ + constant int64_t & c_stride [[buffer(12)]], \ + constant int64_t & h_stride [[buffer(13)]], \ + constant int64_t & w_stride [[buffer(14)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void ps_roi_align( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * channel_mapping [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & sampling_ratio [[buffer(10)]], + constant int64_t & channels_out [[buffer(11)]], + constant float & spatial_scale [[buffer(12)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c_out, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c_out = (index / pooled_width / pooled_height) % channels_out; + integer_t n = index / pooled_width / pooled_height / channels_out; + + // (n, c_in, ph, pw) is the associated element in the input + integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw; + + // [start, end) interval for spatial sampling + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + // We use roi_bin_grid to sample the grid and mimic integral + integer_t roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + integer_t roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + constant T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + T out_sum = 0; + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + out_sum += val; + } + } + + out_sum /= count; + output[index] = out_sum; + channel_mapping[index] = c_in; + } +} + +#define REGISTER_PS_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_align_" #DTYPE)]] \ +kernel void ps_roi_align( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * channel_mapping [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & sampling_ratio [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void ps_roi_align_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * channel_mapping [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & sampling_ratio [[buffer(10)]], + constant int64_t & channels_out [[buffer(11)]], + constant float & spatial_scale [[buffer(12)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, *, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t n = index / pooled_width / pooled_height / channels_out; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + // Force too small ROIs to be 1x1 + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + integer_t c_in = channel_mapping[index]; + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + const T grad_output_this_bin = grad_output[index]; + + // We use roi_bin_grid to sample the grid and mimic integral + integer_t roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + integer_t roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + const integer_t offset = (roi_batch_ind * channels + c_in) * height * width; + + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + integer_t x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomic_add_float(grad_input + offset + y_low * width + x_low, static_cast(g1)); + atomic_add_float(grad_input + offset + y_low * width + x_high, static_cast(g2)); + atomic_add_float(grad_input + offset + y_high * width + x_low, static_cast(g3)); + atomic_add_float(grad_input + offset + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } +} + +#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_align_backward_" #DTYPE)]] \ +kernel void ps_roi_align_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * channel_mapping [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & sampling_ratio [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void ps_roi_pool( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * channel_mapping [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & channels_out [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c_out, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c_out = (index / (pooled_width * pooled_height)) % channels_out; + integer_t n = index / pooled_width / pooled_height / channels_out; + + // (n, c_in, ph, pw) is the associated element in the input + integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw; + + // [start, end) interval for spatial sampling + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + integer_t roi_start_w = round(offset_rois[1] * spatial_scale); + integer_t roi_start_h = round(offset_rois[2] * spatial_scale); + integer_t roi_end_w = round(offset_rois[3] * spatial_scale); + integer_t roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + integer_t roi_width = max(roi_end_w - roi_start_w, static_cast(1)); + integer_t roi_height = max(roi_end_h - roi_start_h, static_cast(1)); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height - 1)); + hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height - 1)); + wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width - 1)); + wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width - 1)); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + constant T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + T out_sum = 0; + for (integer_t h = hstart; h < hend; ++h) { + for (integer_t w = wstart; w < wend; ++w) { + integer_t input_index = h * width + w; + out_sum += offset_input[input_index]; + } + } + + T bin_area = (hend - hstart) * (wend - wstart); + output[index] = is_empty ? static_cast(0) : out_sum / bin_area; + channel_mapping[index] = c_in; + } +} + +#define REGISTER_PS_ROI_POOL_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_pool_" #DTYPE)]] \ +kernel void ps_roi_pool( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * channel_mapping [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & channels_out [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void ps_roi_pool_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * channel_mapping [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & channels_out [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, *, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t n = index / pooled_width / pooled_height / channels_out; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + integer_t roi_start_w = round(offset_rois[1] * spatial_scale); + integer_t roi_start_h = round(offset_rois[2] * spatial_scale); + integer_t roi_end_w = round(offset_rois[3] * spatial_scale); + integer_t roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + integer_t roi_width = max(roi_end_w - roi_start_w, static_cast(1)); + integer_t roi_height = max(roi_end_h - roi_start_h, static_cast(1)); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height)); + hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height)); + wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width)); + wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width)); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + integer_t c_in = channel_mapping[index]; + T bin_area = (hend - hstart) * (wend - wstart); + T diff_val = is_empty ? static_cast(0) : grad_output[index] / bin_area; + + const integer_t offset = (roi_batch_ind * channels + c_in) * height * width; + + for (integer_t h = hstart; h < hend; ++h) { + for (integer_t w = wstart; w < wend; ++w) { + integer_t grad_input_index = h * width + w; + atomic_add_float(grad_input + offset + grad_input_index, diff_val); + } + } + + } // MPS_1D_KERNEL_LOOP +} + +#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_pool_backward_" #DTYPE)]] \ +kernel void ps_roi_pool_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * channel_mapping [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & channels_out [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +REGISTER_NMS_OP(float); +REGISTER_NMS_OP(half); +REGISTER_ROI_ALIGN_OP(float, int64_t); +REGISTER_ROI_ALIGN_OP(half, int64_t); +REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t); +REGISTER_ROI_ALIGN_BACKWARD_OP(half, int64_t); +REGISTER_ROI_POOL_OP(float, int64_t); +REGISTER_ROI_POOL_OP(half, int64_t); +REGISTER_ROI_POOL_BACKWARD_OP(float, int64_t); +REGISTER_ROI_POOL_BACKWARD_OP(half, int64_t); +REGISTER_PS_ROI_ALIGN_OP(float, int64_t); +REGISTER_PS_ROI_ALIGN_OP(half, int64_t); +REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float, int64_t); +REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half, int64_t); +REGISTER_PS_ROI_POOL_OP(float, int64_t); +REGISTER_PS_ROI_POOL_OP(half, int64_t); +REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t); +REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t); + +)VISION_METAL"; + +static id compileVisionOpsLibrary(id device) { + static id visionLibrary = nil; + if (visionLibrary) { + return visionLibrary; + } + + NSError* error = nil; + MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion:MTLLanguageVersion2_3]; + visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding] + options:options + error:&error]; + TORCH_CHECK(visionLibrary, "Failed to create metal vision library, error: ", [[error description] UTF8String]); + return visionLibrary; +} + +static id visionPipelineState(id device, const std::string& kernel) { + static std::unordered_map> psoCache; + id pso = psoCache[kernel]; + if (pso) { + return pso; + } + + NSError* error = nil; + id visionLib = compileVisionOpsLibrary(device); + id visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]]; + TORCH_CHECK(visionFunc, "Failed to create function state object for: ", kernel); + pso = [device newComputePipelineStateWithFunction:visionFunc error:&error]; + TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); + + psoCache[kernel] = pso; + return pso; +} + +} // namespace mps +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/mps/nms_kernel.mm b/product/include/torchvision/ops/mps/nms_kernel.mm new file mode 100644 index 00000000000..5ee9b5cbeae --- /dev/null +++ b/product/include/torchvision/ops/mps/nms_kernel.mm @@ -0,0 +1,109 @@ +#include +#include +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +// This should be in sync with `nmsThreadsPerBlock` in the metal kernel. +constexpr int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8; + +at::Tensor nms_kernel(const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { + using namespace at::native::mps; + TORCH_CHECK(dets.is_mps(), "dets must be a MPS tensor"); + TORCH_CHECK(scores.is_mps(), "scores must be a MPS tensor"); + + TORCH_CHECK(dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK(dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1)); + TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D"); + TORCH_CHECK(dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)) + + if (dets.numel() == 0) { + return at::empty({0}, dets.options().dtype(at::kLong)); + } + + auto order_t = std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + auto dets_sorted = dets.index_select(0, order_t).contiguous(); + int64_t dets_num = dets.size(0); + float iou_threshold_f = static_cast(iou_threshold); + + const int col_blocks = (dets_num + nmsThreadsPerBlock - 1) / nmsThreadsPerBlock; + at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + + id inputBuffer = getMTLBufferStorage(dets_sorted); + id outputBuffer = getMTLBufferStorage(mask); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(col_blocks, col_blocks, 1); + + const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {dets, scores}); + + [computeEncoder setComputePipelineState:visionPSO]; + [computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0]; + [computeEncoder setBuffer:outputBuffer offset:mask.storage_offset() * mask.element_size() atIndex:1]; + [computeEncoder setBytes:&dets_num length:sizeof(int64_t) atIndex:2]; + [computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > nmsThreadsPerBlock) { + tgSize = nmsThreadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + + int64_t num_to_keep = 0; + + at::Tensor mask_cpu = mask.to(at::kCPU); + unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr(); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); + + at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); + int64_t* keep_out = keep.data_ptr(); + + for (int64_t i = 0; i < dets_num; i++) { + int64_t nblock = i / nmsThreadsPerBlock; + int64_t inblock = i % nmsThreadsPerBlock; + + if (!(remv[nblock] & (1ULL << inblock))) { + keep_out[num_to_keep++] = i; + unsigned long long* p = mask_host + i * col_blocks; + for (int64_t j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } + } + + return order_t.index( + {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(order_t.device(), keep.scalar_type())}); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/mps/ps_roi_align_kernel.mm b/product/include/torchvision/ops/mps/ps_roi_align_kernel.mm new file mode 100644 index 00000000000..16b711ad5ef --- /dev/null +++ b/product/include/torchvision/ops/mps/ps_roi_align_kernel.mm @@ -0,0 +1,205 @@ +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +std::tuple ps_roi_align_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_align_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + TORCH_CHECK(channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + + int64_t channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong)); + + int64_t output_size = output.numel(); + + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor ps_roi_align_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_align backward with float16 inputs."); + TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); + + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "ps_roi_align_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t output_size = grad.numel(); + int64_t channels_out = channels / (pooled_height * pooled_width); + + at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel"); + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad_); + id roisBuffer = getMTLBufferStorage(rois_); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), TORCH_FN(ps_roi_align_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), TORCH_FN(ps_roi_align_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/mps/ps_roi_pool_kernel.mm b/product/include/torchvision/ops/mps/ps_roi_pool_kernel.mm new file mode 100644 index 00000000000..fc24f6990fa --- /dev/null +++ b/product/include/torchvision/ops/mps/ps_roi_pool_kernel.mm @@ -0,0 +1,200 @@ +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +std::tuple ps_roi_pool_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_pool_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + TORCH_CHECK(channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int64_t channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong)); + auto output_size = output.numel(); + + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "ps_roi_pool_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor ps_roi_pool_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_pool backward with float16 inputs."); + TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "ps_roi_pool_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + auto num_rois = rois.size(0); + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t channels_out = channels / (pooled_height * pooled_width); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("ps_roi_pool_backward_kernel"); + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad_); + id roisBuffer = getMTLBufferStorage(rois_); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "ps_roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad_, rois_, channel_mapping}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), TORCH_FN(ps_roi_pool_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), TORCH_FN(ps_roi_pool_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/mps/roi_align_kernel.mm b/product/include/torchvision/ops/mps/roi_align_kernel.mm new file mode 100644 index 00000000000..d4ed8b43fd2 --- /dev/null +++ b/product/include/torchvision/ops/mps/roi_align_kernel.mm @@ -0,0 +1,197 @@ +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +at::Tensor roi_align_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + + int64_t output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return output; + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return output; +} + +at::Tensor roi_align_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned) { + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_align backward with float16 inputs."); + + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t n_stride = grad.stride(0); + int64_t c_stride = grad.stride(1); + int64_t h_stride = grad.stride(2); + int64_t w_stride = grad.stride(3); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("roi_align_backward_kernel"); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_align"), TORCH_FN(roi_align_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), TORCH_FN(roi_align_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/mps/roi_pool_kernel.mm b/product/include/torchvision/ops/mps/roi_pool_kernel.mm new file mode 100644 index 00000000000..816d8d70863 --- /dev/null +++ b/product/include/torchvision/ops/mps/roi_pool_kernel.mm @@ -0,0 +1,196 @@ +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +std::tuple roi_pool_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_pool_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kLong)); + + int64_t output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return std::make_tuple(output, argmax); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id argmaxBuffer = getMTLBufferStorage(argmax); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "roi_pool_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:argmaxBuffer offset:argmax.storage_offset() * argmax.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return std::make_tuple(output, argmax); +} + +at::Tensor roi_pool_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_pool backward with float16 inputs."); + TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor"); + + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3}; + + at::CheckedFrom c = "roi_pool_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t n_stride = grad.stride(0); + int64_t c_stride = grad.stride(1); + int64_t h_stride = grad.stride(2); + int64_t w_stride = grad.stride(3); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("roi_pool_backward_kernel"); + auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad); + id roisBuffer = getMTLBufferStorage(rois_); + id argmaxBuffer = getMTLBufferStorage(argmax_); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_, argmax_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; + [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_pool"), TORCH_FN(roi_pool_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), TORCH_FN(roi_pool_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/nms.cpp b/product/include/torchvision/ops/nms.cpp new file mode 100644 index 00000000000..5ecf8812f1b --- /dev/null +++ b/product/include/torchvision/ops/nms.cpp @@ -0,0 +1,28 @@ +#include "nms.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +at::Tensor nms( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.nms.nms"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::nms", "") + .typed(); + return op.call(dets, scores, iou_threshold); +} + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.set_python_module("torchvision._meta_registrations"); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/nms.h b/product/include/torchvision/ops/nms.h new file mode 100644 index 00000000000..8c75a242bff --- /dev/null +++ b/product/include/torchvision/ops/nms.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API at::Tensor nms( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold); + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/ops.h b/product/include/torchvision/ops/ops.h new file mode 100644 index 00000000000..77995e44197 --- /dev/null +++ b/product/include/torchvision/ops/ops.h @@ -0,0 +1,8 @@ +#pragma once + +#include "deform_conv2d.h" +#include "nms.h" +#include "ps_roi_align.h" +#include "ps_roi_pool.h" +#include "roi_align.h" +#include "roi_pool.h" diff --git a/product/include/torchvision/ops/ps_roi_align.cpp b/product/include/torchvision/ops/ps_roi_align.cpp new file mode 100644 index 00000000000..de458c0d62d --- /dev/null +++ b/product/include/torchvision/ops/ps_roi_align.cpp @@ -0,0 +1,112 @@ +#include "ps_roi_align.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +std::tuple ps_roi_align( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_align", "") + .typed(); + return op.call( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + +std::tuple ps_roi_align_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_align", "") + .typed(); + return op.call( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + +namespace detail { + +at::Tensor _ps_roi_align_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") + .typed(); + return op.call( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width); +} + +at::Tensor _ps_roi_align_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") + .typed(); + return op.call( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width); +} + +} // namespace detail + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio) -> (Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/ps_roi_align.h b/product/include/torchvision/ops/ps_roi_align.h new file mode 100644 index 00000000000..75650586bc6 --- /dev/null +++ b/product/include/torchvision/ops/ps_roi_align.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API std::tuple ps_roi_align( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio); + +VISION_API std::tuple ps_roi_align_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio); + +namespace detail { + +at::Tensor _ps_roi_align_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); + +at::Tensor _ps_roi_align_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width); + +} // namespace detail + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/ps_roi_pool.cpp b/product/include/torchvision/ops/ps_roi_pool.cpp new file mode 100644 index 00000000000..92469d5e380 --- /dev/null +++ b/product/include/torchvision/ops/ps_roi_pool.cpp @@ -0,0 +1,104 @@ +#include "ps_roi_pool.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +std::tuple ps_roi_pool( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + +std::tuple ps_roi_pool_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + +namespace detail { + +at::Tensor _ps_roi_pool_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") + .typed(); + return op.call( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +} + +at::Tensor _ps_roi_pool_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") + .typed(); + return op.call( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +} + +} // namespace detail + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/ps_roi_pool.h b/product/include/torchvision/ops/ps_roi_pool.h new file mode 100644 index 00000000000..4a3cc54e0e5 --- /dev/null +++ b/product/include/torchvision/ops/ps_roi_pool.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API std::tuple ps_roi_pool( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +VISION_API std::tuple ps_roi_pool_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width); + +namespace detail { + +at::Tensor _ps_roi_pool_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); + +at::Tensor _ps_roi_pool_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width); + +} // namespace detail + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/roi_align.cpp b/product/include/torchvision/ops/roi_align.cpp new file mode 100644 index 00000000000..aa6dccb44f2 --- /dev/null +++ b/product/include/torchvision/ops/roi_align.cpp @@ -0,0 +1,132 @@ +#include "roi_align.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +at::Tensor roi_align( + const at::Tensor& input, // Input feature map. + const at::Tensor& rois, // List of ROIs to pool over. + double spatial_scale, // The scale of the image features. ROIs will be + // scaled to this. + int64_t pooled_height, // The height of the pooled feature map. + int64_t pooled_width, // The width of the pooled feature + int64_t sampling_ratio, // The number of points to sample in each bin + bool aligned) // The flag for pixel shift +// along each axis. +{ + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::roi_align", "") + .typed(); + return op.call( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + aligned); +} + +at::Tensor roi_align_symint( + const at::Tensor& input, // Input feature map. + const at::Tensor& rois, // List of ROIs to pool over. + double spatial_scale, // The scale of the image features. ROIs will be + // scaled to this. + c10::SymInt pooled_height, // The height of the pooled feature map. + c10::SymInt pooled_width, // The width of the pooled feature + int64_t sampling_ratio, // The number of points to sample in each bin + bool aligned) // The flag for pixel shift +// along each axis. +{ + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::roi_align", "") + .typed(); + return op.call( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + aligned); +} + +namespace detail { + +at::Tensor _roi_align_backward( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_roi_align_backward", "") + .typed(); + return op.call( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio, + aligned); +} + +at::Tensor _roi_align_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width, + int64_t sampling_ratio, + bool aligned) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_roi_align_backward", "") + .typed(); + return op.call( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio, + aligned); +} + +} // namespace detail + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width, int sampling_ratio, bool aligned) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/roi_align.h b/product/include/torchvision/ops/roi_align.h new file mode 100644 index 00000000000..072d6d4231c --- /dev/null +++ b/product/include/torchvision/ops/roi_align.h @@ -0,0 +1,58 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API at::Tensor roi_align( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned); + +VISION_API at::Tensor roi_align_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + bool aligned); + +namespace detail { + +at::Tensor _roi_align_backward( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned); + +at::Tensor _roi_align_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width, + int64_t sampling_ratio, + bool aligned); + +} // namespace detail + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/roi_pool.cpp b/product/include/torchvision/ops/roi_pool.cpp new file mode 100644 index 00000000000..20ca3ca91e7 --- /dev/null +++ b/product/include/torchvision/ops/roi_pool.cpp @@ -0,0 +1,102 @@ +#include "roi_pool.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +std::tuple roi_pool( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + +std::tuple roi_pool_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + +namespace detail { + +at::Tensor _roi_pool_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_roi_pool_backward", "") + .typed(); + return op.call( + grad, + rois, + argmax, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +} + +at::Tensor _roi_pool_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_roi_pool_backward", "") + .typed(); + return op.call( + grad, + rois, + argmax, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +} + +} // namespace detail + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/ops/roi_pool.h b/product/include/torchvision/ops/roi_pool.h new file mode 100644 index 00000000000..e2133240f4f --- /dev/null +++ b/product/include/torchvision/ops/roi_pool.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API std::tuple roi_pool( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +VISION_API std::tuple roi_pool_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width); + +namespace detail { + +at::Tensor _roi_pool_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); + +at::Tensor _roi_pool_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width); + +} // namespace detail + +} // namespace ops +} // namespace vision diff --git a/product/include/torchvision/vision.cpp b/product/include/torchvision/vision.cpp new file mode 100644 index 00000000000..806e870a83f --- /dev/null +++ b/product/include/torchvision/vision.cpp @@ -0,0 +1,32 @@ +#include "vision.h" + +#include + +#ifdef WITH_CUDA +#include +#endif +#ifdef WITH_HIP +#include +#endif + +// If we are in a Windows environment, we need to define +// initialization functions for the _custom_ops extension. +#if !defined(MOBILE) && defined(_WIN32) +void* PyInit__C(void) { + return nullptr; +} +#endif // !defined(MOBILE) && defined(_WIN32) + +namespace vision { +int64_t cuda_version() { +#ifdef WITH_CUDA + return CUDA_VERSION; +#else + return -1; +#endif +} + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def("_cuda_version", &cuda_version); +} +} // namespace vision diff --git a/product/include/torchvision/vision.h b/product/include/torchvision/vision.h new file mode 100644 index 00000000000..651ef3ca143 --- /dev/null +++ b/product/include/torchvision/vision.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include "macros.h" + +namespace vision { +VISION_API int64_t cuda_version(); + +namespace detail { +extern "C" inline auto _register_ops = &cuda_version; +} // namespace detail +} // namespace vision diff --git a/product/share/cmake/TorchVision/TorchVisionConfig.cmake b/product/share/cmake/TorchVision/TorchVisionConfig.cmake new file mode 100644 index 00000000000..57b2b6caab7 --- /dev/null +++ b/product/share/cmake/TorchVision/TorchVisionConfig.cmake @@ -0,0 +1,74 @@ +# TorchVisionConfig.cmake +# -------------------- +# +# Exported targets:: Vision +# + + +####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() ####### +####### Any changes to this file will be overwritten by the next CMake run #### +####### The input file was TorchVisionConfig.cmake.in ######## + +get_filename_component(PACKAGE_${CMAKE_FIND_PACKAGE_NAME}_COUNTER_1 "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) + +macro(set_and_check _var _file) + set(${_var} "${_file}") + if(NOT EXISTS "${_file}") + message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !") + endif() +endmacro() + +macro(check_required_components _NAME) + foreach(comp ${${_NAME}_FIND_COMPONENTS}) + if(NOT ${_NAME}_${comp}_FOUND) + if(${_NAME}_FIND_REQUIRED_${comp}) + set(${_NAME}_FOUND FALSE) + endif() + endif() + endforeach() +endmacro() + +#################################################################################### + +set(PN TorchVision) + +# location of include/torchvision +set(${PN}_INCLUDE_DIR "${PACKAGE_PREFIX_DIR}/include") + +set(${PN}_LIBRARY "") +set(${PN}_DEFINITIONS USING_${PN}) + +check_required_components(${PN}) + + +if(NOT (CMAKE_VERSION VERSION_LESS 3.0)) +#----------------------------------------------------------------------------- +# Don't include targets if this file is being picked up by another +# project which has already built this as a subproject +#----------------------------------------------------------------------------- +if(NOT TARGET ${PN}::${PN}) +include("${CMAKE_CURRENT_LIST_DIR}/${PN}Targets.cmake") + +target_include_directories(${PN}::${PN} INTERFACE "${${PN}_INCLUDE_DIR}") + +if(OFF) + target_compile_definitions(${PN}::${PN} INTERFACE WITH_CUDA) +endif() + +find_package(Torch REQUIRED) +target_link_libraries(${PN}::${PN} INTERFACE torch) + +if(ON) + find_package(PNG REQUIRED) + target_link_libraries(${PN}::${PN} INTERFACE ${PNG_LIBRARY}) + target_compile_definitions(${PN}::${PN} INTERFACE PNG_FOUND) +endif() + +if(ON) + find_package(JPEG REQUIRED) + target_link_libraries(${PN}::${PN} INTERFACE ${JPEG_LIBRARIES}) + target_compile_definitions(${PN}::${PN} INTERFACE JPEG_FOUND) +endif() + +endif() +endif() diff --git a/product/share/cmake/TorchVision/TorchVisionConfigVersion.cmake b/product/share/cmake/TorchVision/TorchVisionConfigVersion.cmake new file mode 100644 index 00000000000..94b7114a138 --- /dev/null +++ b/product/share/cmake/TorchVision/TorchVisionConfigVersion.cmake @@ -0,0 +1,43 @@ +# This is a basic version file for the Config-mode of find_package(). +# It is used by write_basic_package_version_file() as input file for configure_file() +# to create a version-file which can be installed along a config.cmake file. +# +# The created file sets PACKAGE_VERSION_EXACT if the current version string and +# the requested version string are exactly the same and it sets +# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version. +# The variable CVF_VERSION must be set before calling configure_file(). + +set(PACKAGE_VERSION "0.20.0a0") + +if (PACKAGE_FIND_VERSION_RANGE) + # Package version must be in the requested version range + if ((PACKAGE_FIND_VERSION_RANGE_MIN STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MIN) + OR ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_GREATER PACKAGE_FIND_VERSION_MAX) + OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND PACKAGE_VERSION VERSION_GREATER_EQUAL PACKAGE_FIND_VERSION_MAX))) + set(PACKAGE_VERSION_COMPATIBLE FALSE) + else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + endif() +else() + if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION) + set(PACKAGE_VERSION_COMPATIBLE FALSE) + else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION) + set(PACKAGE_VERSION_EXACT TRUE) + endif() + endif() +endif() + + +# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it: +if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "") + return() +endif() + +# check that the installed version has the same 32/64bit-ness as the one which is currently searching: +if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8") + math(EXPR installedBits "8 * 8") + set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)") + set(PACKAGE_VERSION_UNSUITABLE TRUE) +endif() diff --git a/product/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake b/product/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake new file mode 100644 index 00000000000..91aa482bb9c --- /dev/null +++ b/product/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake @@ -0,0 +1,20 @@ +#---------------------------------------------------------------- +# Generated CMake target import file. +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Import target "TorchVision::TorchVision" for configuration "" +set_property(TARGET TorchVision::TorchVision APPEND PROPERTY IMPORTED_CONFIGURATIONS NOCONFIG) +set_target_properties(TorchVision::TorchVision PROPERTIES + IMPORTED_LINK_DEPENDENT_LIBRARIES_NOCONFIG "torch" + IMPORTED_LOCATION_NOCONFIG "${_IMPORT_PREFIX}/lib/libtorchvision.dylib" + IMPORTED_SONAME_NOCONFIG "@rpath/libtorchvision.dylib" + ) + +list(APPEND _cmake_import_check_targets TorchVision::TorchVision ) +list(APPEND _cmake_import_check_files_for_TorchVision::TorchVision "${_IMPORT_PREFIX}/lib/libtorchvision.dylib" ) + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) diff --git a/product/share/cmake/TorchVision/TorchVisionTargets.cmake b/product/share/cmake/TorchVision/TorchVisionTargets.cmake new file mode 100644 index 00000000000..1e07b7fc626 --- /dev/null +++ b/product/share/cmake/TorchVision/TorchVisionTargets.cmake @@ -0,0 +1,102 @@ +# Generated by CMake + +if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8) + message(FATAL_ERROR "CMake >= 2.8.0 required") +endif() +if(CMAKE_VERSION VERSION_LESS "2.8.3") + message(FATAL_ERROR "CMake >= 2.8.3 required") +endif() +cmake_policy(PUSH) +cmake_policy(VERSION 2.8.3...3.27) +#---------------------------------------------------------------- +# Generated CMake target import file. +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Protect against multiple inclusion, which would fail when already imported targets are added once more. +set(_cmake_targets_defined "") +set(_cmake_targets_not_defined "") +set(_cmake_expected_targets "") +foreach(_cmake_expected_target IN ITEMS TorchVision::TorchVision) + list(APPEND _cmake_expected_targets "${_cmake_expected_target}") + if(TARGET "${_cmake_expected_target}") + list(APPEND _cmake_targets_defined "${_cmake_expected_target}") + else() + list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}") + endif() +endforeach() +unset(_cmake_expected_target) +if(_cmake_targets_defined STREQUAL _cmake_expected_targets) + unset(_cmake_targets_defined) + unset(_cmake_targets_not_defined) + unset(_cmake_expected_targets) + unset(CMAKE_IMPORT_FILE_VERSION) + cmake_policy(POP) + return() +endif() +if(NOT _cmake_targets_defined STREQUAL "") + string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}") + string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}") + message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n") +endif() +unset(_cmake_targets_defined) +unset(_cmake_targets_not_defined) +unset(_cmake_expected_targets) + + +# Compute the installation prefix relative to this file. +get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +if(_IMPORT_PREFIX STREQUAL "/") + set(_IMPORT_PREFIX "") +endif() + +# Create imported target TorchVision::TorchVision +add_library(TorchVision::TorchVision SHARED IMPORTED) + +# Load information for each installed configuration. +file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/TorchVisionTargets-*.cmake") +foreach(_cmake_config_file IN LISTS _cmake_config_files) + include("${_cmake_config_file}") +endforeach() +unset(_cmake_config_file) +unset(_cmake_config_files) + +# Cleanup temporary variables. +set(_IMPORT_PREFIX) + +# Loop over all imported files and verify that they actually exist +foreach(_cmake_target IN LISTS _cmake_import_check_targets) + if(CMAKE_VERSION VERSION_LESS "3.28" + OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target} + OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}") + foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}") + if(NOT EXISTS "${_cmake_file}") + message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file + \"${_cmake_file}\" +but this file does not exist. Possible reasons include: +* The file was deleted, renamed, or moved to another location. +* An install or uninstall procedure did not complete successfully. +* The installation package was faulty and contained + \"${CMAKE_CURRENT_LIST_FILE}\" +but not all the files it references. +") + endif() + endforeach() + endif() + unset(_cmake_file) + unset("_cmake_import_check_files_for_${_cmake_target}") +endforeach() +unset(_cmake_target) +unset(_cmake_import_check_targets) + +# This file does not depend on other imported targets which have +# been exported from the same project but in a separate export set. + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) +cmake_policy(POP) diff --git a/test/playground/test_mps_import.py b/test/playground/test_mps_import.py new file mode 100644 index 00000000000..8452fe75fde --- /dev/null +++ b/test/playground/test_mps_import.py @@ -0,0 +1,6 @@ +import torch +import torchvision as tv + + + +print(torch.backends.mps.is_available()) diff --git a/torchvision/installTest.cpp b/torchvision/installTest.cpp new file mode 100644 index 00000000000..6e9f1c90e50 --- /dev/null +++ b/torchvision/installTest.cpp @@ -0,0 +1,5 @@ +#include +#include + + + From d838bf791c1090c65164b04191f719821af4bcf8 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Thu, 10 Oct 2024 16:36:59 +0200 Subject: [PATCH 02/31] Setting up for development --- .DS_Store | Bin 6148 -> 6148 bytes .gitignore | 6 ++++++ CMakePresets.json | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+) create mode 100644 CMakePresets.json diff --git a/.DS_Store b/.DS_Store index 1499086fa3f99cdf0adf902c657c89b9a25ecaf6..bb0be2e98d6e1941f8401a7b7ae0215a03c0ab17 100644 GIT binary patch delta 234 zcmZoMXfc@JFUrfnz`)4BAi%(o%23Xb&rrmWos+&o%1nFC~{G89cVWD%|>pe8rp#U&{xKM80TN72_6d%0qcIRe=Pb)^ug aO9a_L1sR6H$@#ej8w(8>H?wp6nx-^C> Date: Fri, 25 Oct 2024 19:17:42 +0200 Subject: [PATCH 03/31] Initial commit for deform_conv2d for MPS --- Deform_conv2d_kernals.metal | 11 ++++++++ build_xcode/VisionTests/VisionTests.mm | 36 ++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 Deform_conv2d_kernals.metal create mode 100644 build_xcode/VisionTests/VisionTests.mm diff --git a/Deform_conv2d_kernals.metal b/Deform_conv2d_kernals.metal new file mode 100644 index 00000000000..2b528c54007 --- /dev/null +++ b/Deform_conv2d_kernals.metal @@ -0,0 +1,11 @@ +// +// Deform_conv2d_kernals.metal +// torchvision +// +// Created by Thomas Martin on 14/10/2024. +// + +#include +using namespace metal; + + diff --git a/build_xcode/VisionTests/VisionTests.mm b/build_xcode/VisionTests/VisionTests.mm new file mode 100644 index 00000000000..62336a1b3d5 --- /dev/null +++ b/build_xcode/VisionTests/VisionTests.mm @@ -0,0 +1,36 @@ +// +// VisionTests.m +// VisionTests +// +// Created by Thomas Martin on 12/10/2024. +// + +#import + +@interface VisionTests : XCTestCase + +@end + +@implementation VisionTests + +- (void)setUp { + // Put setup code here. This method is called before the invocation of each test method in the class. +} + +- (void)tearDown { + // Put teardown code here. This method is called after the invocation of each test method in the class. +} + +- (void)testExample { + // This is an example of a functional test case. + // Use XCTAssert and related functions to verify your tests produce the correct results. +} + +- (void)testPerformanceExample { + // This is an example of a performance test case. + [self measureBlock:^{ + // Put the code you want to measure the time of here. + }]; +} + +@end From c53e1bd0dac1285c933d05300163679b70dfbab2 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Tue, 12 Nov 2024 21:18:14 +0100 Subject: [PATCH 04/31] New mps kernel for deform_conv2d and updated shader functions in kernel.h --- .../csrc/ops/mps/deform_conv2d_kernal.mm | 934 ++++++++++++++++++ torchvision/csrc/ops/mps/mps_kernels.h | 433 +++++++- 2 files changed, 1366 insertions(+), 1 deletion(-) create mode 100644 torchvision/csrc/ops/mps/deform_conv2d_kernal.mm diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernal.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernal.mm new file mode 100644 index 00000000000..2df529f25a2 --- /dev/null +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernal.mm @@ -0,0 +1,934 @@ +// vision::ops:: +// deform_conv2d_kernal.mm +// + +#include +#include +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + + +namespace vision { +namespace ops { + +namespace { + +const int64_t tkMaxParallelImgs = 32; + + +void deformable_im2col(const at::Tensor& input, + const at::Tensor& data_offset, + const at::Tensor& data_mask, + int64_t n_in_channels, + int64_t height, + int64_t width, + int64_t weight_h, + int64_t weight_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t out_h, + int64_t out_w, + int64_t parallel_imgs, + int64_t deformable_group, + bool use_mask, + at::Tensor data_col) { + using namespace at::native::mps; + + // Validate tensors as of type mps. + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(data_offset.is_mps(), "data_offset must be a MPS tensor"); + TORCH_CHECK(data_mask.is_mps(), "data_mask must be a MPS tensor"); + + at::TensorArg input_t{input, "input", 1}, + data_offset_t{data_offset, "data_offset", 2}, + data_mask_t{data_mask, "data_mask", 3}; + + at::CheckedFrom c = "deformable_im2col"; + at::checkAllSameGPU(c, {input_t, data_offset_t, data_mask_t}); + at::checkAllSameType(c, {input_t, data_offset_t, data_mask_t}); + + + const int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs; + + // These function parameters have all been made contiguous by the caller function deform_conv2d_forward_kernel + // Check if it is safe to skip the following: + auto input_c = input.contiguous(); + auto data_offset_c = data_offset.contiguous(); + auto data_mask_c = data_mask.contiguous(); + + // Get a raw pointer to the underlying data structure of the tensors and cast it as a pointer to an MTLBuffer. + id inputBuffer = getMTLBufferStorage(input_c); + id data_offsetBuffer = getMTLBufferStorage(data_offset_c); + id data_maskBuffer = getMTLBufferStorage(data_mask_c); + id data_colBuffer = getMTLBufferStorage(data_col); + + id device = MPSDevice::getInstance()->device(); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + const std::string kernel = "deformable_im2col_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(num_kernels), + static_cast(512)), + static_cast(4096)), + 1, + 1); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_c, data_offset_c, data_mask_c}); + + id computeEncoder = mpsStream->commandEncoder(); + [computeEncoder setComputePipelineState:visionPSO]; + + [computeEncoder setBuffer:inputBuffer offset:input_c.storage_offset() * input_c.element_size() atIndex:1]; + [computeEncoder setBuffer:data_offsetBuffer offset:data_offset_c.storage_offset() * data_offset_c.element_size() atIndex:2]; + [computeEncoder setBuffer:data_maskBuffer offset:data_mask_c.storage_offset() * data_mask_c.element_size() atIndex:3]; + [computeEncoder setBuffer:data_colBuffer offset:data_col.storage_offset() * data_col.element_size() atIndex:20]; + + [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:0]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:&n_in_channels length:sizeof(int64_t) atIndex:15]; + [computeEncoder setBytes:&deformable_group length:sizeof(int64_t) atIndex:16]; + [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:17]; + [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:18]; + [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:19]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + +} + +int get_greatest_divisor_below_bound(int n, int bound) { + for (int k = bound; k > 1; --k) { + if (n % k == 0) { + return k; + } + } + return 1; +} + +void compute_grad_input( + const at::Tensor& columns, + const at::Tensor& offset, + const at::Tensor& mask, + int64_t channels, + int64_t height, + int64_t width, + int64_t weight_h, //kernel_h + int64_t weight_w, //kernel_w + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t parallel_imgs, //batch_sz + int64_t n_offset_grps, + bool use_mask, + at::Tensor grad_im) { + using namespace at::native::mps; + + at::globalContext().alertNotDeterministic("compute_grad_input"); + + auto columns_c = columns.contiguous(); + auto offset_c = offset.contiguous(); + auto mask_c = mask.contiguous(); + + const int64_t out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + const int64_t out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + + const int64_t num_kernels = + (int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs; + + id columnsBuffer = getMTLBufferStorage(columns_c); + id offsetBuffer = getMTLBufferStorage(offset_c); + id maskBuffer = getMTLBufferStorage(mask_c); + id grad_imBuffer = getMTLBufferStorage(grad_im); + + id device = MPSDevice::getInstance()->device(); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + + const std::string kernel = "deformable_col2im_" + scalarToMetalTypeString(columns.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {columns, offset, mask}); + + [computeEncoder setComputePipelineState:visionPSO]; + + [computeEncoder setBuffer:columnsBuffer offset:columns_c.storage_offset() * columns_c.element_size() atIndex:1]; + [computeEncoder setBuffer:offsetBuffer offset:offset_c.storage_offset() * offset_c.element_size() atIndex:2]; + [computeEncoder setBuffer:maskBuffer offset:mask_c.storage_offset() * mask_c.element_size() atIndex:3]; + [computeEncoder setBuffer:grad_imBuffer + offset:grad_im.storage_offset() * grad_im.element_size() + atIndex:20]; + + [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:0]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:15]; + [computeEncoder setBytes:&n_offset_grps length:sizeof(int64_t) atIndex:16]; + [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:17]; + [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:18]; + [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:19]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(num_kernels), static_cast(512)), static_cast(4096)), + 1, + 1); + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); +} + +void compute_grad_offset_and_mask( + const at::Tensor& columns, + const at::Tensor& input, + const at::Tensor& offset, + const at::Tensor& mask, + int64_t channels, + int64_t height, + int64_t width, + int64_t weight_h, + int64_t weight_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t parallel_imgs, + int64_t n_offset_grps, + bool use_mask, + at::Tensor grad_offset, + at::Tensor grad_mask) { + + using namespace at::native::mps; + + auto columns_c = columns; //.contiguous(); + auto input_c = input; //.contiguous(); + auto offset_c = offset; //.contiguous(); + auto mask_c = mask; //.contiguous(); + + const int64_t out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + const int64_t out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + const int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w * + n_offset_grps * parallel_imgs; + + const int64_t offset_channels = 2 * weight_h * weight_w * n_offset_grps; + + id columnsBuffer = getMTLBufferStorage(columns_c); + id inputBuffer = getMTLBufferStorage(input_c); + id offsetBuffer = getMTLBufferStorage(offset_c); + id maskBuffer = getMTLBufferStorage(mask_c); + id grad_offsetBuffer = getMTLBufferStorage(grad_offset); + id grad_maskBuffer = getMTLBufferStorage(grad_mask); + + id device = MPSDevice::getInstance()->device(); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(num_kernels), static_cast(512)), static_cast(4096)), 1, 1); + + const std::string kernel = "deformable_col2im_coord_" + scalarToMetalTypeString(columns.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {columns_c, input_c, offset_c, mask_c}); + + [computeEncoder setComputePipelineState:visionPSO]; + + [computeEncoder setBuffer:columnsBuffer offset:columns_c.storage_offset() * columns_c.element_size() atIndex:1]; + [computeEncoder setBuffer:inputBuffer offset:input_c.storage_offset() * input_c.element_size() atIndex:2]; + [computeEncoder setBuffer:offsetBuffer offset:offset_c.storage_offset() * offset_c.element_size() atIndex:3]; + [computeEncoder setBuffer:maskBuffer offset:mask_c.storage_offset() * mask_c.element_size() atIndex:4]; + [computeEncoder setBuffer:grad_offsetBuffer + offset:grad_offset.storage_offset() * grad_offset.element_size() + atIndex:22]; + [computeEncoder setBuffer:grad_maskBuffer + offset:grad_mask.storage_offset() * grad_mask.element_size() + atIndex:23]; + + [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:0]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:15]; + [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:16]; + [computeEncoder setBytes:&offset_channels length:sizeof(int64_t) atIndex:17]; + [computeEncoder setBytes:&n_offset_grps length:sizeof(int64_t) atIndex:18]; + [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:19]; + [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:20]; + [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:21]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); +} + +std::tuple backward_gradient_inputs( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor mask, + at::Tensor grad_out, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + int64_t n_parallel_imgs, + bool use_mask) { + + int64_t batch_sz = input.size(0); + int64_t n_in_channels = input.size(1); + int64_t in_h = input.size(2); + int64_t in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + int64_t n_out_channels = weight.size(0); + int64_t weight_h = weight.size(2); + int64_t weight_w = weight.size(3); + + int64_t out_w = + (in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int64_t out_h = + (in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + + auto grad_input = at::zeros_like(input); + auto grad_offset = at::zeros_like(offset); + auto grad_mask = at::zeros_like(mask); + + if (batch_sz == 0) { + return std::make_tuple(grad_input, grad_offset, grad_mask); + } + + auto columns = at::empty( + {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, + input.options()); + + // Separate into blocks + grad_input = grad_input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + input = input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + + grad_offset = grad_offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + offset = offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + grad_mask = grad_mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + mask = mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + grad_out = grad_out + .reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_weight_grps, + n_out_channels / n_weight_grps, + out_h, + out_w}) + .permute({0, 2, 3, 1, 4, 5}); + + weight = weight.reshape( + {n_weight_grps, + weight.size(0) / n_weight_grps, + weight.size(1), + weight.size(2), + weight.size(3)}); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + for (int64_t elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + columns.zero_(); + // Separate into weight groups + for (int64_t g = 0; g < n_weight_grps; g++) { + columns[g] = columns[g].addmm_( + weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); + } + + compute_grad_offset_and_mask( + columns, + input[elt], + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + grad_offset[elt], + grad_mask[elt]); + + compute_grad_input( + columns, + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + grad_input[elt]); + } + + grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); + grad_offset = grad_offset.view( + {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + if (use_mask) { + grad_mask = grad_mask.view( + {batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w}); + } + + return std::make_tuple(grad_input, grad_offset, grad_mask); +} + +at::Tensor backward_gradient_parameters( + at::Tensor input, + const at::Tensor& weight, + at::Tensor offset, + at::Tensor mask, + const at::Tensor& grad_out, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + int64_t n_parallel_imgs, + bool use_mask) { + + int64_t batch_sz = input.size(0); + int64_t n_in_channels = input.size(1); + int64_t in_h = input.size(2); + int64_t in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + int64_t n_out_channels = weight.size(0); + int64_t weight_h = weight.size(2); + int64_t weight_w = weight.size(3); + + int64_t out_h = grad_out.size(2); + int64_t out_w = grad_out.size(3); + + auto grad_weight = at::zeros_like(weight); + if (batch_sz == 0) { + return grad_weight; + } + + at::Tensor grad_out_buf = grad_out + .reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_weight_grps, + n_out_channels / n_weight_grps, + out_h, + out_w}) + .permute({0, 2, 3, 1, 4, 5}) + .contiguous(); + + input = input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + + offset = offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + mask = mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + grad_weight = grad_weight.reshape( + {n_weight_grps, + grad_weight.size(0) / n_weight_grps, + grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3)}); + + auto columns = at::empty( + {n_weight_grps, + n_in_channels * weight_w * weight_h / n_weight_grps, + n_parallel_imgs * out_h * out_w}, + input.options()); + + for (int64_t elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + deformable_im2col( + input[elt], + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + columns); + + for (int64_t g = 0; g < n_weight_grps; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_( + grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) + .view_as(grad_weight[g]); + } + } + + grad_weight = grad_weight.view( + {grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3), + grad_weight.size(4)}); + return grad_weight; +} + +at::Tensor deform_conv2d_forward_kernel( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + at::Tensor input_c = input.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + TORCH_CHECK(input_c.ndimension() == 4); + TORCH_CHECK(offset_c.ndimension() == 4); + TORCH_CHECK(!use_mask || mask_c.ndimension() == 4); + TORCH_CHECK(weight_c.ndimension() == 4); + TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor"); + + at::DeviceGuard guard(input_c.device()); + + int batch_sz = input_c.size(0); + int in_channels = input_c.size(1); + int in_h = input_c.size(2); + int in_w = input_c.size(3); + + int n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, tkMaxParallelImgs); + + int out_channels = weight_c.size(0); + int weight_h = weight_c.size(2); + int weight_w = weight_c.size(3); + + int ker_h = dilation_h * (weight_h - 1) + 1; + int ker_w = dilation_w * (weight_w - 1) + 1; + int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; + + TORCH_CHECK( + weight_h > 0 && weight_w > 0, + "weight_h: ", + weight_h, + " weight_w: ", + weight_w); + TORCH_CHECK( + stride_h > 0 && stride_w > 0, + "stride_h: ", + stride_h, + " stride_w: ", + stride_w); + TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); + TORCH_CHECK( + dilation_h > 0 && dilation_w > 0, + "dilation_h: ", + dilation_h, + " dilation_w: ", + dilation_w); + + TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1)); + TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0); + TORCH_CHECK( + (offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w), + "offset.shape[1] is not valid: got: ", + offset_c.size(1), + " expected: ", + n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK( + (!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w), + "mask.shape[1] is not valid: got: ", + mask_c.size(1), + " expected: ", + n_offset_grps * weight_h * weight_w); + TORCH_CHECK(input_c.size(1) % n_offset_grps == 0); + + TORCH_CHECK( + (offset_c.size(0) == input_c.size(0)), "invalid batch size of offset"); + TORCH_CHECK( + (offset_c.size(2) == out_h && offset_c.size(3) == out_w), + "offset output dims: (", + offset_c.size(2), + ", ", + offset_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + (mask_c.size(0) == input_c.size(0)), "invalid batch size of mask"); + TORCH_CHECK( + (!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)), + "mask output dims: (", + mask_c.size(2), + ", ", + mask_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + out_h > 0 && out_w > 0, + "Calculated output size too small - out_h: ", + out_h, + " out_w: ", + out_w); + + auto out = + at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options()); + if (batch_sz == 0) { + return out; + } + + // Separate batches into blocks + out = out.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + out_channels, + out_h, + out_w}); + input_c = input_c.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); + + offset_c = offset_c.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + mask_c = mask_c.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + at::Tensor out_buf = at::zeros( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs * out_h, + out_w}, + out.options()); + + // Separate channels into convolution groups + out_buf = out_buf.view( + {out_buf.size(0), + n_weight_grps, + out_buf.size(1) / n_weight_grps, + out_buf.size(2), + out_buf.size(3)}); + weight_c = weight_c.view( + {n_weight_grps, + weight_c.size(0) / n_weight_grps, + weight_c.size(1), + weight_c.size(2), + weight_c.size(3)}); + + // Sample points and perform convolution + auto columns = at::zeros( + {in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, + input_c.options()); + + for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { + deformable_im2col( + input_c[b], + offset_c[b], + mask_c[b], + in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + columns); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + for (int g = 0; g < n_weight_grps; g++) { + out_buf[b][g] = out_buf[b][g] + .flatten(1) + .addmm_(weight_c[g].flatten(1), columns[g]) + .view_as(out_buf[b][g]); + } + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + out_buf = out_buf.view( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs, + out_h, + out_w}); + out_buf.transpose_(1, 2); + out.copy_(out_buf); + out = out.view({batch_sz, out_channels, out_h, out_w}); + + return out + bias_c.view({1, out_channels, 1, 1}); +} + +std::tuple +deform_conv2d_backward_kernel( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + at::Tensor grad_out_c = grad_out.contiguous(); + at::Tensor input_c = input.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + const int64_t batch_sz = input_c.size(0); + const int64_t n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, tkMaxParallelImgs); + + auto grad_input_and_offset_and_mask = backward_gradient_inputs( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + n_weight_grps, + n_offset_grps, + n_parallel_imgs, + use_mask); + + auto grad_input = std::get<0>(grad_input_and_offset_and_mask); + auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); + auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); + + auto grad_weight = backward_gradient_parameters( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + n_weight_grps, + n_offset_grps, + n_parallel_imgs, + use_mask); + + auto value = grad_out_c.sum({0, 2, 3}); + auto grad_bias = at::ones_like(bias_c) * value; + + return std::make_tuple( + grad_input, grad_weight, grad_offset, grad_mask, grad_bias); +} +} // namespace + + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), + TORCH_FN(deform_conv2d_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), + TORCH_FN(deform_conv2d_backward_kernel)); +} + +} // namespace ops +} // namespace vision + diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index e720a1608f1..002a3d4c242 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -1036,11 +1036,436 @@ kernel void ps_roi_pool_backward( \ constant int64_t & width [[buffer(7)]], \ constant int64_t & pooled_height [[buffer(8)]], \ constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & channels_out [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(10)]], \ constant float & spatial_scale [[buffer(11)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); + + + + + + + + + +/*----------- START OF DEFORM_CONV2D KERNEL IMPLEMENTATION -----------------*/ + + + + + + + +template +kernel void deformable_im2col( + constant int64_t & n [[buffer(0)]], + constant scalar_t * input_ptr [[buffer(1)]], + constant scalar_t * offset_ptr [[buffer(2)]], + constant scalar_t * mask_ptr [[buffer(3)]], + constant int64_t & height [[buffer(4)]], + constant int64_t & width [[buffer(5)]], + constant int64_t & weight_h [[buffer(6)]], + constant int64_t & weight_w [[buffer(7)]], + constant int64_t & pad_h [[buffer(8)]], + constant int64_t & pad_w [[buffer(9)]], + constant int64_t & stride_h [[buffer(10)]], + constant int64_t & stride_w [[buffer(11)]], + constant int64_t & dilation_h [[buffer(12)]], + constant int64_t & dilation_w [[buffer(13)]], + constant int64_t & batch_sz [[buffer(14)]], + constant int64_t & n_in_channels [[buffer(15)]], + constant int64_t & n_offset_grps [[buffer(16)]], + constant int64_t & out_h [[buffer(17)]], + constant int64_t & out_w [[buffer(18)]], + constant bool & use_mask [[buffer(19)]], + device scalar_t * columns_ptr [[buffer(20)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]) { + MPS_1D_KERNEL_LOOP(index, n, 1) { + const integer_t out_x = index % out_w; + const integer_t out_y = (index / out_w) % out_h; + const integer_t out_b = (index / (out_w * out_h)) % batch_sz; + const integer_t in_c = index / (out_w * out_h * batch_sz); + const integer_t out_c = in_c * weight_h * weight_w; + + integer_t c_per_offset_grp = n_in_channels / n_offset_grps; + const integer_t grp_idx = in_c / c_per_offset_grp; + + columns_ptr += + (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + + out_y * out_w + out_x); + + input_ptr += + (out_b * (n_in_channels * height * width) + in_c * (height * width)); + + offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * + out_h * out_w; + + if (use_mask) { + mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * + out_h * out_w; + } + + for (int i = 0; i < weight_h; ++i) { + for (int j = 0; j < weight_w; ++j) { + const integer_t mask_idx = i * weight_w + j; + const integer_t offset_idx = 2 * mask_idx; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = + mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; + } + + const scalar_t offset_h = + offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t offset_w = offset_ptr + [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t y = + (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = + (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + *columns_ptr = + mask_value * bilinear_interpolate(input_ptr, height, width, y, x, index); + columns_ptr += batch_sz * out_h * out_w; + } + } + } +} + +#define REGISTER_DEFORMABLE_IM2COL_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("deformable_im2col_" #DTYPE)]] \ +kernel void deformable_im2col( \ + constant int64_t & n [[buffer(0)]], \ + constant DTYPE * input_ptr [[buffer(1)]], \ + constant DTYPE * offset_ptr [[buffer(2)]], \ + constant DTYPE * mask_ptr [[buffer(3)]], \ + constant int64_t & height [[buffer(4)]], \ + constant int64_t & width [[buffer(5)]], \ + constant int64_t & weight_h [[buffer(6)]], \ + constant int64_t & weight_w [[buffer(7)]], \ + constant int64_t & pad_h [[buffer(8)]], \ + constant int64_t & pad_w [[buffer(9)]], \ + constant int64_t & stride_h [[buffer(10)]], \ + constant int64_t & stride_w [[buffer(11)]], \ + constant int64_t & dilation_h [[buffer(12)]], \ + constant int64_t & dilation_w [[buffer(13)]], \ + constant int64_t & batch_sz [[buffer(14)]], \ + constant int64_t & n_in_channels [[buffer(15)]], \ + constant int64_t & n_offset_grps [[buffer(16)]], \ + constant int64_t & out_h [[buffer(17)]], \ + constant int64_t & out_w [[buffer(18)]], \ + constant bool & use_mask [[buffer(19)]], \ + device DTYPE * columns_ptr [[buffer(20)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + + + +template +kernel void deformable_col2im( + constant int64_t & n [[buffer(0)]], + constant scalar_t * col [[buffer(1)]], + constant scalar_t * offset_ptr [[buffer(2)]], + constant scalar_t * mask_ptr [[buffer(3)]], + constant int64_t & channels [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], + constant int64_t & kernel_h [[buffer(7)]], + constant int64_t & kernel_w [[buffer(8)]], + constant int64_t & pad_h [[buffer(9)]], + constant int64_t & pad_w [[buffer(10)]], + constant int64_t & stride_h [[buffer(11)]], + constant int64_t & stride_w [[buffer(12)]], + constant int64_t & dilation_h [[buffer(13)]], + constant int64_t & dilation_w [[buffer(14)]], + constant int64_t & batch_sz [[buffer(15)]], + constant int64_t & n_offset_grps [[buffer(16)]], + constant int64_t & out_h [[buffer(17)]], + constant int64_t & out_w [[buffer(18)]], + constant bool & use_mask [[buffer(19)]], + device scalar_t * grad_im [[buffer(20)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + const integer_t grad_im_numel = width * height * channels * batch_sz; + + MPS_1D_KERNEL_LOOP(index, n, 1) { + const integer_t out_x = index % out_w; + const integer_t out_y = (index / out_w) % out_h; + const integer_t b = (index / (out_w * out_h)) % batch_sz; + const integer_t j = (index / (out_w * out_h * batch_sz)) % kernel_w; + const integer_t i = + (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; + const integer_t c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); + + integer_t c_per_offset_grp = channels / n_offset_grps; + const integer_t offset_grp = c / c_per_offset_grp; + + offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * + out_h * out_w; + + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * + out_h * out_w; + } + + const integer_t mask_idx = i * kernel_w + j; + const integer_t offset_idx = 2 * mask_idx; + + const integer_t offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; + const integer_t offset_w_ptr = + ((offset_idx + 1) * out_h + out_y) * out_w + out_x; + + const scalar_t offset_h = offset_ptr[offset_h_ptr]; + const scalar_t offset_w = offset_ptr[offset_w_ptr]; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + + const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + for (integer_t dy = -1; dy <= 1; dy++) { + for (integer_t dx = -1; dx <= 1; dx++) { + integer_t yp = (integer_t)y + dy; + integer_t xp = (integer_t)x + dx; + if (0 <= yp && yp < height && 0 <= xp && xp < width && + abs(y - yp) < 1 && abs(x - xp) < 1) { + integer_t grad_pos = ((b * channels + c) * height + yp) * width + xp; + scalar_t weight = (1 - abs(y - yp)) * (1 - abs(x - xp)); + // MSL doesn't support at::native::fastAtomicAdd + if (grad_pos >= 0 && grad_pos < grad_im_numel) { + // Atomically add the computed value directly + atomic_add_float(grad_im + grad_pos, static_cast(mask_value * weight * col[index])); + } + } + } + } + } +} + +#define REGISTER_DEFORMABLE_COL2IM_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("deformable_col2im_" #DTYPE)]] \ +kernel void deformable_col2im( \ + constant int64_t & n [[buffer(0)]], \ + constant DTYPE * col [[buffer(1)]], \ + constant DTYPE * offset_ptr [[buffer(2)]], \ + constant DTYPE * mask_ptr [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & kernel_h [[buffer(7)]], \ + constant int64_t & kernel_w [[buffer(8)]], \ + constant int64_t & pad_h [[buffer(9)]], \ + constant int64_t & pad_w [[buffer(10)]], \ + constant int64_t & stride_h [[buffer(11)]], \ + constant int64_t & stride_w [[buffer(12)]], \ + constant int64_t & dilation_h [[buffer(13)]], \ + constant int64_t & dilation_w [[buffer(14)]], \ + constant int64_t & batch_sz [[buffer(15)]], \ + constant int64_t & n_offset_grps [[buffer(16)]], \ + constant int64_t & out_h [[buffer(17)]], \ + constant int64_t & out_w [[buffer(18)]], \ + constant bool & use_mask [[buffer(19)]], \ + device DTYPE * grad_im [[buffer(20)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + + +template +scalar_t get_coordinate_weight( + constant scalar_t* im_data, + index_t height, + index_t width, + scalar_t y, + scalar_t x, + bool is_y_direction) { + index_t y_l = floor(y); + index_t x_l = floor(x); + index_t y_h = y_l + 1; + index_t x_h = x_l + 1; + + bool valid_y_l = 0 <= y_l && y_l < height; + bool valid_y_h = 0 <= y_h && y_h < height; + bool valid_x_l = 0 <= x_l && x_l < width; + bool valid_x_h = 0 <= x_h && x_h < width; + + scalar_t zero = 0; + scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero; + scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero; + scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero; + scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero; + + if (is_y_direction) { + scalar_t dx = x - x_l; + return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx); + } else { + scalar_t dy = y - y_l; + return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx); + } +} + + + + + +template +kernel void deformable_col2im_coord( + constant int64_t & n [[buffer(0)]], + constant scalar_t * col_ptr [[buffer(1)]], + constant scalar_t * im_ptr [[buffer(2)]], + constant scalar_t * offset_ptr [[buffer(3)]], + constant scalar_t * mask_ptr [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & weight_h [[buffer(8)]], + constant int64_t & weight_w [[buffer(9)]], + constant int64_t & pad_h [[buffer(10)]], + constant int64_t & pad_w [[buffer(11)]], + constant int64_t & stride_h [[buffer(12)]], + constant int64_t & stride_w [[buffer(13)]], + constant int64_t & dilation_h [[buffer(14)]], + constant int64_t & dilation_w [[buffer(15)]], + constant int64_t & batch_sz [[buffer(16)]], + constant int64_t & offset_channels [[buffer(17)]], + constant int64_t & n_offset_grps [[buffer(18)]], + constant int64_t & out_h [[buffer(19)]], + constant int64_t & out_w [[buffer(20)]], + constant bool & use_mask [[buffer(21)]], + device scalar_t* grad_offset [[buffer(22)]], + device scalar_t* grad_mask [[buffer(23)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]) { + MPS_1D_KERNEL_LOOP(index, n, 1) { + scalar_t grad_offset_val = 0; + scalar_t grad_mask_val = 0; + integer_t w = index % out_w; + integer_t h = (index / out_w) % out_h; + integer_t w_w = (index / (out_w * out_h * 2)) % weight_w; + integer_t w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; + integer_t c = (index / (out_w * out_h)) % offset_channels; + integer_t b = index / (out_w * out_h * offset_channels); + + const integer_t offset_grp = c / (2 * weight_h * weight_w); + const integer_t col_step = weight_h * weight_w; + + integer_t c_per_offset_grp = channels / n_offset_grps; + + col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * + out_w * out_h; + im_ptr += + (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; + offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * + out_h * out_w; + + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * + out_h * out_w; + } + + const integer_t offset_c = c - offset_grp * 2 * weight_h * weight_w; + const bool is_y_direction = offset_c % 2 == 0; + + const integer_t c_bound = c_per_offset_grp * weight_h * weight_w; + for (integer_t col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { + const integer_t col_pos = + (((col_c * batch_sz + b) * out_h) + h) * out_w + w; + + integer_t out_x = col_pos % out_w; + integer_t out_y = (col_pos / out_w) % out_h; + integer_t j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; + integer_t i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; + + const integer_t mask_idx = i * weight_w + j; + + const integer_t offset_h_ptr = + (((2 * mask_idx) * out_h + out_y) * out_w + out_x); + const integer_t offset_w_ptr = + (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); + const scalar_t offset_h = offset_ptr[offset_h_ptr]; + const scalar_t offset_w = offset_ptr[offset_w_ptr]; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + + scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + const scalar_t weight = + get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); + grad_offset_val += mask_value * weight * col_ptr[col_pos]; + + if (use_mask && is_y_direction) { + grad_mask_val += col_ptr[col_pos] * + bilinear_interpolate(im_ptr, height, width, y, x, index); + } + + im_ptr += height * width; + } + + grad_offset[index] = grad_offset_val; + + if (use_mask && is_y_direction) { + const integer_t idx = + ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + + w_w) * + out_h + + h) * + out_w + + w; + grad_mask[idx] = grad_mask_val; + } + } +} + +#define REGISTER_DEFORMABLE_COL2IM_COORD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("deformable_col2im_coord_" #DTYPE)]] \ +kernel void deformable_col2im_coord( \ + constant int64_t & n [[buffer(0)]], \ + constant DTYPE * col_ptr [[buffer(1)]], \ + constant DTYPE * im_ptr [[buffer(2)]], \ + constant DTYPE * offset_ptr [[buffer(3)]], \ + constant DTYPE * mask_ptr [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & weight_h [[buffer(8)]], \ + constant int64_t & weight_w [[buffer(9)]], \ + constant int64_t & pad_h [[buffer(10)]], \ + constant int64_t & pad_w [[buffer(11)]], \ + constant int64_t & stride_h [[buffer(12)]], \ + constant int64_t & stride_w [[buffer(13)]], \ + constant int64_t & dilation_h [[buffer(14)]], \ + constant int64_t & dilation_w [[buffer(15)]], \ + constant int64_t & batch_sz [[buffer(16)]], \ + constant int64_t & offset_channels [[buffer(17)]], \ + constant int64_t & n_offset_grps [[buffer(18)]], \ + constant int64_t & out_h [[buffer(19)]], \ + constant int64_t & out_w [[buffer(20)]], \ + constant bool & use_mask [[buffer(21)]], \ + device DTYPE * grad_offset [[buffer(22)]], \ + device DTYPE * grad_mask [[buffer(23)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +/* ----------END OF DEFORM_CONV2D KERNELS ----------------------*/ + REGISTER_NMS_OP(float); REGISTER_NMS_OP(half); @@ -1060,6 +1485,12 @@ REGISTER_PS_ROI_POOL_OP(float, int64_t); REGISTER_PS_ROI_POOL_OP(half, int64_t); REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t); REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t); +REGISTER_DEFORMABLE_IM2COL_OP(float, int64_t); +REGISTER_DEFORMABLE_IM2COL_OP(half, int64_t); +REGISTER_DEFORMABLE_COL2IM_OP(float, int64_t); +REGISTER_DEFORMABLE_COL2IM_OP(half, int64_t); +REGISTER_DEFORMABLE_COL2IM_COORD_OP(float, int64_t); +REGISTER_DEFORMABLE_COL2IM_COORD_OP(half, int64_t); )VISION_METAL"; From 1153b847ec00a37577c359175c4caa9b9c9f8708 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Fri, 15 Nov 2024 15:30:09 +0100 Subject: [PATCH 05/31] Renaming source file. --- Deform_conv2d_kernals.metal | 447 ++++++++++++++++++++++++++++++++++++ 1 file changed, 447 insertions(+) diff --git a/Deform_conv2d_kernals.metal b/Deform_conv2d_kernals.metal index 2b528c54007..fb18af38ebf 100644 --- a/Deform_conv2d_kernals.metal +++ b/Deform_conv2d_kernals.metal @@ -5,7 +5,454 @@ // Created by Thomas Martin on 14/10/2024. // +// This include will only work when the remaining code is embedded in a C string in mps_kernels.h +//#include + #include +#include + using namespace metal; +// ********************************************************************** +// MACROS AND HELPER FUNCTIONS SHOULD NOT BE INCLUDED IN THE FINAL SOURCE +// AS THEY ARE ALREADY INCLUDED IN mps_kernels.h +// ********************************************************************** + +/*----------Macros----------*/ + +#define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \ + for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \ + i += (tptg.x * n_tgs)) + +#define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, uint) + + +/*----------Helpers--------*/ + +template +inline T ceil_div(T n, T m) { + return (n + m - 1) / m; +} + + +template +inline void atomic_add_float( device T* data_ptr, const T val) +{ +#if __METAL_VERSION__ >= 300 + // atomic_float is supported in Metal 3 (macOS Ventura) onward. + atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); +#else + // Custom atomic addition implementation + // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 + // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639 + // https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide) + + // Create an atomic uint pointer for atomic transaction. + device atomic_uint* atom_var = (device atomic_uint*)data_ptr; + // Create necessary storage. + uint fetched_uint, assigning_uint; + T fetched_float, assigning_float; + + // Replace the value in atom_var with 0 and return the previous value in atom_var. + fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed); + // Read out the previous value as float. + fetched_float = *( (thread T*) &fetched_uint ); + + // Do addition and represent the addition result in uint for atomic transaction. + assigning_float = fetched_float + val; + assigning_uint = *((thread uint*) &assigning_float); + + // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr). + while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) { + // If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads. + // Try to assign 0 and get the previously assigned addition result. + uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed); + T fetched_float_again = *( (thread T*) &fetched_uint_again ); + // Re-add again + fetched_float = *((thread T*) &(fetched_uint)); + // Previously assigned addition result + addition result from other threads. + assigning_float = fetched_float_again + fetched_float; + assigning_uint = *( (thread uint*) &assigning_float); + } +#endif +} + + +template +kernel void deformable_im2col( + index_t n [[buffer(0)]], + constant scalar_t* input_ptr [[buffer(1)]], + constant scalar_t* offset_ptr [[buffer(2)]], + constant scalar_t* mask_ptr [[buffer(3)]], + index_t height [[buffer(4)]], + index_t width [[buffer(5)]], + index_t weight_h [[buffer(6)]], + index_t weight_w [[buffer(7)]], + index_t pad_h [[buffer(8)]], + index_t pad_w [[buffer(9)]], + index_t stride_h [[buffer(10)]], + index_t stride_w [[buffer(11)]], + index_t dilation_h [[buffer(12)]], + index_t dilation_w [[buffer(13)]], + index_t batch_sz [[buffer(14)]], // parallel_imgs + index_t n_in_channels [[buffer(15)]], + index_t n_offset_grps [[buffer(16)]], //deformable_grp + index_t out_h [[buffer(17)]], + index_t out_w [[buffer(18)]], + constant bool & use_mask [[buffer(19)]], + device scalar_t* columns_ptr [[buffer(20)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]) { + MPS_1D_KERNEL_LOOP(index, n, 1) { + const index_t out_x = index % out_w; + const index_t out_y = (index / out_w) % out_h; + const index_t out_b = (index / (out_w * out_h)) % batch_sz; + const index_t in_c = index / (out_w * out_h * batch_sz); + const index_t out_c = in_c * weight_h * weight_w; + + index_t c_per_offset_grp = n_in_channels / n_offset_grps; + const index_t grp_idx = in_c / c_per_offset_grp; + + columns_ptr += + (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + + out_y * out_w + out_x); + + input_ptr += + (out_b * (n_in_channels * height * width) + in_c * (height * width)); + + offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * + out_h * out_w; + + if (use_mask) { + mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * + out_h * out_w; + } + + for (int i = 0; i < weight_h; ++i) { + for (int j = 0; j < weight_w; ++j) { + const index_t mask_idx = i * weight_w + j; + const index_t offset_idx = 2 * mask_idx; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = + mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; + } + + const scalar_t offset_h = + offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t offset_w = offset_ptr + [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t y = + (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = + (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + *columns_ptr = + mask_value * bilinear_interpolate(input_ptr, height, width, y, x); + columns_ptr += batch_sz * out_h * out_w; + } + } + } + +} + +#define REGISTER_DEFORMABLE_IM2COL_OP(DTYPE) \ +template \ +[[host_name("deformable_im2col_" #DTYPE)]] \ +template \ +kernel void deformable_im2col( \ +index_t n [[buffer(0)]], \ +constant scalar_t* input_ptr [[buffer(1)]], \ +constant scalar_t* offset_ptr [[buffer(2)]], \ +constant scalar_t* mask_ptr [[buffer(3)]], \ +index_t height [[buffer(4)]], \ +index_t width [[buffer(5)]], \ +index_t weight_h [[buffer(6)]], \ +index_t weight_w [[buffer(7)]], \ +index_t pad_h [[buffer(8)]], \ +index_t pad_w [[buffer(9)]], \ +index_t stride_h [[buffer(10)]], \ +index_t stride_w [[buffer(11)]], \ +index_t dilation_h [[buffer(12)]], \ +index_t dilation_w [[buffer(13)]], \ +index_t batch_sz [[buffer(14)]], \ +index_t n_in_channels [[buffer(15)]], \ +index_t n_offset_grps [[buffer(16)]], \ +index_t out_h [[buffer(17)]], \ +index_t out_w [[buffer(18)]], \ +constant bool & use_mask [[buffer(19)]], \ +device scalar_t* columns_ptr [[buffer(20)]], \ +uint2 tgid [[threadgroup_position_in_grid]], \ +uint2 tptg [[threads_per_threadgroup]], \ +uint2 tid2 [[thread_position_in_threadgroup]]); + + + + + + + + +template +kernel void deformable_col2im( + index_t n [[buffer(0)]], + constant scalar_t* col [[buffer(1)]], + constant scalar_t* offset_ptr [[buffer(2)]], + constant scalar_t* mask_ptr [[buffer(3)]], + index_t channels [[buffer(4)]], + index_t height [[buffer(5)]], + index_t width [[buffer(6)]], + index_t kernel_h [[buffer(7)]], + index_t kernel_w [[buffer(8)]], + index_t pad_h [[buffer(9)]], + index_t pad_w [[buffer(10)]], + index_t stride_h [[buffer(11)]], + index_t stride_w [[buffer(12)]], + index_t dilation_h [[buffer(13)]], + index_t dilation_w [[buffer(14)]], + index_t batch_sz [[buffer(15)]], //parallel_imgs + index_t n_offset_grps [[buffer(16)]], + index_t out_h [[buffer(17)]], + index_t out_w [[buffer(18)]], + constant bool & use_mask [[buffer(19)]], + constant scalar_t* grad_im [[buffer(20)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + const index_t grad_im_numel = width * height * channels * batch_sz; + + MPS_1D_KERNEL_LOOP(index, n, 1) { + const index_t out_x = index % out_w; + const index_t out_y = (index / out_w) % out_h; + const index_t b = (index / (out_w * out_h)) % batch_sz; + const index_t j = (index / (out_w * out_h * batch_sz)) % kernel_w; + const index_t i = + (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; + const index_t c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); + + index_t c_per_offset_grp = channels / n_offset_grps; + const index_t offset_grp = c / c_per_offset_grp; + + offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * + out_h * out_w; + + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * + out_h * out_w; + } + + const index_t mask_idx = i * kernel_w + j; + const index_t offset_idx = 2 * mask_idx; + + const index_t offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; + const index_t offset_w_ptr = + ((offset_idx + 1) * out_h + out_y) * out_w + out_x; + + const scalar_t offset_h = offset_ptr[offset_h_ptr]; + const scalar_t offset_w = offset_ptr[offset_w_ptr]; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + + const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + for (index_t dy = -1; dy <= 1; dy++) { + for (index_t dx = -1; dx <= 1; dx++) { + index_t yp = (index_t)y + dy; + index_t xp = (index_t)x + dx; + if (0 <= yp && yp < height && 0 <= xp && xp < width && + abs(y - yp) < 1 && abs(x - xp) < 1) { + index_t grad_pos = ((b * channels + c) * height + yp) * width + xp; + scalar_t weight = (1 - abs(y - yp)) * (1 - abs(x - xp)); + // MSL doesn't support at::native::fastAtomicAdd + if (grad_pos >= 0 && grad_pos < grad_im_numel) { + // Atomically add the computed value directly + atomic_add_float(grad_im + grad_pos, static_cast(mask_value * weight * col[index])); + } + } + } + } + } +} + +#define REGISTER_DEFORMABLE_COL2IM_OP(DTYPE) \ +template \ +[[host_name("deformable_col2im_" #DTYPE)]] \ +template \ +kernel void deformable_col2im( \ + index_t n [[buffer(0)]], \ + constant scalar_t* col [[buffer(1)]], \ + constant scalar_t* offset_ptr [[buffer(2)]], \ + constant scalar_t* mask_ptr [[buffer(3)]], \ + index_t channels [[buffer(4)]], \ + index_t height [[buffer(5)]], \ + index_t width [[buffer(6)]], \ + index_t kernel_h [[buffer(7)]], \ + index_t kernel_w [[buffer(8)]], \ + index_t pad_h [[buffer(9)]], \ + index_t pad_w [[buffer(10)]], \ + index_t stride_h [[buffer(11)]], \ + index_t stride_w [[buffer(12)]], \ + index_t dilation_h [[buffer(13)]], \ + index_t dilation_w [[buffer(14)]], \ + index_t batch_sz [[buffer(15)]], \ + index_t n_offset_grps [[buffer(16)]], \ + index_t out_h [[buffer(17)]], \ + index_t out_w [[buffer(18)]], \ + constant bool & use_mask [[buffer(19)]], \ + constant scalar_t* grad_im [[buffer(20)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + + + +template +kernel void deformable_col2im_coord( + index_t n [[buffer(0)]], + constant scalar_t* col_ptr [[buffer(1)]], + constant scalar_t* im_ptr [[buffer(2)]], //input + constant scalar_t* offset_ptr [[buffer(3)]], + constant scalar_t* mask_ptr [[buffer(4)]], + index_t channels [[buffer(5)]], + index_t height [[buffer(6)]], + index_t width [[buffer(7)]], + index_t weight_h [[buffer(8)]], + index_t weight_w [[buffer(9)]], + index_t pad_h [[buffer(10)]], + index_t pad_w [[buffer(11)]], + index_t stride_h [[buffer(12)]], + index_t stride_w [[buffer(13)]], + index_t dilation_h [[buffer(14)]], + index_t dilation_w [[buffer(15)]], + index_t batch_sz [[buffer(16)]], //parallel_imgs + index_t offset_channels [[buffer(17)]], + index_t n_offset_grps [[buffer(18)]], + index_t out_h [[buffer(19)]], + index_t out_w [[buffer(20)]], + constant bool & use_mask [[buffer(21)]], + constant scalar_t* grad_offset [[buffer(22)]], + constant scalar_t* grad_mask [[buffer(23)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]) { + MPS_1D_KERNEL_LOOP(index, n, 1) { + scalar_t grad_offset_val = 0; + scalar_t grad_mask_val = 0; + index_t w = index % out_w; + index_t h = (index / out_w) % out_h; + index_t w_w = (index / (out_w * out_h * 2)) % weight_w; + index_t w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; + index_t c = (index / (out_w * out_h)) % offset_channels; + index_t b = index / (out_w * out_h * offset_channels); + + const index_t offset_grp = c / (2 * weight_h * weight_w); + const index_t col_step = weight_h * weight_w; + + index_t c_per_offset_grp = channels / n_offset_grps; + + col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * + out_w * out_h; + im_ptr += + (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; + offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * + out_h * out_w; + + if (use_mask) { + mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * + out_h * out_w; + } + + const index_t offset_c = c - offset_grp * 2 * weight_h * weight_w; + const bool is_y_direction = offset_c % 2 == 0; + + const index_t c_bound = c_per_offset_grp * weight_h * weight_w; + for (index_t col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { + const index_t col_pos = + (((col_c * batch_sz + b) * out_h) + h) * out_w + w; + + index_t out_x = col_pos % out_w; + index_t out_y = (col_pos / out_w) % out_h; + index_t j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; + index_t i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; + + const index_t mask_idx = i * weight_w + j; + + const index_t offset_h_ptr = + (((2 * mask_idx) * out_h + out_y) * out_w + out_x); + const index_t offset_w_ptr = + (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); + const scalar_t offset_h = offset_ptr[offset_h_ptr]; + const scalar_t offset_w = offset_ptr[offset_w_ptr]; + + scalar_t mask_value = 1; + if (use_mask) { + mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; + } + + scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + + const scalar_t weight = + get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); + grad_offset_val += mask_value * weight * col_ptr[col_pos]; + + if (use_mask && is_y_direction) { + grad_mask_val += col_ptr[col_pos] * + bilinear_interpolate(im_ptr, height, width, y, x); + } + + im_ptr += height * width; + } + + grad_offset[index] = grad_offset_val; + + if (use_mask && is_y_direction) { + const index_t idx = + ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + + w_w) * + out_h + + h) * + out_w + + w; + grad_mask[idx] = grad_mask_val; + } + } +} +#define REGISTER_DEFORMABLE_COL2IM_COORD_OP(DTYPE) \ +template \ +[[host_name("deformable_col2im_coord_" #DTYPE)]] \ +template \ +kernel void deformable_col2im_coord( \ + index_t n [[buffer(0)]],\ + constant scalar_t* col_ptr [[buffer(1)]], \ + constant scalar_t* im_ptr [[buffer(2)]], \ + constant scalar_t* offset_ptr [[buffer(3)]], \ + constant scalar_t* mask_ptr [[buffer(4)]], \ + index_t channels [[buffer(5)]], \ + index_t height [[buffer(6)]], \ + index_t width [[buffer(7)]], \ + index_t weight_h [[buffer(8)]], \ + index_t weight_w [[buffer(9)]], \ + index_t pad_h [[buffer(10)]], \ + index_t pad_w [[buffer(11)]], \ + index_t stride_h [[buffer(12)]], \ + index_t stride_w [[buffer(13)]], \ + index_t dilation_h [[buffer(14)]], \ + index_t dilation_w [[buffer(15)]], \ + index_t batch_sz [[buffer(16)]], \ + index_t offset_channels [[buffer(17)]], \ + index_t n_offset_grps [[buffer(18)]], \ + index_t out_h [[buffer(19)]], \ + index_t out_w [[buffer(20)]], \ + constant bool & use_mask [[buffer(21)]], \ + constant scalar_t* grad_offset [[buffer(22)]], \ + constant scalar_t* grad_mask [[buffer(23)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); From 1c87a264e466f40333196bb62303a8d286a11f98 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Fri, 15 Nov 2024 15:37:25 +0100 Subject: [PATCH 06/31] Changed part of the file name from _kernal to _kernel --- .../csrc/ops/mps/deform_conv2d_kernel.mm | 936 ++++++++++++++++++ 1 file changed, 936 insertions(+) create mode 100644 torchvision/csrc/ops/mps/deform_conv2d_kernel.mm diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm new file mode 100644 index 00000000000..7295fc2caa5 --- /dev/null +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -0,0 +1,936 @@ +// vision::ops:: +// deform_conv2d_kernal.mm +// + +#include +#include +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + + +namespace vision { +namespace ops { + +namespace { + +const int64_t tkMaxParallelImgs = 32; + + +void deformable_im2col(const at::Tensor& input, + const at::Tensor& data_offset, + const at::Tensor& data_mask, + int64_t n_in_channels, + int64_t height, + int64_t width, + int64_t weight_h, + int64_t weight_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t out_h, + int64_t out_w, + int64_t parallel_imgs, + int64_t deformable_group, + bool use_mask, + at::Tensor data_col) { + using namespace at::native::mps; + + // Validate tensors as of type mps. + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(data_offset.is_mps(), "data_offset must be a MPS tensor"); + TORCH_CHECK(data_mask.is_mps(), "data_mask must be a MPS tensor"); + + at::TensorArg input_t{input, "input", 1}, + data_offset_t{data_offset, "data_offset", 2}, + data_mask_t{data_mask, "data_mask", 3}; + + at::CheckedFrom c = "deformable_im2col"; + at::checkAllSameGPU(c, {input_t, data_offset_t, data_mask_t}); + at::checkAllSameType(c, {input_t, data_offset_t, data_mask_t}); + + + const int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs; + + // These function parameters have all been made contiguous by the caller function deform_conv2d_forward_kernel + // Check if it is safe to skip the following: + auto input_c = input.contiguous(); + auto data_offset_c = data_offset.contiguous(); + auto data_mask_c = data_mask.contiguous(); + + // Get a raw pointer to the underlying data structure of the tensors and cast it as a pointer to an MTLBuffer. + id inputBuffer = getMTLBufferStorage(input_c); + id data_offsetBuffer = getMTLBufferStorage(data_offset_c); + id data_maskBuffer = getMTLBufferStorage(data_mask_c); + id data_colBuffer = getMTLBufferStorage(data_col); + + id device = MPSDevice::getInstance()->device(); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + const std::string kernel = "deformable_im2col_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(num_kernels), + static_cast(512)), + static_cast(4096)), + 1, + 1); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_c, data_offset_c, data_mask_c}); + + id computeEncoder = mpsStream->commandEncoder(); + [computeEncoder setComputePipelineState:visionPSO]; + + [computeEncoder setBuffer:inputBuffer offset:input_c.storage_offset() * input_c.element_size() atIndex:1]; + [computeEncoder setBuffer:data_offsetBuffer offset:data_offset_c.storage_offset() * data_offset_c.element_size() atIndex:2]; + [computeEncoder setBuffer:data_maskBuffer offset:data_mask_c.storage_offset() * data_mask_c.element_size() atIndex:3]; + [computeEncoder setBuffer:data_colBuffer offset:data_col.storage_offset() * data_col.element_size() atIndex:20]; + + [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:0]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:&n_in_channels length:sizeof(int64_t) atIndex:15]; + [computeEncoder setBytes:&deformable_group length:sizeof(int64_t) atIndex:16]; + [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:17]; + [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:18]; + [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:19]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + +} + +int get_greatest_divisor_below_bound(int n, int bound) { + for (int k = bound; k > 1; --k) { + if (n % k == 0) { + return k; + } + } + return 1; +} + +void compute_grad_input( + const at::Tensor& columns, + const at::Tensor& offset, + const at::Tensor& mask, + int64_t channels, + int64_t height, + int64_t width, + int64_t weight_h, //kernel_h + int64_t weight_w, //kernel_w + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t parallel_imgs, //batch_sz + int64_t n_offset_grps, + bool use_mask, + at::Tensor grad_im) { + using namespace at::native::mps; + + at::globalContext().alertNotDeterministic("compute_grad_input"); + + auto columns_c = columns.contiguous(); + auto offset_c = offset.contiguous(); + auto mask_c = mask.contiguous(); + + const int64_t out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + const int64_t out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + + const int64_t num_kernels = + (int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs; + + id columnsBuffer = getMTLBufferStorage(columns_c); + id offsetBuffer = getMTLBufferStorage(offset_c); + id maskBuffer = getMTLBufferStorage(mask_c); + id grad_imBuffer = getMTLBufferStorage(grad_im); + + id device = MPSDevice::getInstance()->device(); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + + const std::string kernel = "deformable_col2im_" + scalarToMetalTypeString(columns.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {columns, offset, mask}); + + [computeEncoder setComputePipelineState:visionPSO]; + + [computeEncoder setBuffer:columnsBuffer offset:columns_c.storage_offset() * columns_c.element_size() atIndex:1]; + [computeEncoder setBuffer:offsetBuffer offset:offset_c.storage_offset() * offset_c.element_size() atIndex:2]; + [computeEncoder setBuffer:maskBuffer offset:mask_c.storage_offset() * mask_c.element_size() atIndex:3]; + [computeEncoder setBuffer:grad_imBuffer + offset:grad_im.storage_offset() * grad_im.element_size() + atIndex:20]; + + [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:0]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:15]; + [computeEncoder setBytes:&n_offset_grps length:sizeof(int64_t) atIndex:16]; + [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:17]; + [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:18]; + [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:19]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(num_kernels), static_cast(512)), static_cast(4096)), + 1, + 1); + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); +} + +void compute_grad_offset_and_mask( + const at::Tensor& columns, + const at::Tensor& input, + const at::Tensor& offset, + const at::Tensor& mask, + int64_t channels, + int64_t height, + int64_t width, + int64_t weight_h, + int64_t weight_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t parallel_imgs, + int64_t n_offset_grps, + bool use_mask, + at::Tensor grad_offset, + at::Tensor grad_mask) { + + using namespace at::native::mps; + + auto columns_c = columns; //.contiguous(); + auto input_c = input; //.contiguous(); + auto offset_c = offset; //.contiguous(); + auto mask_c = mask; //.contiguous(); + + const int64_t out_h = + (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + const int64_t out_w = + (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + const int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w * + n_offset_grps * parallel_imgs; + + const int64_t offset_channels = 2 * weight_h * weight_w * n_offset_grps; + + id columnsBuffer = getMTLBufferStorage(columns_c); + id inputBuffer = getMTLBufferStorage(input_c); + id offsetBuffer = getMTLBufferStorage(offset_c); + id maskBuffer = getMTLBufferStorage(mask_c); + id grad_offsetBuffer = getMTLBufferStorage(grad_offset); + id grad_maskBuffer = getMTLBufferStorage(grad_mask); + + id device = MPSDevice::getInstance()->device(); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(num_kernels), static_cast(512)), static_cast(4096)), 1, 1); + + const std::string kernel = "deformable_col2im_coord_" + scalarToMetalTypeString(columns.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {columns_c, input_c, offset_c, mask_c}); + + [computeEncoder setComputePipelineState:visionPSO]; + + [computeEncoder setBuffer:columnsBuffer offset:columns_c.storage_offset() * columns_c.element_size() atIndex:1]; + [computeEncoder setBuffer:inputBuffer offset:input_c.storage_offset() * input_c.element_size() atIndex:2]; + [computeEncoder setBuffer:offsetBuffer offset:offset_c.storage_offset() * offset_c.element_size() atIndex:3]; + [computeEncoder setBuffer:maskBuffer offset:mask_c.storage_offset() * mask_c.element_size() atIndex:4]; + [computeEncoder setBuffer:grad_offsetBuffer + offset:grad_offset.storage_offset() * grad_offset.element_size() + atIndex:22]; + [computeEncoder setBuffer:grad_maskBuffer + offset:grad_mask.storage_offset() * grad_mask.element_size() + atIndex:23]; + + [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:0]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:15]; + [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:16]; + [computeEncoder setBytes:&offset_channels length:sizeof(int64_t) atIndex:17]; + [computeEncoder setBytes:&n_offset_grps length:sizeof(int64_t) atIndex:18]; + [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:19]; + [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:20]; + [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:21]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); +} + +std::tuple backward_gradient_inputs( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor mask, + at::Tensor grad_out, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + int64_t n_parallel_imgs, + bool use_mask) { + + int64_t batch_sz = input.size(0); + int64_t n_in_channels = input.size(1); + int64_t in_h = input.size(2); + int64_t in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + int64_t n_out_channels = weight.size(0); + int64_t weight_h = weight.size(2); + int64_t weight_w = weight.size(3); + + int64_t out_w = + (in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; + int64_t out_h = + (in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; + + auto grad_input = at::zeros_like(input); + auto grad_offset = at::zeros_like(offset); + auto grad_mask = at::zeros_like(mask); + + if (batch_sz == 0) { + return std::make_tuple(grad_input, grad_offset, grad_mask); + } + + auto columns = at::empty( + {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, + input.options()); + + // Separate into blocks + grad_input = grad_input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + input = input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + + grad_offset = grad_offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + offset = offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + grad_mask = grad_mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + mask = mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + grad_out = grad_out + .reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_weight_grps, + n_out_channels / n_weight_grps, + out_h, + out_w}) + .permute({0, 2, 3, 1, 4, 5}); + + weight = weight.reshape( + {n_weight_grps, + weight.size(0) / n_weight_grps, + weight.size(1), + weight.size(2), + weight.size(3)}); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + for (int64_t elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + columns.zero_(); + // Separate into weight groups + for (int64_t g = 0; g < n_weight_grps; g++) { + columns[g] = columns[g].addmm_( + weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); + } + + compute_grad_offset_and_mask( + columns, + input[elt], + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + grad_offset[elt], + grad_mask[elt]); + + compute_grad_input( + columns, + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + grad_input[elt]); + } + + grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); + grad_offset = grad_offset.view( + {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); + + if (use_mask) { + grad_mask = grad_mask.view( + {batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w}); + } + + return std::make_tuple(grad_input, grad_offset, grad_mask); +} + +at::Tensor backward_gradient_parameters( + at::Tensor input, + const at::Tensor& weight, + at::Tensor offset, + at::Tensor mask, + const at::Tensor& grad_out, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + int64_t n_parallel_imgs, + bool use_mask) { + + int64_t batch_sz = input.size(0); + int64_t n_in_channels = input.size(1); + int64_t in_h = input.size(2); + int64_t in_w = input.size(3); + + n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); + + int64_t n_out_channels = weight.size(0); + int64_t weight_h = weight.size(2); + int64_t weight_w = weight.size(3); + + int64_t out_h = grad_out.size(2); + int64_t out_w = grad_out.size(3); + + auto grad_weight = at::zeros_like(weight); + if (batch_sz == 0) { + return grad_weight; + } + + at::Tensor grad_out_buf = grad_out + .reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_weight_grps, + n_out_channels / n_weight_grps, + out_h, + out_w}) + .permute({0, 2, 3, 1, 4, 5}) + .contiguous(); + + input = input.reshape( + {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); + + offset = offset.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + mask = mask.reshape( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + grad_weight = grad_weight.reshape( + {n_weight_grps, + grad_weight.size(0) / n_weight_grps, + grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3)}); + + auto columns = at::empty( + {n_weight_grps, + n_in_channels * weight_w * weight_h / n_weight_grps, + n_parallel_imgs * out_h * out_w}, + input.options()); + + for (int64_t elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { + deformable_im2col( + input[elt], + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + columns); + + for (int64_t g = 0; g < n_weight_grps; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_( + grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) + .view_as(grad_weight[g]); + } + } + + grad_weight = grad_weight.view( + {grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3), + grad_weight.size(4)}); + return grad_weight; +} + +at::Tensor deform_conv2d_forward_kernel( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + + + at::Tensor input_c = input.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + TORCH_CHECK(input_c.ndimension() == 4); + TORCH_CHECK(offset_c.ndimension() == 4); + TORCH_CHECK(!use_mask || mask_c.ndimension() == 4); + TORCH_CHECK(weight_c.ndimension() == 4); + TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor"); + + at::DeviceGuard guard(input_c.device()); + + int batch_sz = input_c.size(0); + int in_channels = input_c.size(1); + int in_h = input_c.size(2); + int in_w = input_c.size(3); + + int n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, tkMaxParallelImgs); + + int out_channels = weight_c.size(0); + int weight_h = weight_c.size(2); + int weight_w = weight_c.size(3); + + int ker_h = dilation_h * (weight_h - 1) + 1; + int ker_w = dilation_w * (weight_w - 1) + 1; + int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; + + TORCH_CHECK( + weight_h > 0 && weight_w > 0, + "weight_h: ", + weight_h, + " weight_w: ", + weight_w); + TORCH_CHECK( + stride_h > 0 && stride_w > 0, + "stride_h: ", + stride_h, + " stride_w: ", + stride_w); + TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); + TORCH_CHECK( + dilation_h > 0 && dilation_w > 0, + "dilation_h: ", + dilation_h, + " dilation_w: ", + dilation_w); + + TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1)); + TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0); + TORCH_CHECK( + (offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w), + "offset.shape[1] is not valid: got: ", + offset_c.size(1), + " expected: ", + n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK( + (!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w), + "mask.shape[1] is not valid: got: ", + mask_c.size(1), + " expected: ", + n_offset_grps * weight_h * weight_w); + TORCH_CHECK(input_c.size(1) % n_offset_grps == 0); + + TORCH_CHECK( + (offset_c.size(0) == input_c.size(0)), "invalid batch size of offset"); + TORCH_CHECK( + (offset_c.size(2) == out_h && offset_c.size(3) == out_w), + "offset output dims: (", + offset_c.size(2), + ", ", + offset_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + (mask_c.size(0) == input_c.size(0)), "invalid batch size of mask"); + TORCH_CHECK( + (!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)), + "mask output dims: (", + mask_c.size(2), + ", ", + mask_c.size(3), + ") - ", + "computed output dims: (", + out_h, + ", ", + out_w, + ")"); + TORCH_CHECK( + out_h > 0 && out_w > 0, + "Calculated output size too small - out_h: ", + out_h, + " out_w: ", + out_w); + + auto out = + at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options()); + if (batch_sz == 0) { + return out; + } + + // Separate batches into blocks + out = out.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + out_channels, + out_h, + out_w}); + input_c = input_c.view( + {batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); + + offset_c = offset_c.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w}); + + if (use_mask) { + mask_c = mask_c.view( + {batch_sz / n_parallel_imgs, + n_parallel_imgs, + n_offset_grps * weight_h * weight_w, + out_h, + out_w}); + } + + at::Tensor out_buf = at::zeros( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs * out_h, + out_w}, + out.options()); + + // Separate channels into convolution groups + out_buf = out_buf.view( + {out_buf.size(0), + n_weight_grps, + out_buf.size(1) / n_weight_grps, + out_buf.size(2), + out_buf.size(3)}); + weight_c = weight_c.view( + {n_weight_grps, + weight_c.size(0) / n_weight_grps, + weight_c.size(1), + weight_c.size(2), + weight_c.size(3)}); + + // Sample points and perform convolution + auto columns = at::zeros( + {in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, + input_c.options()); + + for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { + deformable_im2col( + input_c[b], + offset_c[b], + mask_c[b], + in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + out_h, + out_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + columns); + + columns = columns.view( + {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + for (int g = 0; g < n_weight_grps; g++) { + out_buf[b][g] = out_buf[b][g] + .flatten(1) + .addmm_(weight_c[g].flatten(1), columns[g]) + .view_as(out_buf[b][g]); + } + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + out_buf = out_buf.view( + {batch_sz / n_parallel_imgs, + out_channels, + n_parallel_imgs, + out_h, + out_w}); + out_buf.transpose_(1, 2); + out.copy_(out_buf); + out = out.view({batch_sz, out_channels, out_h, out_w}); + + return out + bias_c.view({1, out_channels, 1, 1}); +} + +std::tuple +deform_conv2d_backward_kernel( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + at::Tensor grad_out_c = grad_out.contiguous(); + at::Tensor input_c = input.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + const int64_t batch_sz = input_c.size(0); + const int64_t n_parallel_imgs = + get_greatest_divisor_below_bound(batch_sz, tkMaxParallelImgs); + + auto grad_input_and_offset_and_mask = backward_gradient_inputs( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + n_weight_grps, + n_offset_grps, + n_parallel_imgs, + use_mask); + + auto grad_input = std::get<0>(grad_input_and_offset_and_mask); + auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); + auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); + + auto grad_weight = backward_gradient_parameters( + input_c, + weight_c, + offset_c, + mask_c, + grad_out_c, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + n_weight_grps, + n_offset_grps, + n_parallel_imgs, + use_mask); + + auto value = grad_out_c.sum({0, 2, 3}); + auto grad_bias = at::ones_like(bias_c) * value; + + return std::make_tuple( + grad_input, grad_weight, grad_offset, grad_mask, grad_bias); +} +} // namespace + + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), + TORCH_FN(deform_conv2d_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), + TORCH_FN(deform_conv2d_backward_kernel)); +} + +} // namespace ops +} // namespace vision + From 8a984dee5606eaa54aa7d8e63b2f0972a8980973 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Sat, 16 Nov 2024 13:49:06 +0100 Subject: [PATCH 07/31] Remove files in product dir --- product/.DS_Store | Bin 6148 -> 0 bytes .../torchvision/io/image/cpu/common_jpeg.cpp | 26 - .../torchvision/io/image/cpu/common_jpeg.h | 27 - .../torchvision/io/image/cpu/common_png.h | 6 - .../torchvision/io/image/cpu/decode_avif.cpp | 92 -- .../torchvision/io/image/cpu/decode_avif.h | 11 - .../torchvision/io/image/cpu/decode_gif.cpp | 173 --- .../torchvision/io/image/cpu/decode_gif.h | 12 - .../torchvision/io/image/cpu/decode_image.cpp | 77 - .../torchvision/io/image/cpu/decode_image.h | 15 - .../torchvision/io/image/cpu/decode_jpeg.cpp | 271 ---- .../torchvision/io/image/cpu/decode_jpeg.h | 18 - .../torchvision/io/image/cpu/decode_png.cpp | 232 --- .../torchvision/io/image/cpu/decode_png.h | 15 - .../torchvision/io/image/cpu/decode_webp.cpp | 40 - .../torchvision/io/image/cpu/decode_webp.h | 11 - .../torchvision/io/image/cpu/encode_jpeg.cpp | 113 -- .../torchvision/io/image/cpu/encode_jpeg.h | 13 - .../torchvision/io/image/cpu/encode_png.cpp | 180 --- .../torchvision/io/image/cpu/encode_png.h | 13 - .../include/torchvision/io/image/cpu/exif.h | 256 ---- .../io/image/cpu/giflib/dgif_lib.c | 1312 ----------------- .../io/image/cpu/giflib/gif_hash.c | 128 -- .../io/image/cpu/giflib/gif_hash.h | 42 - .../torchvision/io/image/cpu/giflib/gif_lib.h | 291 ---- .../io/image/cpu/giflib/gif_lib_private.h | 72 - .../io/image/cpu/giflib/gifalloc.c | 425 ------ .../image/cpu/giflib/openbsd-reallocarray.c | 73 - .../io/image/cpu/read_write_file.cpp | 108 -- .../io/image/cpu/read_write_file.h | 13 - .../io/image/cuda/decode_jpegs_cuda.cpp | 603 -------- .../io/image/cuda/decode_jpegs_cuda.h | 45 - .../io/image/cuda/encode_decode_jpegs_cuda.h | 59 - .../io/image/cuda/encode_jpegs_cuda.cpp | 274 ---- .../io/image/cuda/encode_jpegs_cuda.h | 33 - .../include/torchvision/io/image/image.cpp | 37 - product/include/torchvision/io/image/image.h | 12 - .../torchvision/io/image/image_read_mode.h | 17 - product/include/torchvision/macros.h | 11 - .../ops/autograd/deform_conv2d_kernel.cpp | 266 ---- .../ops/autograd/ps_roi_align_kernel.cpp | 167 --- .../ops/autograd/ps_roi_pool_kernel.cpp | 152 -- .../ops/autograd/roi_align_kernel.cpp | 167 --- .../ops/autograd/roi_pool_kernel.cpp | 152 -- .../ops/cpu/deform_conv2d_kernel.cpp | 1172 --------------- .../torchvision/ops/cpu/nms_kernel.cpp | 117 -- .../ops/cpu/ps_roi_align_kernel.cpp | 429 ------ .../ops/cpu/ps_roi_pool_kernel.cpp | 273 ---- .../torchvision/ops/cpu/roi_align_common.h | 128 -- .../torchvision/ops/cpu/roi_align_kernel.cpp | 400 ----- .../torchvision/ops/cpu/roi_pool_kernel.cpp | 249 ---- .../include/torchvision/ops/deform_conv2d.cpp | 172 --- .../include/torchvision/ops/deform_conv2d.h | 82 -- .../include/torchvision/ops/mps/mps_helpers.h | 6 - .../include/torchvision/ops/mps/mps_kernels.h | 1102 -------------- .../include/torchvision/ops/mps/nms_kernel.mm | 109 -- .../ops/mps/ps_roi_align_kernel.mm | 205 --- .../torchvision/ops/mps/ps_roi_pool_kernel.mm | 200 --- .../torchvision/ops/mps/roi_align_kernel.mm | 197 --- .../torchvision/ops/mps/roi_pool_kernel.mm | 196 --- product/include/torchvision/ops/nms.cpp | 28 - product/include/torchvision/ops/nms.h | 15 - product/include/torchvision/ops/ops.h | 8 - .../include/torchvision/ops/ps_roi_align.cpp | 112 -- .../include/torchvision/ops/ps_roi_align.h | 56 - .../include/torchvision/ops/ps_roi_pool.cpp | 104 -- product/include/torchvision/ops/ps_roi_pool.h | 52 - product/include/torchvision/ops/roi_align.cpp | 132 -- product/include/torchvision/ops/roi_align.h | 58 - product/include/torchvision/ops/roi_pool.cpp | 102 -- product/include/torchvision/ops/roi_pool.h | 52 - product/include/torchvision/vision.cpp | 32 - product/include/torchvision/vision.h | 12 - .../cmake/TorchVision/TorchVisionConfig.cmake | 74 - .../TorchVisionConfigVersion.cmake | 43 - .../TorchVisionTargets-noconfig.cmake | 20 - .../TorchVision/TorchVisionTargets.cmake | 102 -- .../csrc/ops/mps/deform_conv2d_kernel.mm | 5 +- 78 files changed, 3 insertions(+), 12101 deletions(-) delete mode 100644 product/.DS_Store delete mode 100644 product/include/torchvision/io/image/cpu/common_jpeg.cpp delete mode 100644 product/include/torchvision/io/image/cpu/common_jpeg.h delete mode 100644 product/include/torchvision/io/image/cpu/common_png.h delete mode 100644 product/include/torchvision/io/image/cpu/decode_avif.cpp delete mode 100644 product/include/torchvision/io/image/cpu/decode_avif.h delete mode 100644 product/include/torchvision/io/image/cpu/decode_gif.cpp delete mode 100644 product/include/torchvision/io/image/cpu/decode_gif.h delete mode 100644 product/include/torchvision/io/image/cpu/decode_image.cpp delete mode 100644 product/include/torchvision/io/image/cpu/decode_image.h delete mode 100644 product/include/torchvision/io/image/cpu/decode_jpeg.cpp delete mode 100644 product/include/torchvision/io/image/cpu/decode_jpeg.h delete mode 100644 product/include/torchvision/io/image/cpu/decode_png.cpp delete mode 100644 product/include/torchvision/io/image/cpu/decode_png.h delete mode 100644 product/include/torchvision/io/image/cpu/decode_webp.cpp delete mode 100644 product/include/torchvision/io/image/cpu/decode_webp.h delete mode 100644 product/include/torchvision/io/image/cpu/encode_jpeg.cpp delete mode 100644 product/include/torchvision/io/image/cpu/encode_jpeg.h delete mode 100644 product/include/torchvision/io/image/cpu/encode_png.cpp delete mode 100644 product/include/torchvision/io/image/cpu/encode_png.h delete mode 100644 product/include/torchvision/io/image/cpu/exif.h delete mode 100644 product/include/torchvision/io/image/cpu/giflib/dgif_lib.c delete mode 100644 product/include/torchvision/io/image/cpu/giflib/gif_hash.c delete mode 100644 product/include/torchvision/io/image/cpu/giflib/gif_hash.h delete mode 100644 product/include/torchvision/io/image/cpu/giflib/gif_lib.h delete mode 100644 product/include/torchvision/io/image/cpu/giflib/gif_lib_private.h delete mode 100644 product/include/torchvision/io/image/cpu/giflib/gifalloc.c delete mode 100644 product/include/torchvision/io/image/cpu/giflib/openbsd-reallocarray.c delete mode 100644 product/include/torchvision/io/image/cpu/read_write_file.cpp delete mode 100644 product/include/torchvision/io/image/cpu/read_write_file.h delete mode 100644 product/include/torchvision/io/image/cuda/decode_jpegs_cuda.cpp delete mode 100644 product/include/torchvision/io/image/cuda/decode_jpegs_cuda.h delete mode 100644 product/include/torchvision/io/image/cuda/encode_decode_jpegs_cuda.h delete mode 100644 product/include/torchvision/io/image/cuda/encode_jpegs_cuda.cpp delete mode 100644 product/include/torchvision/io/image/cuda/encode_jpegs_cuda.h delete mode 100644 product/include/torchvision/io/image/image.cpp delete mode 100644 product/include/torchvision/io/image/image.h delete mode 100644 product/include/torchvision/io/image/image_read_mode.h delete mode 100644 product/include/torchvision/macros.h delete mode 100644 product/include/torchvision/ops/autograd/deform_conv2d_kernel.cpp delete mode 100644 product/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp delete mode 100644 product/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp delete mode 100644 product/include/torchvision/ops/autograd/roi_align_kernel.cpp delete mode 100644 product/include/torchvision/ops/autograd/roi_pool_kernel.cpp delete mode 100644 product/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp delete mode 100644 product/include/torchvision/ops/cpu/nms_kernel.cpp delete mode 100644 product/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp delete mode 100644 product/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp delete mode 100644 product/include/torchvision/ops/cpu/roi_align_common.h delete mode 100644 product/include/torchvision/ops/cpu/roi_align_kernel.cpp delete mode 100644 product/include/torchvision/ops/cpu/roi_pool_kernel.cpp delete mode 100644 product/include/torchvision/ops/deform_conv2d.cpp delete mode 100644 product/include/torchvision/ops/deform_conv2d.h delete mode 100644 product/include/torchvision/ops/mps/mps_helpers.h delete mode 100644 product/include/torchvision/ops/mps/mps_kernels.h delete mode 100644 product/include/torchvision/ops/mps/nms_kernel.mm delete mode 100644 product/include/torchvision/ops/mps/ps_roi_align_kernel.mm delete mode 100644 product/include/torchvision/ops/mps/ps_roi_pool_kernel.mm delete mode 100644 product/include/torchvision/ops/mps/roi_align_kernel.mm delete mode 100644 product/include/torchvision/ops/mps/roi_pool_kernel.mm delete mode 100644 product/include/torchvision/ops/nms.cpp delete mode 100644 product/include/torchvision/ops/nms.h delete mode 100644 product/include/torchvision/ops/ops.h delete mode 100644 product/include/torchvision/ops/ps_roi_align.cpp delete mode 100644 product/include/torchvision/ops/ps_roi_align.h delete mode 100644 product/include/torchvision/ops/ps_roi_pool.cpp delete mode 100644 product/include/torchvision/ops/ps_roi_pool.h delete mode 100644 product/include/torchvision/ops/roi_align.cpp delete mode 100644 product/include/torchvision/ops/roi_align.h delete mode 100644 product/include/torchvision/ops/roi_pool.cpp delete mode 100644 product/include/torchvision/ops/roi_pool.h delete mode 100644 product/include/torchvision/vision.cpp delete mode 100644 product/include/torchvision/vision.h delete mode 100644 product/share/cmake/TorchVision/TorchVisionConfig.cmake delete mode 100644 product/share/cmake/TorchVision/TorchVisionConfigVersion.cmake delete mode 100644 product/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake delete mode 100644 product/share/cmake/TorchVision/TorchVisionTargets.cmake diff --git a/product/.DS_Store b/product/.DS_Store deleted file mode 100644 index 773050d530cd0ef7bf753032e0c2257470adce8f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKO-sW-5S?wSR`iggHv?Y0_9k8uOB6g8@3oECLSl+i(33g(%e?y!{3*WK8PkMB z@Ki)*VE1k2V_)*JWV1x%#_M!M6cEu6jj^|nuERLatz|v+90Hx6V@)aDW%r}Wyl7>c z<1Z?}?{1Hl^gwgUsP_KymA~EH{%boUFOp=I6*EMJ$JggCLG*snm6iS?Yqj9jqHMAP z`?!-O&14;EB>H#*gGTE7&u0~X72WW*qYb|36O+;oJ<)=0>G6B*Dkq)XG^2IitJ8|N zR-YFe?(U2RKgG;%W(t@Brogrpz@E)EJP@?r6fgx$fl2}XK7?qDv0^FcKON}&5db)V z+Z)Dwmf)IDF;*-E5rH`=1xl&YBZiZ5_(RQ$6-z-WC+C?lk1{(wp*YVDf9TW6#e&wG z0;WJyfxcV@y#HVReEx41*_A0^3j8Yt+#s1GV?2`Xt*wXSy*5OfqOoyaDOiYC=<_-Rk2f$dd6odz6KLTC`>r8>ID)0rT`(E7u diff --git a/product/include/torchvision/io/image/cpu/common_jpeg.cpp b/product/include/torchvision/io/image/cpu/common_jpeg.cpp deleted file mode 100644 index 4c993106b45..00000000000 --- a/product/include/torchvision/io/image/cpu/common_jpeg.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include "common_jpeg.h" - -namespace vision { -namespace image { -namespace detail { - -#if JPEG_FOUND -void torch_jpeg_error_exit(j_common_ptr cinfo) { - /* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce - * pointer */ - torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err; - - /* Always display the message. */ - /* We could postpone this until after returning, if we chose. */ - // (*cinfo->err->output_message)(cinfo); - /* Create the message */ - (*(cinfo->err->format_message))(cinfo, myerr->jpegLastErrorMsg); - - /* Return control to the setjmp point */ - longjmp(myerr->setjmp_buffer, 1); -} -#endif - -} // namespace detail -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/common_jpeg.h b/product/include/torchvision/io/image/cpu/common_jpeg.h deleted file mode 100644 index 7f7f9f0ccf1..00000000000 --- a/product/include/torchvision/io/image/cpu/common_jpeg.h +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#if JPEG_FOUND -#include - -#include -#include - -namespace vision { -namespace image { -namespace detail { - -static const JOCTET EOI_BUFFER[1] = {JPEG_EOI}; -struct torch_jpeg_error_mgr { - struct jpeg_error_mgr pub; /* "public" fields */ - char jpegLastErrorMsg[JMSG_LENGTH_MAX]; /* error messages */ - jmp_buf setjmp_buffer; /* for return to caller */ -}; - -using torch_jpeg_error_ptr = struct torch_jpeg_error_mgr*; -void torch_jpeg_error_exit(j_common_ptr cinfo); - -} // namespace detail -} // namespace image -} // namespace vision - -#endif diff --git a/product/include/torchvision/io/image/cpu/common_png.h b/product/include/torchvision/io/image/cpu/common_png.h deleted file mode 100644 index 68400d48e05..00000000000 --- a/product/include/torchvision/io/image/cpu/common_png.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once - -#if PNG_FOUND -#include -#include -#endif diff --git a/product/include/torchvision/io/image/cpu/decode_avif.cpp b/product/include/torchvision/io/image/cpu/decode_avif.cpp deleted file mode 100644 index ec136743806..00000000000 --- a/product/include/torchvision/io/image/cpu/decode_avif.cpp +++ /dev/null @@ -1,92 +0,0 @@ -#include "decode_avif.h" - -#if AVIF_FOUND -#include "avif/avif.h" -#endif // AVIF_FOUND - -namespace vision { -namespace image { - -#if !AVIF_FOUND -torch::Tensor decode_avif(const torch::Tensor& data) { - TORCH_CHECK( - false, "decode_avif: torchvision not compiled with libavif support"); -} -#else - -// This normally comes from avif_cxx.h, but it's not always present when -// installing libavif. So we just copy/paste it here. -struct UniquePtrDeleter { - void operator()(avifDecoder* decoder) const { - avifDecoderDestroy(decoder); - } -}; -using DecoderPtr = std::unique_ptr; - -torch::Tensor decode_avif(const torch::Tensor& encoded_data) { - // This is based on - // https://github.com/AOMediaCodec/libavif/blob/main/examples/avif_example_decode_memory.c - // Refer there for more detail about what each function does, and which - // structure/data is available after which call. - - TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); - TORCH_CHECK( - encoded_data.dtype() == torch::kU8, - "Input tensor must have uint8 data type, got ", - encoded_data.dtype()); - TORCH_CHECK( - encoded_data.dim() == 1, - "Input tensor must be 1-dimensional, got ", - encoded_data.dim(), - " dims."); - - DecoderPtr decoder(avifDecoderCreate()); - TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder."); - - auto result = AVIF_RESULT_UNKNOWN_ERROR; - result = avifDecoderSetIOMemory( - decoder.get(), encoded_data.data_ptr(), encoded_data.numel()); - TORCH_CHECK( - result == AVIF_RESULT_OK, - "avifDecoderSetIOMemory failed:", - avifResultToString(result)); - - result = avifDecoderParse(decoder.get()); - TORCH_CHECK( - result == AVIF_RESULT_OK, - "avifDecoderParse failed: ", - avifResultToString(result)); - TORCH_CHECK( - decoder->imageCount == 1, "Avif file contains more than one image"); - TORCH_CHECK( - decoder->image->depth <= 8, - "avif images with bitdepth > 8 are not supported"); - - result = avifDecoderNextImage(decoder.get()); - TORCH_CHECK( - result == AVIF_RESULT_OK, - "avifDecoderNextImage failed:", - avifResultToString(result)); - - auto out = torch::empty( - {decoder->image->height, decoder->image->width, 3}, torch::kUInt8); - - avifRGBImage rgb; - memset(&rgb, 0, sizeof(rgb)); - avifRGBImageSetDefaults(&rgb, decoder->image); - rgb.format = AVIF_RGB_FORMAT_RGB; - rgb.pixels = out.data_ptr(); - rgb.rowBytes = rgb.width * avifRGBImagePixelSize(&rgb); - - result = avifImageYUVToRGB(decoder->image, &rgb); - TORCH_CHECK( - result == AVIF_RESULT_OK, - "avifImageYUVToRGB failed: ", - avifResultToString(result)); - - return out.permute({2, 0, 1}); // return CHW, channels-last -} -#endif // AVIF_FOUND - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_avif.h b/product/include/torchvision/io/image/cpu/decode_avif.h deleted file mode 100644 index 269bce52197..00000000000 --- a/product/include/torchvision/io/image/cpu/decode_avif.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#include - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor decode_avif(const torch::Tensor& data); - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_gif.cpp b/product/include/torchvision/io/image/cpu/decode_gif.cpp deleted file mode 100644 index 183d42e86a4..00000000000 --- a/product/include/torchvision/io/image/cpu/decode_gif.cpp +++ /dev/null @@ -1,173 +0,0 @@ -#include "decode_gif.h" -#include -#include "giflib/gif_lib.h" - -namespace vision { -namespace image { - -typedef struct reader_helper_t { - uint8_t const* encoded_data; // input tensor data pointer - size_t encoded_data_size; // size of input tensor in bytes - size_t num_bytes_read; // number of bytes read so far in the tensor -} reader_helper_t; - -// That function is used by GIFLIB routines to read the encoded bytes. -// This reads `len` bytes and writes them into `buf`. The data is read from the -// input tensor passed to decode_gif() starting at the `num_bytes_read` -// position. -int read_from_tensor(GifFileType* gifFile, GifByteType* buf, int len) { - // the UserData field was set in DGifOpen() - reader_helper_t* reader_helper = - static_cast(gifFile->UserData); - - size_t num_bytes_to_read = std::min( - (size_t)len, - reader_helper->encoded_data_size - reader_helper->num_bytes_read); - std::memcpy( - buf, reader_helper->encoded_data + reader_helper->num_bytes_read, len); - reader_helper->num_bytes_read += num_bytes_to_read; - return num_bytes_to_read; -} - -torch::Tensor decode_gif(const torch::Tensor& encoded_data) { - // LibGif docs: https://giflib.sourceforge.net/intro.html - // Refer over there for more details on the libgif API, API ref, and a - // detailed description of the GIF format. - - TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); - TORCH_CHECK( - encoded_data.dtype() == torch::kU8, - "Input tensor must have uint8 data type, got ", - encoded_data.dtype()); - TORCH_CHECK( - encoded_data.dim() == 1, - "Input tensor must be 1-dimensional, got ", - encoded_data.dim(), - " dims."); - - int error = D_GIF_SUCCEEDED; - - // We're using DGidOpen. The other entrypoints of libgif are - // DGifOpenFileName and DGifOpenFileHandle but we don't want to use those, - // since we need to read the encoded bytes from a tensor of encoded bytes, not - // from a file (for consistency with existing jpeg and png decoders). Using - // DGifOpen is the only way to read from a custom source. - // For that we need to provide a reader function `read_from_tensor` that - // reads from the tensor, and we have to keep track of the number of bytes - // read so far: this is why we need the reader_helper struct. - - // TODO: We are potentially doing an unnecessary copy of the encoded bytes: - // - 1 copy in from file to tensor (in read_file()) - // - 1 copy from tensor to GIFLIB buffers (in read_from_tensor()) - // Since we're vendoring GIFLIB we can potentially modify the calls to - // InternalRead() and just set the `buf` pointer to the tensor data directly. - // That might even save allocation of those buffers. - // If we do that, we'd have to make sure the buffers are never written to by - // GIFLIB, otherwise we'd be overridding the tensor data. - reader_helper_t reader_helper; - reader_helper.encoded_data = encoded_data.data_ptr(); - reader_helper.encoded_data_size = encoded_data.numel(); - reader_helper.num_bytes_read = 0; - GifFileType* gifFile = - DGifOpen(static_cast(&reader_helper), read_from_tensor, &error); - - TORCH_CHECK( - (gifFile != nullptr) && (error == D_GIF_SUCCEEDED), - "DGifOpenFileName() failed - ", - error); - - if (DGifSlurp(gifFile) == GIF_ERROR) { - auto gifFileError = gifFile->Error; - DGifCloseFile(gifFile, &error); - TORCH_CHECK(false, "DGifSlurp() failed - ", gifFileError); - } - auto num_images = gifFile->ImageCount; - - // This check should already done within DGifSlurp(), just to be safe - TORCH_CHECK(num_images > 0, "GIF file should contain at least one image!"); - - GifColorType bg = {0, 0, 0}; - if (gifFile->SColorMap) { - bg = gifFile->SColorMap->Colors[gifFile->SBackGroundColor]; - } - - // The GIFLIB docs say that the canvas's height and width are potentially - // ignored by modern viewers, so to be on the safe side we set the output - // height to max(canvas_heigh, first_image_height). Same for width. - // https://giflib.sourceforge.net/whatsinagif/bits_and_bytes.html - auto out_h = - std::max(gifFile->SHeight, gifFile->SavedImages[0].ImageDesc.Height); - auto out_w = - std::max(gifFile->SWidth, gifFile->SavedImages[0].ImageDesc.Width); - - // We output a channels-last tensor for consistency with other image decoders. - // Torchvision's resize tends to be is faster on uint8 channels-last tensors. - auto options = torch::TensorOptions() - .dtype(torch::kU8) - .memory_format(torch::MemoryFormat::ChannelsLast); - auto out = torch::empty( - {int64_t(num_images), 3, int64_t(out_h), int64_t(out_w)}, options); - auto out_a = out.accessor(); - for (int i = 0; i < num_images; i++) { - const SavedImage& img = gifFile->SavedImages[i]; - - GraphicsControlBlock gcb; - DGifSavedExtensionToGCB(gifFile, i, &gcb); - - const GifImageDesc& desc = img.ImageDesc; - const ColorMapObject* cmap = - desc.ColorMap ? desc.ColorMap : gifFile->SColorMap; - TORCH_CHECK( - cmap != nullptr, - "Global and local color maps are missing. This should never happen!"); - - // When going from one image to another, there is a "disposal method" which - // specifies how to handle the transition. E.g. DISPOSE_DO_NOT means that - // the current image should essentially be drawn on top of the previous - // canvas. The pixels of that previous canvas will appear on the new one if - // either: - // - a pixel is transparent in the current image - // - the current image is smaller than the canvas, hence exposing its pixels - // The "background" disposal method means that the current canvas should be - // set to the background color. - // We only support these 2 modes and default to "background" when the - // disposal method is unspecified, or when it's set to "DISPOSE_PREVIOUS" - // which according to GIFLIB is not widely supported. - // (https://giflib.sourceforge.net/whatsinagif/animation_and_transparency.html). - if (i > 0 && gcb.DisposalMode == DISPOSE_DO_NOT) { - out[i] = out[i - 1]; - } else { - // Background. If bg wasn't defined, it will be (0, 0, 0) - for (int h = 0; h < gifFile->SHeight; h++) { - for (int w = 0; w < gifFile->SWidth; w++) { - out_a[i][0][h][w] = bg.Red; - out_a[i][1][h][w] = bg.Green; - out_a[i][2][h][w] = bg.Blue; - } - } - } - - for (int h = 0; h < desc.Height; h++) { - for (int w = 0; w < desc.Width; w++) { - auto c = img.RasterBits[h * desc.Width + w]; - if (c == gcb.TransparentColor) { - continue; - } - GifColorType rgb = cmap->Colors[c]; - out_a[i][0][h + desc.Top][w + desc.Left] = rgb.Red; - out_a[i][1][h + desc.Top][w + desc.Left] = rgb.Green; - out_a[i][2][h + desc.Top][w + desc.Left] = rgb.Blue; - } - } - } - - out = out.squeeze(0); // remove batch dim if there's only one image - - DGifCloseFile(gifFile, &error); - TORCH_CHECK(error == D_GIF_SUCCEEDED, "DGifCloseFile() failed - ", error); - - return out; -} - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_gif.h b/product/include/torchvision/io/image/cpu/decode_gif.h deleted file mode 100644 index 68d5073c91b..00000000000 --- a/product/include/torchvision/io/image/cpu/decode_gif.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -namespace vision { -namespace image { - -// encoded_data tensor must be 1D uint8 and contiguous -C10_EXPORT torch::Tensor decode_gif(const torch::Tensor& encoded_data); - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_image.cpp b/product/include/torchvision/io/image/cpu/decode_image.cpp deleted file mode 100644 index 75c7e06195a..00000000000 --- a/product/include/torchvision/io/image/cpu/decode_image.cpp +++ /dev/null @@ -1,77 +0,0 @@ -#include "decode_image.h" - -#include "decode_avif.h" -#include "decode_gif.h" -#include "decode_jpeg.h" -#include "decode_png.h" -#include "decode_webp.h" - -namespace vision { -namespace image { - -torch::Tensor decode_image( - const torch::Tensor& data, - ImageReadMode mode, - bool apply_exif_orientation) { - // Check that tensor is a CPU tensor - TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor"); - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - // Check that the input tensor is 1-dimensional - TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); - - auto err_msg = - "Unsupported image file. Only jpeg, png and gif are currently supported."; - - auto datap = data.data_ptr(); - - const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF" - TORCH_CHECK(data.numel() >= 3, err_msg); - if (memcmp(jpeg_signature, datap, 3) == 0) { - return decode_jpeg(data, mode, apply_exif_orientation); - } - - const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" - TORCH_CHECK(data.numel() >= 4, err_msg); - if (memcmp(png_signature, datap, 4) == 0) { - return decode_png(data, mode, apply_exif_orientation); - } - - const uint8_t gif_signature_1[6] = { - 0x47, 0x49, 0x46, 0x38, 0x39, 0x61}; // == "GIF89a" - const uint8_t gif_signature_2[6] = { - 0x47, 0x49, 0x46, 0x38, 0x37, 0x61}; // == "GIF87a" - TORCH_CHECK(data.numel() >= 6, err_msg); - if (memcmp(gif_signature_1, datap, 6) == 0 || - memcmp(gif_signature_2, datap, 6) == 0) { - return decode_gif(data); - } - - // We assume the signature of an avif file is - // 0000 0020 6674 7970 6176 6966 - // xxxx xxxx f t y p a v i f - // We only check for the "ftyp avif" part. - // This is probably not perfect, but hopefully this should cover most files. - const uint8_t avif_signature[8] = { - 0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66}; // == "ftypavif" - TORCH_CHECK(data.numel() >= 12, err_msg); - if ((memcmp(avif_signature, datap + 4, 8) == 0)) { - return decode_avif(data); - } - - const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF" - const uint8_t webp_signature_end[7] = { - 0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8" - TORCH_CHECK(data.numel() >= 15, err_msg); - if ((memcmp(webp_signature_begin, datap, 4) == 0) && - (memcmp(webp_signature_end, datap + 8, 7) == 0)) { - return decode_webp(data); - } - - TORCH_CHECK(false, err_msg); -} - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_image.h b/product/include/torchvision/io/image/cpu/decode_image.h deleted file mode 100644 index f0e66d397ac..00000000000 --- a/product/include/torchvision/io/image/cpu/decode_image.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include -#include "../image_read_mode.h" - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor decode_image( - const torch::Tensor& data, - ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, - bool apply_exif_orientation = false); - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_jpeg.cpp b/product/include/torchvision/io/image/cpu/decode_jpeg.cpp deleted file mode 100644 index ec5953e4106..00000000000 --- a/product/include/torchvision/io/image/cpu/decode_jpeg.cpp +++ /dev/null @@ -1,271 +0,0 @@ -#include "decode_jpeg.h" -#include "common_jpeg.h" -#include "exif.h" - -namespace vision { -namespace image { - -#if !JPEG_FOUND -torch::Tensor decode_jpeg( - const torch::Tensor& data, - ImageReadMode mode, - bool apply_exif_orientation) { - TORCH_CHECK( - false, "decode_jpeg: torchvision not compiled with libjpeg support"); -} -#else - -using namespace detail; -using namespace exif_private; - -namespace { - -struct torch_jpeg_mgr { - struct jpeg_source_mgr pub; - const JOCTET* data; - size_t len; -}; - -static void torch_jpeg_init_source(j_decompress_ptr cinfo) {} - -static boolean torch_jpeg_fill_input_buffer(j_decompress_ptr cinfo) { - // No more data. Probably an incomplete image; Raise exception. - torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err; - strcpy(myerr->jpegLastErrorMsg, "Image is incomplete or truncated"); - longjmp(myerr->setjmp_buffer, 1); -} - -static void torch_jpeg_skip_input_data(j_decompress_ptr cinfo, long num_bytes) { - torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src; - if (src->pub.bytes_in_buffer < (size_t)num_bytes) { - // Skipping over all of remaining data; output EOI. - src->pub.next_input_byte = EOI_BUFFER; - src->pub.bytes_in_buffer = 1; - } else { - // Skipping over only some of the remaining data. - src->pub.next_input_byte += num_bytes; - src->pub.bytes_in_buffer -= num_bytes; - } -} - -static void torch_jpeg_term_source(j_decompress_ptr cinfo) {} - -static void torch_jpeg_set_source_mgr( - j_decompress_ptr cinfo, - const unsigned char* data, - size_t len) { - torch_jpeg_mgr* src; - if (cinfo->src == 0) { // if this is first time; allocate memory - cinfo->src = (struct jpeg_source_mgr*)(*cinfo->mem->alloc_small)( - (j_common_ptr)cinfo, JPOOL_PERMANENT, sizeof(torch_jpeg_mgr)); - } - src = (torch_jpeg_mgr*)cinfo->src; - src->pub.init_source = torch_jpeg_init_source; - src->pub.fill_input_buffer = torch_jpeg_fill_input_buffer; - src->pub.skip_input_data = torch_jpeg_skip_input_data; - src->pub.resync_to_restart = jpeg_resync_to_restart; // default - src->pub.term_source = torch_jpeg_term_source; - // fill the buffers - src->data = (const JOCTET*)data; - src->len = len; - src->pub.bytes_in_buffer = len; - src->pub.next_input_byte = src->data; - - jpeg_save_markers(cinfo, APP1, 0xffff); -} - -inline unsigned char clamped_cmyk_rgb_convert( - unsigned char k, - unsigned char cmy) { - // Inspired from Pillow: - // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569 - int v = k * cmy + 128; - v = ((v >> 8) + v) >> 8; - return std::clamp(k - v, 0, 255); -} - -void convert_line_cmyk_to_rgb( - j_decompress_ptr cinfo, - const unsigned char* cmyk_line, - unsigned char* rgb_line) { - int width = cinfo->output_width; - for (int i = 0; i < width; ++i) { - int c = cmyk_line[i * 4 + 0]; - int m = cmyk_line[i * 4 + 1]; - int y = cmyk_line[i * 4 + 2]; - int k = cmyk_line[i * 4 + 3]; - - rgb_line[i * 3 + 0] = clamped_cmyk_rgb_convert(k, 255 - c); - rgb_line[i * 3 + 1] = clamped_cmyk_rgb_convert(k, 255 - m); - rgb_line[i * 3 + 2] = clamped_cmyk_rgb_convert(k, 255 - y); - } -} - -inline unsigned char rgb_to_gray(int r, int g, int b) { - // Inspired from Pillow: - // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226 - return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16; -} - -void convert_line_cmyk_to_gray( - j_decompress_ptr cinfo, - const unsigned char* cmyk_line, - unsigned char* gray_line) { - int width = cinfo->output_width; - for (int i = 0; i < width; ++i) { - int c = cmyk_line[i * 4 + 0]; - int m = cmyk_line[i * 4 + 1]; - int y = cmyk_line[i * 4 + 2]; - int k = cmyk_line[i * 4 + 3]; - - int r = clamped_cmyk_rgb_convert(k, 255 - c); - int g = clamped_cmyk_rgb_convert(k, 255 - m); - int b = clamped_cmyk_rgb_convert(k, 255 - y); - - gray_line[i] = rgb_to_gray(r, g, b); - } -} - -} // namespace - -torch::Tensor decode_jpeg( - const torch::Tensor& data, - ImageReadMode mode, - bool apply_exif_orientation) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg"); - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - // Check that the input tensor is 1-dimensional - TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); - - struct jpeg_decompress_struct cinfo; - struct torch_jpeg_error_mgr jerr; - - auto datap = data.data_ptr(); - // Setup decompression structure - cinfo.err = jpeg_std_error(&jerr.pub); - jerr.pub.error_exit = torch_jpeg_error_exit; - /* Establish the setjmp return context for my_error_exit to use. */ - if (setjmp(jerr.setjmp_buffer)) { - /* If we get here, the JPEG code has signaled an error. - * We need to clean up the JPEG object. - */ - jpeg_destroy_decompress(&cinfo); - TORCH_CHECK(false, jerr.jpegLastErrorMsg); - } - - jpeg_create_decompress(&cinfo); - torch_jpeg_set_source_mgr(&cinfo, datap, data.numel()); - - // read info from header. - jpeg_read_header(&cinfo, TRUE); - - int channels = cinfo.num_components; - bool cmyk_to_rgb_or_gray = false; - - if (mode != IMAGE_READ_MODE_UNCHANGED) { - switch (mode) { - case IMAGE_READ_MODE_GRAY: - if (cinfo.jpeg_color_space == JCS_CMYK || - cinfo.jpeg_color_space == JCS_YCCK) { - cinfo.out_color_space = JCS_CMYK; - cmyk_to_rgb_or_gray = true; - } else { - cinfo.out_color_space = JCS_GRAYSCALE; - } - channels = 1; - break; - case IMAGE_READ_MODE_RGB: - if (cinfo.jpeg_color_space == JCS_CMYK || - cinfo.jpeg_color_space == JCS_YCCK) { - cinfo.out_color_space = JCS_CMYK; - cmyk_to_rgb_or_gray = true; - } else { - cinfo.out_color_space = JCS_RGB; - } - channels = 3; - break; - /* - * Libjpeg does not support converting from CMYK to grayscale etc. There - * is a way to do this but it involves converting it manually to RGB: - * https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313 - */ - default: - jpeg_destroy_decompress(&cinfo); - TORCH_CHECK(false, "The provided mode is not supported for JPEG files"); - } - - jpeg_calc_output_dimensions(&cinfo); - } - - int exif_orientation = -1; - if (apply_exif_orientation) { - exif_orientation = fetch_jpeg_exif_orientation(&cinfo); - } - - jpeg_start_decompress(&cinfo); - - int height = cinfo.output_height; - int width = cinfo.output_width; - - int stride = width * channels; - auto tensor = - torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); - auto ptr = tensor.data_ptr(); - torch::Tensor cmyk_line_tensor; - if (cmyk_to_rgb_or_gray) { - cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8); - } - - while (cinfo.output_scanline < cinfo.output_height) { - /* jpeg_read_scanlines expects an array of pointers to scanlines. - * Here the array is only one element long, but you could ask for - * more than one scanline at a time if that's more convenient. - */ - if (cmyk_to_rgb_or_gray) { - auto cmyk_line_ptr = cmyk_line_tensor.data_ptr(); - jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1); - - if (channels == 3) { - convert_line_cmyk_to_rgb(&cinfo, cmyk_line_ptr, ptr); - } else if (channels == 1) { - convert_line_cmyk_to_gray(&cinfo, cmyk_line_ptr, ptr); - } - } else { - jpeg_read_scanlines(&cinfo, &ptr, 1); - } - ptr += stride; - } - - jpeg_finish_decompress(&cinfo); - jpeg_destroy_decompress(&cinfo); - auto output = tensor.permute({2, 0, 1}); - - if (apply_exif_orientation) { - return exif_orientation_transform(output, exif_orientation); - } - return output; -} -#endif // #if !JPEG_FOUND - -int64_t _jpeg_version() { -#if JPEG_FOUND - return JPEG_LIB_VERSION; -#else - return -1; -#endif -} - -bool _is_compiled_against_turbo() { -#ifdef LIBJPEG_TURBO_VERSION - return true; -#else - return false; -#endif -} - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_jpeg.h b/product/include/torchvision/io/image/cpu/decode_jpeg.h deleted file mode 100644 index e0c9a24c846..00000000000 --- a/product/include/torchvision/io/image/cpu/decode_jpeg.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include -#include "../image_read_mode.h" - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor decode_jpeg( - const torch::Tensor& data, - ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, - bool apply_exif_orientation = false); - -C10_EXPORT int64_t _jpeg_version(); -C10_EXPORT bool _is_compiled_against_turbo(); - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_png.cpp b/product/include/torchvision/io/image/cpu/decode_png.cpp deleted file mode 100644 index ac14ae934a4..00000000000 --- a/product/include/torchvision/io/image/cpu/decode_png.cpp +++ /dev/null @@ -1,232 +0,0 @@ -#include "decode_png.h" -#include "common_png.h" -#include "exif.h" - -namespace vision { -namespace image { - -using namespace exif_private; - -#if !PNG_FOUND -torch::Tensor decode_png( - const torch::Tensor& data, - ImageReadMode mode, - bool apply_exif_orientation) { - TORCH_CHECK( - false, "decode_png: torchvision not compiled with libPNG support"); -} -#else - -bool is_little_endian() { - uint32_t x = 1; - return *(uint8_t*)&x; -} - -torch::Tensor decode_png( - const torch::Tensor& data, - ImageReadMode mode, - bool apply_exif_orientation) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png"); - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - // Check that the input tensor is 1-dimensional - TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); - - auto png_ptr = - png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); - TORCH_CHECK(png_ptr, "libpng read structure allocation failed!") - auto info_ptr = png_create_info_struct(png_ptr); - if (!info_ptr) { - png_destroy_read_struct(&png_ptr, nullptr, nullptr); - // Seems redundant with the if statement. done here to avoid leaking memory. - TORCH_CHECK(info_ptr, "libpng info structure allocation failed!") - } - - auto accessor = data.accessor(); - auto datap = accessor.data(); - auto datap_len = accessor.size(0); - - if (setjmp(png_jmpbuf(png_ptr)) != 0) { - png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, "Internal error."); - } - TORCH_CHECK(datap_len >= 8, "Content is too small for png!") - auto is_png = !png_sig_cmp(datap, 0, 8); - TORCH_CHECK(is_png, "Content is not png!") - - struct Reader { - png_const_bytep ptr; - png_size_t count; - } reader; - reader.ptr = png_const_bytep(datap) + 8; - reader.count = datap_len - 8; - - auto read_callback = [](png_structp png_ptr, - png_bytep output, - png_size_t bytes) { - auto reader = static_cast(png_get_io_ptr(png_ptr)); - TORCH_CHECK( - reader->count >= bytes, - "Out of bound read in decode_png. Probably, the input image is corrupted"); - std::copy(reader->ptr, reader->ptr + bytes, output); - reader->ptr += bytes; - reader->count -= bytes; - }; - png_set_sig_bytes(png_ptr, 8); - png_set_read_fn(png_ptr, &reader, read_callback); - png_read_info(png_ptr, info_ptr); - - png_uint_32 width, height; - int bit_depth, color_type; - int interlace_type; - auto retval = png_get_IHDR( - png_ptr, - info_ptr, - &width, - &height, - &bit_depth, - &color_type, - &interlace_type, - nullptr, - nullptr); - - if (retval != 1) { - png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(retval == 1, "Could read image metadata from content.") - } - - if (bit_depth > 8 && bit_depth != 16) { - png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK( - false, - "bit depth of png image is " + std::to_string(bit_depth) + - ". Only <=8 and 16 are supported.") - } - - int channels = png_get_channels(png_ptr, info_ptr); - - if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8) - png_set_expand_gray_1_2_4_to_8(png_ptr); - - int number_of_passes; - if (interlace_type == PNG_INTERLACE_ADAM7) { - number_of_passes = png_set_interlace_handling(png_ptr); - } else { - number_of_passes = 1; - } - - if (mode != IMAGE_READ_MODE_UNCHANGED) { - // TODO: consider supporting PNG_INFO_tRNS - bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0; - bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0; - bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0; - - switch (mode) { - case IMAGE_READ_MODE_GRAY: - if (color_type != PNG_COLOR_TYPE_GRAY) { - if (is_palette) { - png_set_palette_to_rgb(png_ptr); - has_alpha = true; - } - - if (has_alpha) { - png_set_strip_alpha(png_ptr); - } - - if (has_color) { - png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); - } - channels = 1; - } - break; - case IMAGE_READ_MODE_GRAY_ALPHA: - if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) { - if (is_palette) { - png_set_palette_to_rgb(png_ptr); - has_alpha = true; - } - - if (!has_alpha) { - png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); - } - - if (has_color) { - png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); - } - channels = 2; - } - break; - case IMAGE_READ_MODE_RGB: - if (color_type != PNG_COLOR_TYPE_RGB) { - if (is_palette) { - png_set_palette_to_rgb(png_ptr); - has_alpha = true; - } else if (!has_color) { - png_set_gray_to_rgb(png_ptr); - } - - if (has_alpha) { - png_set_strip_alpha(png_ptr); - } - channels = 3; - } - break; - case IMAGE_READ_MODE_RGB_ALPHA: - if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) { - if (is_palette) { - png_set_palette_to_rgb(png_ptr); - has_alpha = true; - } else if (!has_color) { - png_set_gray_to_rgb(png_ptr); - } - - if (!has_alpha) { - png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); - } - channels = 4; - } - break; - default: - png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, "The provided mode is not supported for PNG files"); - } - - png_read_update_info(png_ptr, info_ptr); - } - - auto num_pixels_per_row = width * channels; - auto is_16_bits = bit_depth == 16; - auto tensor = torch::empty( - {int64_t(height), int64_t(width), channels}, - is_16_bits ? at::kUInt16 : torch::kU8); - if (is_little_endian()) { - png_set_swap(png_ptr); - } - auto t_ptr = (uint8_t*)tensor.data_ptr(); - for (int pass = 0; pass < number_of_passes; pass++) { - for (png_uint_32 i = 0; i < height; ++i) { - png_read_row(png_ptr, t_ptr, nullptr); - t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1); - } - t_ptr = (uint8_t*)tensor.data_ptr(); - } - - int exif_orientation = -1; - if (apply_exif_orientation) { - exif_orientation = fetch_png_exif_orientation(png_ptr, info_ptr); - } - - png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - - auto output = tensor.permute({2, 0, 1}); - if (apply_exif_orientation) { - return exif_orientation_transform(output, exif_orientation); - } - return output; -} -#endif - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_png.h b/product/include/torchvision/io/image/cpu/decode_png.h deleted file mode 100644 index 0866711e987..00000000000 --- a/product/include/torchvision/io/image/cpu/decode_png.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include -#include "../image_read_mode.h" - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor decode_png( - const torch::Tensor& data, - ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, - bool apply_exif_orientation = false); - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_webp.cpp b/product/include/torchvision/io/image/cpu/decode_webp.cpp deleted file mode 100644 index 844ce61a3e3..00000000000 --- a/product/include/torchvision/io/image/cpu/decode_webp.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include "decode_webp.h" - -#if WEBP_FOUND -#include "webp/decode.h" -#endif // WEBP_FOUND - -namespace vision { -namespace image { - -#if !WEBP_FOUND -torch::Tensor decode_webp(const torch::Tensor& data) { - TORCH_CHECK( - false, "decode_webp: torchvision not compiled with libwebp support"); -} -#else - -torch::Tensor decode_webp(const torch::Tensor& encoded_data) { - TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); - TORCH_CHECK( - encoded_data.dtype() == torch::kU8, - "Input tensor must have uint8 data type, got ", - encoded_data.dtype()); - TORCH_CHECK( - encoded_data.dim() == 1, - "Input tensor must be 1-dimensional, got ", - encoded_data.dim(), - " dims."); - - int width = 0; - int height = 0; - auto decoded_data = WebPDecodeRGB( - encoded_data.data_ptr(), encoded_data.numel(), &width, &height); - TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB failed."); - auto out = torch::from_blob(decoded_data, {height, width, 3}, torch::kUInt8); - return out.permute({2, 0, 1}); // return CHW, channels-last -} -#endif // WEBP_FOUND - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/decode_webp.h b/product/include/torchvision/io/image/cpu/decode_webp.h deleted file mode 100644 index 00a0c3362f7..00000000000 --- a/product/include/torchvision/io/image/cpu/decode_webp.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#include - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor decode_webp(const torch::Tensor& data); - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/encode_jpeg.cpp b/product/include/torchvision/io/image/cpu/encode_jpeg.cpp deleted file mode 100644 index d2ed73071a2..00000000000 --- a/product/include/torchvision/io/image/cpu/encode_jpeg.cpp +++ /dev/null @@ -1,113 +0,0 @@ -#include "encode_jpeg.h" - -#include "common_jpeg.h" - -namespace vision { -namespace image { - -#if !JPEG_FOUND - -torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { - TORCH_CHECK( - false, "encode_jpeg: torchvision not compiled with libjpeg support"); -} - -#else -// For libjpeg version <= 9b, the out_size parameter in jpeg_mem_dest() is -// defined as unsigned long, whereas in later version, it is defined as size_t. -#if !defined(JPEG_LIB_VERSION_MAJOR) || JPEG_LIB_VERSION_MAJOR < 9 || \ - (JPEG_LIB_VERSION_MAJOR == 9 && JPEG_LIB_VERSION_MINOR <= 2) -using JpegSizeType = unsigned long; -#else -using JpegSizeType = size_t; -#endif - -using namespace detail; - -torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.encode_jpeg.encode_jpeg"); - // Define compression structures and error handling - struct jpeg_compress_struct cinfo {}; - struct torch_jpeg_error_mgr jerr {}; - - // Define buffer to write JPEG information to and its size - JpegSizeType jpegSize = 0; - uint8_t* jpegBuf = nullptr; - - cinfo.err = jpeg_std_error(&jerr.pub); - jerr.pub.error_exit = torch_jpeg_error_exit; - - /* Establish the setjmp return context for my_error_exit to use. */ - if (setjmp(jerr.setjmp_buffer)) { - /* If we get here, the JPEG code has signaled an error. - * We need to clean up the JPEG object and the buffer. - */ - jpeg_destroy_compress(&cinfo); - if (jpegBuf != nullptr) { - free(jpegBuf); - } - - TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg); - } - - // Check that the input tensor is on CPU - TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); - - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); - - // Check that the input tensor is 3-dimensional - TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); - - // Get image info - int channels = data.size(0); - int height = data.size(1); - int width = data.size(2); - auto input = data.permute({1, 2, 0}).contiguous(); - - TORCH_CHECK( - channels == 1 || channels == 3, - "The number of channels should be 1 or 3, got: ", - channels); - - // Initialize JPEG structure - jpeg_create_compress(&cinfo); - - // Set output image information - cinfo.image_width = width; - cinfo.image_height = height; - cinfo.input_components = channels; - cinfo.in_color_space = channels == 1 ? JCS_GRAYSCALE : JCS_RGB; - - jpeg_set_defaults(&cinfo); - jpeg_set_quality(&cinfo, quality, TRUE); - - // Save JPEG output to a buffer - jpeg_mem_dest(&cinfo, &jpegBuf, &jpegSize); - - // Start JPEG compression - jpeg_start_compress(&cinfo, TRUE); - - auto stride = width * channels; - auto ptr = input.data_ptr(); - - // Encode JPEG file - while (cinfo.next_scanline < cinfo.image_height) { - jpeg_write_scanlines(&cinfo, &ptr, 1); - ptr += stride; - } - - jpeg_finish_compress(&cinfo); - jpeg_destroy_compress(&cinfo); - - torch::TensorOptions options = torch::TensorOptions{torch::kU8}; - auto out_tensor = - torch::from_blob(jpegBuf, {(long)jpegSize}, ::free, options); - jpegBuf = nullptr; - return out_tensor; -} -#endif - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/encode_jpeg.h b/product/include/torchvision/io/image/cpu/encode_jpeg.h deleted file mode 100644 index 25084e154d6..00000000000 --- a/product/include/torchvision/io/image/cpu/encode_jpeg.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor encode_jpeg( - const torch::Tensor& data, - int64_t quality); - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/encode_png.cpp b/product/include/torchvision/io/image/cpu/encode_png.cpp deleted file mode 100644 index 5596d3a6789..00000000000 --- a/product/include/torchvision/io/image/cpu/encode_png.cpp +++ /dev/null @@ -1,180 +0,0 @@ -#include "encode_jpeg.h" - -#include "common_png.h" - -namespace vision { -namespace image { - -#if !PNG_FOUND - -torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { - TORCH_CHECK( - false, "encode_png: torchvision not compiled with libpng support"); -} - -#else - -namespace { - -struct torch_mem_encode { - char* buffer; - size_t size; -}; - -struct torch_png_error_mgr { - const char* pngLastErrorMsg; /* error messages */ - jmp_buf setjmp_buffer; /* for return to caller */ -}; - -using torch_png_error_mgr_ptr = torch_png_error_mgr*; - -void torch_png_error(png_structp png_ptr, png_const_charp error_msg) { - /* png_ptr->err really points to a torch_png_error_mgr struct, so coerce - * pointer */ - auto error_ptr = (torch_png_error_mgr_ptr)png_get_error_ptr(png_ptr); - /* Replace the error message on the error structure */ - error_ptr->pngLastErrorMsg = error_msg; - /* Return control to the setjmp point */ - longjmp(error_ptr->setjmp_buffer, 1); -} - -void torch_png_write_data( - png_structp png_ptr, - png_bytep data, - png_size_t length) { - struct torch_mem_encode* p = - (struct torch_mem_encode*)png_get_io_ptr(png_ptr); - size_t nsize = p->size + length; - - /* allocate or grow buffer */ - if (p->buffer) - p->buffer = (char*)realloc(p->buffer, nsize); - else - p->buffer = (char*)malloc(nsize); - - if (!p->buffer) - png_error(png_ptr, "Write Error"); - - /* copy new bytes to end of buffer */ - memcpy(p->buffer + p->size, data, length); - p->size += length; -} - -} // namespace - -torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.encode_png.encode_png"); - // Define compression structures and error handling - png_structp png_write; - png_infop info_ptr; - struct torch_png_error_mgr err_ptr; - - // Define output buffer - struct torch_mem_encode buf_info; - buf_info.buffer = nullptr; - buf_info.size = 0; - - /* Establish the setjmp return context for my_error_exit to use. */ - if (setjmp(err_ptr.setjmp_buffer)) { - /* If we get here, the PNG code has signaled an error. - * We need to clean up the PNG object and the buffer. - */ - if (info_ptr != nullptr) { - png_destroy_info_struct(png_write, &info_ptr); - } - - if (png_write != nullptr) { - png_destroy_write_struct(&png_write, nullptr); - } - - if (buf_info.buffer != nullptr) { - free(buf_info.buffer); - } - - TORCH_CHECK(false, err_ptr.pngLastErrorMsg); - } - - // Check that the compression level is between 0 and 9 - TORCH_CHECK( - compression_level >= 0 && compression_level <= 9, - "Compression level should be between 0 and 9"); - - // Check that the input tensor is on CPU - TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); - - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); - - // Check that the input tensor is 3-dimensional - TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); - - // Get image info - int channels = data.size(0); - int height = data.size(1); - int width = data.size(2); - auto input = data.permute({1, 2, 0}).contiguous(); - - TORCH_CHECK( - channels == 1 || channels == 3, - "The number of channels should be 1 or 3, got: ", - channels); - - // Initialize PNG structures - png_write = png_create_write_struct( - PNG_LIBPNG_VER_STRING, &err_ptr, torch_png_error, nullptr); - - info_ptr = png_create_info_struct(png_write); - - // Define custom buffer output - png_set_write_fn(png_write, &buf_info, torch_png_write_data, nullptr); - - // Set output image information - auto color_type = channels == 1 ? PNG_COLOR_TYPE_GRAY : PNG_COLOR_TYPE_RGB; - png_set_IHDR( - png_write, - info_ptr, - width, - height, - 8, - color_type, - PNG_INTERLACE_NONE, - PNG_COMPRESSION_TYPE_DEFAULT, - PNG_FILTER_TYPE_DEFAULT); - - // Set image compression level - png_set_compression_level(png_write, compression_level); - - // Write file header - png_write_info(png_write, info_ptr); - - auto stride = width * channels; - auto ptr = input.data_ptr(); - - // Encode PNG file - for (int y = 0; y < height; ++y) { - png_write_row(png_write, ptr); - ptr += stride; - } - - // Write EOF - png_write_end(png_write, info_ptr); - - // Destroy structures - png_destroy_write_struct(&png_write, &info_ptr); - - torch::TensorOptions options = torch::TensorOptions{torch::kU8}; - auto outTensor = torch::empty({(long)buf_info.size}, options); - - // Copy memory from png buffer, since torch cannot get ownership of it via - // `from_blob` - auto outPtr = outTensor.data_ptr(); - std::memcpy(outPtr, buf_info.buffer, sizeof(uint8_t) * outTensor.numel()); - free(buf_info.buffer); - - return outTensor; -} - -#endif - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/encode_png.h b/product/include/torchvision/io/image/cpu/encode_png.h deleted file mode 100644 index 86a67c8706e..00000000000 --- a/product/include/torchvision/io/image/cpu/encode_png.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor encode_png( - const torch::Tensor& data, - int64_t compression_level); - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/exif.h b/product/include/torchvision/io/image/cpu/exif.h deleted file mode 100644 index 61948bfe16a..00000000000 --- a/product/include/torchvision/io/image/cpu/exif.h +++ /dev/null @@ -1,256 +0,0 @@ -/*M/////////////////////////////////////////////////////////////////////////////////////// -// -// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. -// -// By downloading, copying, installing or using the software you agree to this -license. -// If you do not agree to this license, do not download, install, -// copy or use the software. -// -// -// License Agreement -// For Open Source Computer Vision Library -// -// Copyright (C) 2000-2008, Intel Corporation, all rights reserved. -// Copyright (C) 2009, Willow Garage Inc., all rights reserved. -// Third party copyrights are property of their respective owners. -// -// Redistribution and use in source and binary forms, with or without -modification, -// are permitted provided that the following conditions are met: -// -// * Redistribution's of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// -// * Redistribution's in binary form must reproduce the above copyright -notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// * The name of the copyright holders may not be used to endorse or promote -products -// derived from this software without specific prior written permission. -// -// This software is provided by the copyright holders and contributors "as is" -and -// any express or implied warranties, including, but not limited to, the implied -// warranties of merchantability and fitness for a particular purpose are -disclaimed. -// In no event shall the Intel Corporation or contributors be liable for any -direct, -// indirect, incidental, special, exemplary, or consequential damages -// (including, but not limited to, procurement of substitute goods or services; -// loss of use, data, or profits; or business interruption) however caused -// and on any theory of liability, whether in contract, strict liability, -// or tort (including negligence or otherwise) arising in any way out of -// the use of this software, even if advised of the possibility of such damage. -// -//M*/ -#pragma once -// Functions in this module are taken from OpenCV -// https://github.com/opencv/opencv/blob/097891e311fae1d8354eb092a0fd0171e630d78c/modules/imgcodecs/src/exif.cpp - -#if JPEG_FOUND -#include -#endif -#if PNG_FOUND -#include -#endif - -#include - -namespace vision { -namespace image { -namespace exif_private { - -constexpr uint16_t APP1 = 0xe1; -constexpr uint16_t ENDIANNESS_INTEL = 0x49; -constexpr uint16_t ENDIANNESS_MOTO = 0x4d; -constexpr uint16_t REQ_EXIF_TAG_MARK = 0x2a; -constexpr uint16_t ORIENTATION_EXIF_TAG = 0x0112; -constexpr uint16_t INCORRECT_TAG = -1; - -class ExifDataReader { - public: - ExifDataReader(unsigned char* p, size_t s) : _ptr(p), _size(s) {} - size_t size() const { - return _size; - } - const unsigned char& operator[](size_t index) const { - TORCH_CHECK(index >= 0 && index < _size); - return _ptr[index]; - } - - protected: - unsigned char* _ptr; - size_t _size; -}; - -inline uint16_t get_endianness(const ExifDataReader& exif_data) { - if ((exif_data.size() < 1) || - (exif_data.size() > 1 && exif_data[0] != exif_data[1])) { - return 0; - } - if (exif_data[0] == 'I') { - return ENDIANNESS_INTEL; - } - if (exif_data[0] == 'M') { - return ENDIANNESS_MOTO; - } - return 0; -} - -inline uint16_t get_uint16( - const ExifDataReader& exif_data, - uint16_t endianness, - const size_t offset) { - if (offset + 1 >= exif_data.size()) { - return INCORRECT_TAG; - } - - if (endianness == ENDIANNESS_INTEL) { - return exif_data[offset] + (exif_data[offset + 1] << 8); - } - return (exif_data[offset] << 8) + exif_data[offset + 1]; -} - -inline uint32_t get_uint32( - const ExifDataReader& exif_data, - uint16_t endianness, - const size_t offset) { - if (offset + 3 >= exif_data.size()) { - return INCORRECT_TAG; - } - - if (endianness == ENDIANNESS_INTEL) { - return exif_data[offset] + (exif_data[offset + 1] << 8) + - (exif_data[offset + 2] << 16) + (exif_data[offset + 3] << 24); - } - return (exif_data[offset] << 24) + (exif_data[offset + 1] << 16) + - (exif_data[offset + 2] << 8) + exif_data[offset + 3]; -} - -inline int fetch_exif_orientation(unsigned char* exif_data_ptr, size_t size) { - int exif_orientation = -1; - - // Exif binary structure looks like this - // First 6 bytes: [E, x, i, f, 0, 0] - // Endianness, 2 bytes : [M, M] or [I, I] - // Tag mark, 2 bytes: [0, 0x2a] - // Offset, 4 bytes - // Num entries, 2 bytes - // Tag entries and data, tag has 2 bytes and its data has 10 bytes - // For more details: - // http://www.media.mit.edu/pia/Research/deepview/exif.html - - ExifDataReader exif_data(exif_data_ptr, size); - auto endianness = get_endianness(exif_data); - - // Checking whether Tag Mark (0x002A) correspond to one contained in the - // Jpeg file - uint16_t tag_mark = get_uint16(exif_data, endianness, 2); - if (tag_mark == REQ_EXIF_TAG_MARK) { - auto offset = get_uint32(exif_data, endianness, 4); - size_t num_entry = get_uint16(exif_data, endianness, offset); - offset += 2; // go to start of tag fields - constexpr size_t tiff_field_size = 12; - for (size_t entry = 0; entry < num_entry; entry++) { - // Here we just search for orientation tag and parse it - auto tag_num = get_uint16(exif_data, endianness, offset); - if (tag_num == INCORRECT_TAG) { - break; - } - if (tag_num == ORIENTATION_EXIF_TAG) { - exif_orientation = get_uint16(exif_data, endianness, offset + 8); - break; - } - offset += tiff_field_size; - } - } - return exif_orientation; -} - -#if JPEG_FOUND -inline int fetch_jpeg_exif_orientation(j_decompress_ptr cinfo) { - // Check for Exif marker APP1 - jpeg_saved_marker_ptr exif_marker = 0; - jpeg_saved_marker_ptr cmarker = cinfo->marker_list; - while (cmarker && exif_marker == 0) { - if (cmarker->marker == APP1) { - exif_marker = cmarker; - } - cmarker = cmarker->next; - } - - if (!exif_marker) { - return -1; - } - - constexpr size_t start_offset = 6; - if (exif_marker->data_length <= start_offset) { - return -1; - } - - auto* exif_data_ptr = exif_marker->data + start_offset; - auto size = exif_marker->data_length - start_offset; - - return fetch_exif_orientation(exif_data_ptr, size); -} -#endif // #if JPEG_FOUND - -#if PNG_FOUND && defined(PNG_eXIf_SUPPORTED) -inline int fetch_png_exif_orientation(png_structp png_ptr, png_infop info_ptr) { - png_uint_32 num_exif = 0; - png_bytep exif = 0; - - // Exif info could be in info_ptr - if (png_get_valid(png_ptr, info_ptr, PNG_INFO_eXIf)) { - png_get_eXIf_1(png_ptr, info_ptr, &num_exif, &exif); - } - - if (exif && num_exif > 0) { - return fetch_exif_orientation(exif, num_exif); - } - return -1; -} -#endif // #if PNG_FOUND && defined(PNG_eXIf_SUPPORTED) - -constexpr uint16_t IMAGE_ORIENTATION_TL = 1; // normal orientation -constexpr uint16_t IMAGE_ORIENTATION_TR = 2; // needs horizontal flip -constexpr uint16_t IMAGE_ORIENTATION_BR = 3; // needs 180 rotation -constexpr uint16_t IMAGE_ORIENTATION_BL = 4; // needs vertical flip -constexpr uint16_t IMAGE_ORIENTATION_LT = - 5; // mirrored horizontal & rotate 270 CW -constexpr uint16_t IMAGE_ORIENTATION_RT = 6; // rotate 90 CW -constexpr uint16_t IMAGE_ORIENTATION_RB = - 7; // mirrored horizontal & rotate 90 CW -constexpr uint16_t IMAGE_ORIENTATION_LB = 8; // needs 270 CW rotation - -inline torch::Tensor exif_orientation_transform( - const torch::Tensor& image, - int orientation) { - if (orientation == IMAGE_ORIENTATION_TL) { - return image; - } else if (orientation == IMAGE_ORIENTATION_TR) { - return image.flip(-1); - } else if (orientation == IMAGE_ORIENTATION_BR) { - // needs 180 rotation equivalent to - // flip both horizontally and vertically - return image.flip({-2, -1}); - } else if (orientation == IMAGE_ORIENTATION_BL) { - return image.flip(-2); - } else if (orientation == IMAGE_ORIENTATION_LT) { - return image.transpose(-1, -2); - } else if (orientation == IMAGE_ORIENTATION_RT) { - return image.transpose(-1, -2).flip(-1); - } else if (orientation == IMAGE_ORIENTATION_RB) { - return image.transpose(-1, -2).flip({-2, -1}); - } else if (orientation == IMAGE_ORIENTATION_LB) { - return image.transpose(-1, -2).flip(-2); - } - return image; -} - -} // namespace exif_private -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/giflib/dgif_lib.c b/product/include/torchvision/io/image/cpu/giflib/dgif_lib.c deleted file mode 100644 index 297f12f15c4..00000000000 --- a/product/include/torchvision/io/image/cpu/giflib/dgif_lib.c +++ /dev/null @@ -1,1312 +0,0 @@ -/****************************************************************************** - -dgif_lib.c - GIF decoding - -The functions here and in egif_lib.c are partitioned carefully so that -if you only require one of read and write capability, only one of these -two modules will be linked. Preserve this property! - -*****************************************************************************/ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Copyright (C) Eric S. Raymond - -#include -#include -#include -#include -#include -#include - -#ifdef _WIN32 -#include -#else -#include -#endif /* _WIN32 */ - -#include "gif_lib.h" -#include "gif_lib_private.h" - -/* compose unsigned little endian value */ -#define UNSIGNED_LITTLE_ENDIAN(lo, hi) ((lo) | ((hi) << 8)) - -/* avoid extra function call in case we use fread (TVT) */ -static int InternalRead(GifFileType *gif, GifByteType *buf, int len) { - // fprintf(stderr, "### Read: %d\n", len); - return (((GifFilePrivateType *)gif->Private)->Read - ? ((GifFilePrivateType *)gif->Private)->Read(gif, buf, len) - : fread(buf, 1, len, - ((GifFilePrivateType *)gif->Private)->File)); -} - -static int DGifGetWord(GifFileType *GifFile, GifWord *Word); -static int DGifSetupDecompress(GifFileType *GifFile); -static int DGifDecompressLine(GifFileType *GifFile, GifPixelType *Line, - int LineLen); -static int DGifGetPrefixChar(const GifPrefixType *Prefix, int Code, - int ClearCode); -static int DGifDecompressInput(GifFileType *GifFile, int *Code); -static int DGifBufferedInput(GifFileType *GifFile, GifByteType *Buf, - GifByteType *NextByte); - -/****************************************************************************** - Open a new GIF file for read, given by its name. - Returns dynamically allocated GifFileType pointer which serves as the GIF - info record. -******************************************************************************/ -GifFileType *DGifOpenFileName(const char *FileName, int *Error) { - int FileHandle; - GifFileType *GifFile; - - if ((FileHandle = open(FileName, O_RDONLY)) == -1) { - if (Error != NULL) { - *Error = D_GIF_ERR_OPEN_FAILED; - } - return NULL; - } - - GifFile = DGifOpenFileHandle(FileHandle, Error); - return GifFile; -} - -/****************************************************************************** - Update a new GIF file, given its file handle. - Returns dynamically allocated GifFileType pointer which serves as the GIF - info record. -******************************************************************************/ -GifFileType *DGifOpenFileHandle(int FileHandle, int *Error) { - char Buf[GIF_STAMP_LEN + 1]; - GifFileType *GifFile; - GifFilePrivateType *Private; - FILE *f; - - GifFile = (GifFileType *)malloc(sizeof(GifFileType)); - if (GifFile == NULL) { - if (Error != NULL) { - *Error = D_GIF_ERR_NOT_ENOUGH_MEM; - } - (void)close(FileHandle); - return NULL; - } - - /*@i1@*/ memset(GifFile, '\0', sizeof(GifFileType)); - - /* Belt and suspenders, in case the null pointer isn't zero */ - GifFile->SavedImages = NULL; - GifFile->SColorMap = NULL; - - Private = (GifFilePrivateType *)calloc(1, sizeof(GifFilePrivateType)); - if (Private == NULL) { - if (Error != NULL) { - *Error = D_GIF_ERR_NOT_ENOUGH_MEM; - } - (void)close(FileHandle); - free((char *)GifFile); - return NULL; - } - - /*@i1@*/ memset(Private, '\0', sizeof(GifFilePrivateType)); - -#ifdef _WIN32 - _setmode(FileHandle, O_BINARY); /* Make sure it is in binary mode. */ -#endif /* _WIN32 */ - - f = fdopen(FileHandle, "rb"); /* Make it into a stream: */ - - /*@-mustfreeonly@*/ - GifFile->Private = (void *)Private; - Private->FileHandle = FileHandle; - Private->File = f; - Private->FileState = FILE_STATE_READ; - Private->Read = NULL; /* don't use alternate input method (TVT) */ - GifFile->UserData = NULL; /* TVT */ - /*@=mustfreeonly@*/ - - /* Let's see if this is a GIF file: */ - /* coverity[check_return] */ - if (InternalRead(GifFile, (unsigned char *)Buf, GIF_STAMP_LEN) != - GIF_STAMP_LEN) { - if (Error != NULL) { - *Error = D_GIF_ERR_READ_FAILED; - } - (void)fclose(f); - free((char *)Private); - free((char *)GifFile); - return NULL; - } - - /* Check for GIF prefix at start of file */ - Buf[GIF_STAMP_LEN] = 0; - if (strncmp(GIF_STAMP, Buf, GIF_VERSION_POS) != 0) { - if (Error != NULL) { - *Error = D_GIF_ERR_NOT_GIF_FILE; - } - (void)fclose(f); - free((char *)Private); - free((char *)GifFile); - return NULL; - } - - if (DGifGetScreenDesc(GifFile) == GIF_ERROR) { - (void)fclose(f); - free((char *)Private); - free((char *)GifFile); - return NULL; - } - - GifFile->Error = 0; - - /* What version of GIF? */ - Private->gif89 = (Buf[GIF_VERSION_POS + 1] == '9'); - - return GifFile; -} - -/****************************************************************************** - GifFileType constructor with user supplied input function (TVT) -******************************************************************************/ -GifFileType *DGifOpen(void *userData, InputFunc readFunc, int *Error) { - char Buf[GIF_STAMP_LEN + 1]; - GifFileType *GifFile; - GifFilePrivateType *Private; - - GifFile = (GifFileType *)malloc(sizeof(GifFileType)); - if (GifFile == NULL) { - if (Error != NULL) { - *Error = D_GIF_ERR_NOT_ENOUGH_MEM; - } - return NULL; - } - - memset(GifFile, '\0', sizeof(GifFileType)); - - /* Belt and suspenders, in case the null pointer isn't zero */ - GifFile->SavedImages = NULL; - GifFile->SColorMap = NULL; - - Private = (GifFilePrivateType *)calloc(1, sizeof(GifFilePrivateType)); - if (!Private) { - if (Error != NULL) { - *Error = D_GIF_ERR_NOT_ENOUGH_MEM; - } - free((char *)GifFile); - return NULL; - } - /*@i1@*/ memset(Private, '\0', sizeof(GifFilePrivateType)); - - GifFile->Private = (void *)Private; - Private->FileHandle = 0; - Private->File = NULL; - Private->FileState = FILE_STATE_READ; - - Private->Read = readFunc; /* TVT */ - GifFile->UserData = userData; /* TVT */ - - /* Lets see if this is a GIF file: */ - /* coverity[check_return] */ - if (InternalRead(GifFile, (unsigned char *)Buf, GIF_STAMP_LEN) != - GIF_STAMP_LEN) { - if (Error != NULL) { - *Error = D_GIF_ERR_READ_FAILED; - } - free((char *)Private); - free((char *)GifFile); - return NULL; - } - - /* Check for GIF prefix at start of file */ - Buf[GIF_STAMP_LEN] = '\0'; - if (strncmp(GIF_STAMP, Buf, GIF_VERSION_POS) != 0) { - if (Error != NULL) { - *Error = D_GIF_ERR_NOT_GIF_FILE; - } - free((char *)Private); - free((char *)GifFile); - return NULL; - } - - if (DGifGetScreenDesc(GifFile) == GIF_ERROR) { - free((char *)Private); - free((char *)GifFile); - if (Error != NULL) { - *Error = D_GIF_ERR_NO_SCRN_DSCR; - } - return NULL; - } - - GifFile->Error = 0; - - /* What version of GIF? */ - Private->gif89 = (Buf[GIF_VERSION_POS + 1] == '9'); - - return GifFile; -} - -/****************************************************************************** - This routine should be called before any other DGif calls. Note that - this routine is called automatically from DGif file open routines. -******************************************************************************/ -int DGifGetScreenDesc(GifFileType *GifFile) { - int BitsPerPixel; - bool SortFlag; - GifByteType Buf[3]; - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - if (!IS_READABLE(Private)) { - /* This file was NOT open for reading: */ - GifFile->Error = D_GIF_ERR_NOT_READABLE; - return GIF_ERROR; - } - - /* Put the screen descriptor into the file: */ - if (DGifGetWord(GifFile, &GifFile->SWidth) == GIF_ERROR || - DGifGetWord(GifFile, &GifFile->SHeight) == GIF_ERROR) { - return GIF_ERROR; - } - - if (InternalRead(GifFile, Buf, 3) != 3) { - GifFile->Error = D_GIF_ERR_READ_FAILED; - GifFreeMapObject(GifFile->SColorMap); - GifFile->SColorMap = NULL; - return GIF_ERROR; - } - GifFile->SColorResolution = (((Buf[0] & 0x70) + 1) >> 4) + 1; - SortFlag = (Buf[0] & 0x08) != 0; - BitsPerPixel = (Buf[0] & 0x07) + 1; - GifFile->SBackGroundColor = Buf[1]; - GifFile->AspectByte = Buf[2]; - if (Buf[0] & 0x80) { /* Do we have global color map? */ - int i; - - GifFile->SColorMap = GifMakeMapObject(1 << BitsPerPixel, NULL); - if (GifFile->SColorMap == NULL) { - GifFile->Error = D_GIF_ERR_NOT_ENOUGH_MEM; - return GIF_ERROR; - } - - /* Get the global color map: */ - GifFile->SColorMap->SortFlag = SortFlag; - for (i = 0; i < GifFile->SColorMap->ColorCount; i++) { - /* coverity[check_return] */ - if (InternalRead(GifFile, Buf, 3) != 3) { - GifFreeMapObject(GifFile->SColorMap); - GifFile->SColorMap = NULL; - GifFile->Error = D_GIF_ERR_READ_FAILED; - return GIF_ERROR; - } - GifFile->SColorMap->Colors[i].Red = Buf[0]; - GifFile->SColorMap->Colors[i].Green = Buf[1]; - GifFile->SColorMap->Colors[i].Blue = Buf[2]; - } - } else { - GifFile->SColorMap = NULL; - } - - /* - * No check here for whether the background color is in range for the - * screen color map. Possibly there should be. - */ - - return GIF_OK; -} - -const char *DGifGetGifVersion(GifFileType *GifFile) { - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - if (Private->gif89) { - return GIF89_STAMP; - } else { - return GIF87_STAMP; - } -} - -/****************************************************************************** - This routine should be called before any attempt to read an image. -******************************************************************************/ -int DGifGetRecordType(GifFileType *GifFile, GifRecordType *Type) { - GifByteType Buf; - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - if (!IS_READABLE(Private)) { - /* This file was NOT open for reading: */ - GifFile->Error = D_GIF_ERR_NOT_READABLE; - return GIF_ERROR; - } - - /* coverity[check_return] */ - if (InternalRead(GifFile, &Buf, 1) != 1) { - GifFile->Error = D_GIF_ERR_READ_FAILED; - return GIF_ERROR; - } - - // fprintf(stderr, "### DGifGetRecordType: %02x\n", Buf); - switch (Buf) { - case DESCRIPTOR_INTRODUCER: - *Type = IMAGE_DESC_RECORD_TYPE; - break; - case EXTENSION_INTRODUCER: - *Type = EXTENSION_RECORD_TYPE; - break; - case TERMINATOR_INTRODUCER: - *Type = TERMINATE_RECORD_TYPE; - break; - default: - *Type = UNDEFINED_RECORD_TYPE; - GifFile->Error = D_GIF_ERR_WRONG_RECORD; - return GIF_ERROR; - } - - return GIF_OK; -} - -int DGifGetImageHeader(GifFileType *GifFile) { - unsigned int BitsPerPixel; - GifByteType Buf[3]; - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - if (!IS_READABLE(Private)) { - /* This file was NOT open for reading: */ - GifFile->Error = D_GIF_ERR_NOT_READABLE; - return GIF_ERROR; - } - - if (DGifGetWord(GifFile, &GifFile->Image.Left) == GIF_ERROR || - DGifGetWord(GifFile, &GifFile->Image.Top) == GIF_ERROR || - DGifGetWord(GifFile, &GifFile->Image.Width) == GIF_ERROR || - DGifGetWord(GifFile, &GifFile->Image.Height) == GIF_ERROR) { - return GIF_ERROR; - } - if (InternalRead(GifFile, Buf, 1) != 1) { - GifFile->Error = D_GIF_ERR_READ_FAILED; - GifFreeMapObject(GifFile->Image.ColorMap); - GifFile->Image.ColorMap = NULL; - return GIF_ERROR; - } - BitsPerPixel = (Buf[0] & 0x07) + 1; - GifFile->Image.Interlace = (Buf[0] & 0x40) ? true : false; - - /* Setup the colormap */ - if (GifFile->Image.ColorMap) { - GifFreeMapObject(GifFile->Image.ColorMap); - GifFile->Image.ColorMap = NULL; - } - /* Does this image have local color map? */ - if (Buf[0] & 0x80) { - int i; - - GifFile->Image.ColorMap = - GifMakeMapObject(1 << BitsPerPixel, NULL); - if (GifFile->Image.ColorMap == NULL) { - GifFile->Error = D_GIF_ERR_NOT_ENOUGH_MEM; - return GIF_ERROR; - } - - /* Get the image local color map: */ - for (i = 0; i < GifFile->Image.ColorMap->ColorCount; i++) { - /* coverity[check_return] */ - if (InternalRead(GifFile, Buf, 3) != 3) { - GifFreeMapObject(GifFile->Image.ColorMap); - GifFile->Error = D_GIF_ERR_READ_FAILED; - GifFile->Image.ColorMap = NULL; - return GIF_ERROR; - } - GifFile->Image.ColorMap->Colors[i].Red = Buf[0]; - GifFile->Image.ColorMap->Colors[i].Green = Buf[1]; - GifFile->Image.ColorMap->Colors[i].Blue = Buf[2]; - } - } - - Private->PixelCount = - (long)GifFile->Image.Width * (long)GifFile->Image.Height; - - /* Reset decompress algorithm parameters. */ - return DGifSetupDecompress(GifFile); -} - -/****************************************************************************** - This routine should be called before any attempt to read an image. - Note it is assumed the Image desc. header has been read. -******************************************************************************/ -int DGifGetImageDesc(GifFileType *GifFile) { - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - SavedImage *sp; - - if (!IS_READABLE(Private)) { - /* This file was NOT open for reading: */ - GifFile->Error = D_GIF_ERR_NOT_READABLE; - return GIF_ERROR; - } - - if (DGifGetImageHeader(GifFile) == GIF_ERROR) { - return GIF_ERROR; - } - - if (GifFile->SavedImages) { - SavedImage *new_saved_images = (SavedImage *)reallocarray( - GifFile->SavedImages, (GifFile->ImageCount + 1), - sizeof(SavedImage)); - if (new_saved_images == NULL) { - GifFile->Error = D_GIF_ERR_NOT_ENOUGH_MEM; - return GIF_ERROR; - } - GifFile->SavedImages = new_saved_images; - } else { - if ((GifFile->SavedImages = - (SavedImage *)malloc(sizeof(SavedImage))) == NULL) { - GifFile->Error = D_GIF_ERR_NOT_ENOUGH_MEM; - return GIF_ERROR; - } - } - - sp = &GifFile->SavedImages[GifFile->ImageCount]; - memcpy(&sp->ImageDesc, &GifFile->Image, sizeof(GifImageDesc)); - if (GifFile->Image.ColorMap != NULL) { - sp->ImageDesc.ColorMap = - GifMakeMapObject(GifFile->Image.ColorMap->ColorCount, - GifFile->Image.ColorMap->Colors); - if (sp->ImageDesc.ColorMap == NULL) { - GifFile->Error = D_GIF_ERR_NOT_ENOUGH_MEM; - return GIF_ERROR; - } - } - sp->RasterBits = (unsigned char *)NULL; - sp->ExtensionBlockCount = 0; - sp->ExtensionBlocks = (ExtensionBlock *)NULL; - - GifFile->ImageCount++; - - return GIF_OK; -} - -/****************************************************************************** - Get one full scanned line (Line) of length LineLen from GIF file. -******************************************************************************/ -int DGifGetLine(GifFileType *GifFile, GifPixelType *Line, int LineLen) { - GifByteType *Dummy; - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - if (!IS_READABLE(Private)) { - /* This file was NOT open for reading: */ - GifFile->Error = D_GIF_ERR_NOT_READABLE; - return GIF_ERROR; - } - - if (!LineLen) { - LineLen = GifFile->Image.Width; - } - - if ((Private->PixelCount -= LineLen) > 0xffff0000UL) { - GifFile->Error = D_GIF_ERR_DATA_TOO_BIG; - return GIF_ERROR; - } - - if (DGifDecompressLine(GifFile, Line, LineLen) == GIF_OK) { - if (Private->PixelCount == 0) { - /* We probably won't be called any more, so let's clean - * up everything before we return: need to flush out all - * the rest of image until an empty block (size 0) - * detected. We use GetCodeNext. - */ - do { - if (DGifGetCodeNext(GifFile, &Dummy) == - GIF_ERROR) { - return GIF_ERROR; - } - } while (Dummy != NULL); - } - return GIF_OK; - } else { - return GIF_ERROR; - } -} - -/****************************************************************************** - Put one pixel (Pixel) into GIF file. -******************************************************************************/ -int DGifGetPixel(GifFileType *GifFile, GifPixelType Pixel) { - GifByteType *Dummy; - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - if (!IS_READABLE(Private)) { - /* This file was NOT open for reading: */ - GifFile->Error = D_GIF_ERR_NOT_READABLE; - return GIF_ERROR; - } - if (--Private->PixelCount > 0xffff0000UL) { - GifFile->Error = D_GIF_ERR_DATA_TOO_BIG; - return GIF_ERROR; - } - - if (DGifDecompressLine(GifFile, &Pixel, 1) == GIF_OK) { - if (Private->PixelCount == 0) { - /* We probably won't be called any more, so let's clean - * up everything before we return: need to flush out all - * the rest of image until an empty block (size 0) - * detected. We use GetCodeNext. - */ - do { - if (DGifGetCodeNext(GifFile, &Dummy) == - GIF_ERROR) { - return GIF_ERROR; - } - } while (Dummy != NULL); - } - return GIF_OK; - } else { - return GIF_ERROR; - } -} - -/****************************************************************************** - Get an extension block (see GIF manual) from GIF file. This routine only - returns the first data block, and DGifGetExtensionNext should be called - after this one until NULL extension is returned. - The Extension should NOT be freed by the user (not dynamically allocated). - Note it is assumed the Extension description header has been read. -******************************************************************************/ -int DGifGetExtension(GifFileType *GifFile, int *ExtCode, - GifByteType **Extension) { - GifByteType Buf; - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - // fprintf(stderr, "### -> DGifGetExtension:\n"); - if (!IS_READABLE(Private)) { - /* This file was NOT open for reading: */ - GifFile->Error = D_GIF_ERR_NOT_READABLE; - return GIF_ERROR; - } - - /* coverity[check_return] */ - if (InternalRead(GifFile, &Buf, 1) != 1) { - GifFile->Error = D_GIF_ERR_READ_FAILED; - return GIF_ERROR; - } - *ExtCode = Buf; - // fprintf(stderr, "### <- DGifGetExtension: %02x, about to call - // next\n", Buf); - - return DGifGetExtensionNext(GifFile, Extension); -} - -/****************************************************************************** - Get a following extension block (see GIF manual) from GIF file. This - routine should be called until NULL Extension is returned. - The Extension should NOT be freed by the user (not dynamically allocated). -******************************************************************************/ -int DGifGetExtensionNext(GifFileType *GifFile, GifByteType **Extension) { - GifByteType Buf; - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - // fprintf(stderr, "### -> DGifGetExtensionNext\n"); - if (InternalRead(GifFile, &Buf, 1) != 1) { - GifFile->Error = D_GIF_ERR_READ_FAILED; - return GIF_ERROR; - } - // fprintf(stderr, "### DGifGetExtensionNext sees %d\n", Buf); - - if (Buf > 0) { - *Extension = Private->Buf; /* Use private unused buffer. */ - (*Extension)[0] = - Buf; /* Pascal strings notation (pos. 0 is len.). */ - /* coverity[tainted_data,check_return] */ - if (InternalRead(GifFile, &((*Extension)[1]), Buf) != Buf) { - GifFile->Error = D_GIF_ERR_READ_FAILED; - return GIF_ERROR; - } - } else { - *Extension = NULL; - } - // fprintf(stderr, "### <- DGifGetExtensionNext: %p\n", Extension); - - return GIF_OK; -} - -/****************************************************************************** - Extract a Graphics Control Block from raw extension data -******************************************************************************/ - -int DGifExtensionToGCB(const size_t GifExtensionLength, - const GifByteType *GifExtension, - GraphicsControlBlock *GCB) { - if (GifExtensionLength != 4) { - return GIF_ERROR; - } - - GCB->DisposalMode = (GifExtension[0] >> 2) & 0x07; - GCB->UserInputFlag = (GifExtension[0] & 0x02) != 0; - GCB->DelayTime = - UNSIGNED_LITTLE_ENDIAN(GifExtension[1], GifExtension[2]); - if (GifExtension[0] & 0x01) { - GCB->TransparentColor = (int)GifExtension[3]; - } else { - GCB->TransparentColor = NO_TRANSPARENT_COLOR; - } - - return GIF_OK; -} - -/****************************************************************************** - Extract the Graphics Control Block for a saved image, if it exists. -******************************************************************************/ - -int DGifSavedExtensionToGCB(GifFileType *GifFile, int ImageIndex, - GraphicsControlBlock *GCB) { - int i; - - if (ImageIndex < 0 || ImageIndex > GifFile->ImageCount - 1) { - return GIF_ERROR; - } - - GCB->DisposalMode = DISPOSAL_UNSPECIFIED; - GCB->UserInputFlag = false; - GCB->DelayTime = 0; - GCB->TransparentColor = NO_TRANSPARENT_COLOR; - - for (i = 0; i < GifFile->SavedImages[ImageIndex].ExtensionBlockCount; - i++) { - ExtensionBlock *ep = - &GifFile->SavedImages[ImageIndex].ExtensionBlocks[i]; - if (ep->Function == GRAPHICS_EXT_FUNC_CODE) { - return DGifExtensionToGCB(ep->ByteCount, ep->Bytes, - GCB); - } - } - - return GIF_ERROR; -} - -/****************************************************************************** - This routine should be called last, to close the GIF file. -******************************************************************************/ -int DGifCloseFile(GifFileType *GifFile, int *ErrorCode) { - GifFilePrivateType *Private; - - if (GifFile == NULL || GifFile->Private == NULL) { - return GIF_ERROR; - } - - if (GifFile->Image.ColorMap) { - GifFreeMapObject(GifFile->Image.ColorMap); - GifFile->Image.ColorMap = NULL; - } - - if (GifFile->SColorMap) { - GifFreeMapObject(GifFile->SColorMap); - GifFile->SColorMap = NULL; - } - - if (GifFile->SavedImages) { - GifFreeSavedImages(GifFile); - GifFile->SavedImages = NULL; - } - - GifFreeExtensions(&GifFile->ExtensionBlockCount, - &GifFile->ExtensionBlocks); - - Private = (GifFilePrivateType *)GifFile->Private; - - if (!IS_READABLE(Private)) { - /* This file was NOT open for reading: */ - if (ErrorCode != NULL) { - *ErrorCode = D_GIF_ERR_NOT_READABLE; - } - free((char *)GifFile->Private); - free(GifFile); - return GIF_ERROR; - } - - if (Private->File && (fclose(Private->File) != 0)) { - if (ErrorCode != NULL) { - *ErrorCode = D_GIF_ERR_CLOSE_FAILED; - } - free((char *)GifFile->Private); - free(GifFile); - return GIF_ERROR; - } - - free((char *)GifFile->Private); - free(GifFile); - if (ErrorCode != NULL) { - *ErrorCode = D_GIF_SUCCEEDED; - } - return GIF_OK; -} - -/****************************************************************************** - Get 2 bytes (word) from the given file: -******************************************************************************/ -static int DGifGetWord(GifFileType *GifFile, GifWord *Word) { - unsigned char c[2]; - - /* coverity[check_return] */ - if (InternalRead(GifFile, c, 2) != 2) { - GifFile->Error = D_GIF_ERR_READ_FAILED; - return GIF_ERROR; - } - - *Word = (GifWord)UNSIGNED_LITTLE_ENDIAN(c[0], c[1]); - return GIF_OK; -} - -/****************************************************************************** - Get the image code in compressed form. This routine can be called if the - information needed to be piped out as is. Obviously this is much faster - than decoding and encoding again. This routine should be followed by calls - to DGifGetCodeNext, until NULL block is returned. - The block should NOT be freed by the user (not dynamically allocated). -******************************************************************************/ -int DGifGetCode(GifFileType *GifFile, int *CodeSize, GifByteType **CodeBlock) { - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - if (!IS_READABLE(Private)) { - /* This file was NOT open for reading: */ - GifFile->Error = D_GIF_ERR_NOT_READABLE; - return GIF_ERROR; - } - - *CodeSize = Private->BitsPerPixel; - - return DGifGetCodeNext(GifFile, CodeBlock); -} - -/****************************************************************************** - Continue to get the image code in compressed form. This routine should be - called until NULL block is returned. - The block should NOT be freed by the user (not dynamically allocated). -******************************************************************************/ -int DGifGetCodeNext(GifFileType *GifFile, GifByteType **CodeBlock) { - GifByteType Buf; - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - /* coverity[tainted_data_argument] */ - /* coverity[check_return] */ - if (InternalRead(GifFile, &Buf, 1) != 1) { - GifFile->Error = D_GIF_ERR_READ_FAILED; - return GIF_ERROR; - } - - /* coverity[lower_bounds] */ - if (Buf > 0) { - *CodeBlock = Private->Buf; /* Use private unused buffer. */ - (*CodeBlock)[0] = - Buf; /* Pascal strings notation (pos. 0 is len.). */ - /* coverity[tainted_data] */ - if (InternalRead(GifFile, &((*CodeBlock)[1]), Buf) != Buf) { - GifFile->Error = D_GIF_ERR_READ_FAILED; - return GIF_ERROR; - } - } else { - *CodeBlock = NULL; - Private->Buf[0] = 0; /* Make sure the buffer is empty! */ - Private->PixelCount = - 0; /* And local info. indicate image read. */ - } - - return GIF_OK; -} - -/****************************************************************************** - Setup the LZ decompression for this image: -******************************************************************************/ -static int DGifSetupDecompress(GifFileType *GifFile) { - int i, BitsPerPixel; - GifByteType CodeSize; - GifPrefixType *Prefix; - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - /* coverity[check_return] */ - if (InternalRead(GifFile, &CodeSize, 1) < - 1) { /* Read Code size from file. */ - GifFile->Error = D_GIF_ERR_READ_FAILED; - return GIF_ERROR; /* Failed to read Code size. */ - } - BitsPerPixel = CodeSize; - - /* this can only happen on a severely malformed GIF */ - if (BitsPerPixel > 8) { - GifFile->Error = - D_GIF_ERR_READ_FAILED; /* somewhat bogus error code */ - return GIF_ERROR; /* Failed to read Code size. */ - } - - Private->Buf[0] = 0; /* Input Buffer empty. */ - Private->BitsPerPixel = BitsPerPixel; - Private->ClearCode = (1 << BitsPerPixel); - Private->EOFCode = Private->ClearCode + 1; - Private->RunningCode = Private->EOFCode + 1; - Private->RunningBits = BitsPerPixel + 1; /* Number of bits per code. */ - Private->MaxCode1 = 1 << Private->RunningBits; /* Max. code + 1. */ - Private->StackPtr = 0; /* No pixels on the pixel stack. */ - Private->LastCode = NO_SUCH_CODE; - Private->CrntShiftState = 0; /* No information in CrntShiftDWord. */ - Private->CrntShiftDWord = 0; - - Prefix = Private->Prefix; - for (i = 0; i <= LZ_MAX_CODE; i++) { - Prefix[i] = NO_SUCH_CODE; - } - - return GIF_OK; -} - -/****************************************************************************** - The LZ decompression routine: - This version decompress the given GIF file into Line of length LineLen. - This routine can be called few times (one per scan line, for example), in - order the complete the whole image. -******************************************************************************/ -static int DGifDecompressLine(GifFileType *GifFile, GifPixelType *Line, - int LineLen) { - int i = 0; - int j, CrntCode, EOFCode, ClearCode, CrntPrefix, LastCode, StackPtr; - GifByteType *Stack, *Suffix; - GifPrefixType *Prefix; - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - StackPtr = Private->StackPtr; - Prefix = Private->Prefix; - Suffix = Private->Suffix; - Stack = Private->Stack; - EOFCode = Private->EOFCode; - ClearCode = Private->ClearCode; - LastCode = Private->LastCode; - - if (StackPtr > LZ_MAX_CODE) { - return GIF_ERROR; - } - - if (StackPtr != 0) { - /* Let pop the stack off before continueing to read the GIF - * file: */ - while (StackPtr != 0 && i < LineLen) { - Line[i++] = Stack[--StackPtr]; - } - } - - while (i < LineLen) { /* Decode LineLen items. */ - if (DGifDecompressInput(GifFile, &CrntCode) == GIF_ERROR) { - return GIF_ERROR; - } - - if (CrntCode == EOFCode) { - /* Note however that usually we will not be here as we - * will stop decoding as soon as we got all the pixel, - * or EOF code will not be read at all, and - * DGifGetLine/Pixel clean everything. */ - GifFile->Error = D_GIF_ERR_EOF_TOO_SOON; - return GIF_ERROR; - } else if (CrntCode == ClearCode) { - /* We need to start over again: */ - for (j = 0; j <= LZ_MAX_CODE; j++) { - Prefix[j] = NO_SUCH_CODE; - } - Private->RunningCode = Private->EOFCode + 1; - Private->RunningBits = Private->BitsPerPixel + 1; - Private->MaxCode1 = 1 << Private->RunningBits; - LastCode = Private->LastCode = NO_SUCH_CODE; - } else { - /* Its regular code - if in pixel range simply add it to - * output stream, otherwise trace to codes linked list - * until the prefix is in pixel range: */ - if (CrntCode < ClearCode) { - /* This is simple - its pixel scalar, so add it - * to output: */ - Line[i++] = CrntCode; - } else { - /* Its a code to needed to be traced: trace the - * linked list until the prefix is a pixel, - * while pushing the suffix pixels on our stack. - * If we done, pop the stack in reverse (thats - * what stack is good for!) order to output. */ - if (Prefix[CrntCode] == NO_SUCH_CODE) { - CrntPrefix = LastCode; - - /* Only allowed if CrntCode is exactly - * the running code: In that case - * CrntCode = XXXCode, CrntCode or the - * prefix code is last code and the - * suffix char is exactly the prefix of - * last code! */ - if (CrntCode == - Private->RunningCode - 2) { - Suffix[Private->RunningCode - - 2] = Stack[StackPtr++] = - DGifGetPrefixChar( - Prefix, LastCode, - ClearCode); - } else { - Suffix[Private->RunningCode - - 2] = Stack[StackPtr++] = - DGifGetPrefixChar( - Prefix, CrntCode, - ClearCode); - } - } else { - CrntPrefix = CrntCode; - } - - /* Now (if image is O.K.) we should not get a - * NO_SUCH_CODE during the trace. As we might - * loop forever, in case of defective image, we - * use StackPtr as loop counter and stop before - * overflowing Stack[]. */ - while (StackPtr < LZ_MAX_CODE && - CrntPrefix > ClearCode && - CrntPrefix <= LZ_MAX_CODE) { - Stack[StackPtr++] = Suffix[CrntPrefix]; - CrntPrefix = Prefix[CrntPrefix]; - } - if (StackPtr >= LZ_MAX_CODE || - CrntPrefix > LZ_MAX_CODE) { - GifFile->Error = D_GIF_ERR_IMAGE_DEFECT; - return GIF_ERROR; - } - /* Push the last character on stack: */ - Stack[StackPtr++] = CrntPrefix; - - /* Now lets pop all the stack into output: */ - while (StackPtr != 0 && i < LineLen) { - Line[i++] = Stack[--StackPtr]; - } - } - if (LastCode != NO_SUCH_CODE && - Private->RunningCode - 2 < (LZ_MAX_CODE + 1) && - Prefix[Private->RunningCode - 2] == NO_SUCH_CODE) { - Prefix[Private->RunningCode - 2] = LastCode; - - if (CrntCode == Private->RunningCode - 2) { - /* Only allowed if CrntCode is exactly - * the running code: In that case - * CrntCode = XXXCode, CrntCode or the - * prefix code is last code and the - * suffix char is exactly the prefix of - * last code! */ - Suffix[Private->RunningCode - 2] = - DGifGetPrefixChar(Prefix, LastCode, - ClearCode); - } else { - Suffix[Private->RunningCode - 2] = - DGifGetPrefixChar(Prefix, CrntCode, - ClearCode); - } - } - LastCode = CrntCode; - } - } - - Private->LastCode = LastCode; - Private->StackPtr = StackPtr; - - return GIF_OK; -} - -/****************************************************************************** - Routine to trace the Prefixes linked list until we get a prefix which is - not code, but a pixel value (less than ClearCode). Returns that pixel value. - If image is defective, we might loop here forever, so we limit the loops to - the maximum possible if image O.k. - LZ_MAX_CODE times. -******************************************************************************/ -static int DGifGetPrefixChar(const GifPrefixType *Prefix, int Code, - int ClearCode) { - int i = 0; - - while (Code > ClearCode && i++ <= LZ_MAX_CODE) { - if (Code > LZ_MAX_CODE) { - return NO_SUCH_CODE; - } - Code = Prefix[Code]; - } - return Code; -} - -/****************************************************************************** - Interface for accessing the LZ codes directly. Set Code to the real code - (12bits), or to -1 if EOF code is returned. -******************************************************************************/ -int DGifGetLZCodes(GifFileType *GifFile, int *Code) { - GifByteType *CodeBlock; - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - if (!IS_READABLE(Private)) { - /* This file was NOT open for reading: */ - GifFile->Error = D_GIF_ERR_NOT_READABLE; - return GIF_ERROR; - } - - if (DGifDecompressInput(GifFile, Code) == GIF_ERROR) { - return GIF_ERROR; - } - - if (*Code == Private->EOFCode) { - /* Skip rest of codes (hopefully only NULL terminating block): - */ - do { - if (DGifGetCodeNext(GifFile, &CodeBlock) == GIF_ERROR) { - return GIF_ERROR; - } - } while (CodeBlock != NULL); - - *Code = -1; - } else if (*Code == Private->ClearCode) { - /* We need to start over again: */ - Private->RunningCode = Private->EOFCode + 1; - Private->RunningBits = Private->BitsPerPixel + 1; - Private->MaxCode1 = 1 << Private->RunningBits; - } - - return GIF_OK; -} - -/****************************************************************************** - The LZ decompression input routine: - This routine is responsable for the decompression of the bit stream from - 8 bits (bytes) packets, into the real codes. - Returns GIF_OK if read successfully. -******************************************************************************/ -static int DGifDecompressInput(GifFileType *GifFile, int *Code) { - static const unsigned short CodeMasks[] = { - 0x0000, 0x0001, 0x0003, 0x0007, 0x000f, 0x001f, 0x003f, - 0x007f, 0x00ff, 0x01ff, 0x03ff, 0x07ff, 0x0fff}; - - GifFilePrivateType *Private = (GifFilePrivateType *)GifFile->Private; - - GifByteType NextByte; - - /* The image can't contain more than LZ_BITS per code. */ - if (Private->RunningBits > LZ_BITS) { - GifFile->Error = D_GIF_ERR_IMAGE_DEFECT; - return GIF_ERROR; - } - - while (Private->CrntShiftState < Private->RunningBits) { - /* Needs to get more bytes from input stream for next code: */ - if (DGifBufferedInput(GifFile, Private->Buf, &NextByte) == - GIF_ERROR) { - return GIF_ERROR; - } - Private->CrntShiftDWord |= ((unsigned long)NextByte) - << Private->CrntShiftState; - Private->CrntShiftState += 8; - } - *Code = Private->CrntShiftDWord & CodeMasks[Private->RunningBits]; - - Private->CrntShiftDWord >>= Private->RunningBits; - Private->CrntShiftState -= Private->RunningBits; - - /* If code cannot fit into RunningBits bits, must raise its size. Note - * however that codes above 4095 are used for special signaling. - * If we're using LZ_BITS bits already and we're at the max code, just - * keep using the table as it is, don't increment Private->RunningCode. - */ - if (Private->RunningCode < LZ_MAX_CODE + 2 && - ++Private->RunningCode > Private->MaxCode1 && - Private->RunningBits < LZ_BITS) { - Private->MaxCode1 <<= 1; - Private->RunningBits++; - } - return GIF_OK; -} - -/****************************************************************************** - This routines read one GIF data block at a time and buffers it internally - so that the decompression routine could access it. - The routine returns the next byte from its internal buffer (or read next - block in if buffer empty) and returns GIF_OK if succesful. -******************************************************************************/ -static int DGifBufferedInput(GifFileType *GifFile, GifByteType *Buf, - GifByteType *NextByte) { - if (Buf[0] == 0) { - /* Needs to read the next buffer - this one is empty: */ - /* coverity[check_return] */ - if (InternalRead(GifFile, Buf, 1) != 1) { - GifFile->Error = D_GIF_ERR_READ_FAILED; - return GIF_ERROR; - } - /* There shouldn't be any empty data blocks here as the LZW spec - * says the LZW termination code should come first. Therefore - * we shouldn't be inside this routine at that point. - */ - if (Buf[0] == 0) { - GifFile->Error = D_GIF_ERR_IMAGE_DEFECT; - return GIF_ERROR; - } - if (InternalRead(GifFile, &Buf[1], Buf[0]) != Buf[0]) { - GifFile->Error = D_GIF_ERR_READ_FAILED; - return GIF_ERROR; - } - *NextByte = Buf[1]; - Buf[1] = 2; /* We use now the second place as last char read! */ - Buf[0]--; - } else { - *NextByte = Buf[Buf[1]++]; - Buf[0]--; - } - - return GIF_OK; -} - -/****************************************************************************** - This routine is called in case of error during parsing image. We need to - decrease image counter and reallocate memory for saved images. Not decreasing - ImageCount may lead to null pointer dereference, because the last element in - SavedImages may point to the spoilt image and null pointer buffers. -*******************************************************************************/ -void DGifDecreaseImageCounter(GifFileType *GifFile) { - GifFile->ImageCount--; - if (GifFile->SavedImages[GifFile->ImageCount].RasterBits != NULL) { - free(GifFile->SavedImages[GifFile->ImageCount].RasterBits); - } - - // Realloc array according to the new image counter. - SavedImage *correct_saved_images = (SavedImage *)reallocarray( - GifFile->SavedImages, GifFile->ImageCount, sizeof(SavedImage)); - if (correct_saved_images != NULL) { - GifFile->SavedImages = correct_saved_images; - } -} - -/****************************************************************************** - This routine reads an entire GIF into core, hanging all its state info off - the GifFileType pointer. Call DGifOpenFileName() or DGifOpenFileHandle() - first to initialize I/O. Its inverse is EGifSpew(). -*******************************************************************************/ -int DGifSlurp(GifFileType *GifFile) { - size_t ImageSize; - GifRecordType RecordType; - SavedImage *sp; - GifByteType *ExtData; - int ExtFunction; - - GifFile->ExtensionBlocks = NULL; - GifFile->ExtensionBlockCount = 0; - - do { - if (DGifGetRecordType(GifFile, &RecordType) == GIF_ERROR) { - return (GIF_ERROR); - } - - switch (RecordType) { - case IMAGE_DESC_RECORD_TYPE: - if (DGifGetImageDesc(GifFile) == GIF_ERROR) { - return (GIF_ERROR); - } - - sp = &GifFile->SavedImages[GifFile->ImageCount - 1]; - /* Allocate memory for the image */ - if (sp->ImageDesc.Width <= 0 || - sp->ImageDesc.Height <= 0 || - sp->ImageDesc.Width > - (INT_MAX / sp->ImageDesc.Height)) { - DGifDecreaseImageCounter(GifFile); - return GIF_ERROR; - } - ImageSize = sp->ImageDesc.Width * sp->ImageDesc.Height; - - if (ImageSize > (SIZE_MAX / sizeof(GifPixelType))) { - DGifDecreaseImageCounter(GifFile); - return GIF_ERROR; - } - sp->RasterBits = (unsigned char *)reallocarray( - NULL, ImageSize, sizeof(GifPixelType)); - - if (sp->RasterBits == NULL) { - DGifDecreaseImageCounter(GifFile); - return GIF_ERROR; - } - - if (sp->ImageDesc.Interlace) { - int i, j; - /* - * The way an interlaced image should be read - - * offsets and jumps... - */ - static const int InterlacedOffset[] = {0, 4, 2, - 1}; - static const int InterlacedJumps[] = {8, 8, 4, - 2}; - /* Need to perform 4 passes on the image */ - for (i = 0; i < 4; i++) { - for (j = InterlacedOffset[i]; - j < sp->ImageDesc.Height; - j += InterlacedJumps[i]) { - if (DGifGetLine( - GifFile, - sp->RasterBits + - j * sp->ImageDesc - .Width, - sp->ImageDesc.Width) == - GIF_ERROR) { - DGifDecreaseImageCounter( - GifFile); - return GIF_ERROR; - } - } - } - } else { - if (DGifGetLine(GifFile, sp->RasterBits, - ImageSize) == GIF_ERROR) { - DGifDecreaseImageCounter(GifFile); - return GIF_ERROR; - } - } - - if (GifFile->ExtensionBlocks) { - sp->ExtensionBlocks = GifFile->ExtensionBlocks; - sp->ExtensionBlockCount = - GifFile->ExtensionBlockCount; - - GifFile->ExtensionBlocks = NULL; - GifFile->ExtensionBlockCount = 0; - } - break; - - case EXTENSION_RECORD_TYPE: - if (DGifGetExtension(GifFile, &ExtFunction, &ExtData) == - GIF_ERROR) { - return (GIF_ERROR); - } - /* Create an extension block with our data */ - if (ExtData != NULL) { - if (GifAddExtensionBlock( - &GifFile->ExtensionBlockCount, - &GifFile->ExtensionBlocks, ExtFunction, - ExtData[0], &ExtData[1]) == GIF_ERROR) { - return (GIF_ERROR); - } - } - for (;;) { - if (DGifGetExtensionNext(GifFile, &ExtData) == - GIF_ERROR) { - return (GIF_ERROR); - } - if (ExtData == NULL) { - break; - } - /* Continue the extension block */ - if (GifAddExtensionBlock( - &GifFile->ExtensionBlockCount, - &GifFile->ExtensionBlocks, - CONTINUE_EXT_FUNC_CODE, ExtData[0], - &ExtData[1]) == GIF_ERROR) { - return (GIF_ERROR); - } - } - break; - - case TERMINATE_RECORD_TYPE: - break; - - default: /* Should be trapped by DGifGetRecordType */ - break; - } - } while (RecordType != TERMINATE_RECORD_TYPE); - - /* Sanity check for corrupted file */ - if (GifFile->ImageCount == 0) { - GifFile->Error = D_GIF_ERR_NO_IMAG_DSCR; - return (GIF_ERROR); - } - - return (GIF_OK); -} - -/* end */ diff --git a/product/include/torchvision/io/image/cpu/giflib/gif_hash.c b/product/include/torchvision/io/image/cpu/giflib/gif_hash.c deleted file mode 100644 index e63a72accd4..00000000000 --- a/product/include/torchvision/io/image/cpu/giflib/gif_hash.c +++ /dev/null @@ -1,128 +0,0 @@ -/***************************************************************************** - -gif_hash.c -- module to support the following operations: - -1. InitHashTable - initialize hash table. -2. ClearHashTable - clear the hash table to an empty state. -2. InsertHashTable - insert one item into data structure. -3. ExistsHashTable - test if item exists in data structure. - -This module is used to hash the GIF codes during encoding. - -*****************************************************************************/ -// SPDX-License-Identifier: MIT -// SPDX-File-Copyright-Txt: (C) Copyright 1989 Gershon Elber - -#include -#include -#include -#include -#include - -#include "gif_hash.h" -#include "gif_lib.h" -#include "gif_lib_private.h" - -/* #define DEBUG_HIT_RATE Debug number of misses per hash Insert/Exists. */ - -#ifdef DEBUG_HIT_RATE -static long NumberOfTests = 0, NumberOfMisses = 0; -#endif /* DEBUG_HIT_RATE */ - -static int KeyItem(uint32_t Item); - -/****************************************************************************** - Initialize HashTable - allocate the memory needed and clear it. * -******************************************************************************/ -GifHashTableType *_InitHashTable(void) { - GifHashTableType *HashTable; - - if ((HashTable = (GifHashTableType *)malloc( - sizeof(GifHashTableType))) == NULL) { - return NULL; - } - - _ClearHashTable(HashTable); - - return HashTable; -} - -/****************************************************************************** - Routine to clear the HashTable to an empty state. * - This part is a little machine depended. Use the commented part otherwise. * -******************************************************************************/ -void _ClearHashTable(GifHashTableType *HashTable) { - memset(HashTable->HTable, 0xFF, HT_SIZE * sizeof(uint32_t)); -} - -/****************************************************************************** - Routine to insert a new Item into the HashTable. The data is assumed to be * - new one. * -******************************************************************************/ -void _InsertHashTable(GifHashTableType *HashTable, uint32_t Key, int Code) { - int HKey = KeyItem(Key); - uint32_t *HTable = HashTable->HTable; - -#ifdef DEBUG_HIT_RATE - NumberOfTests++; - NumberOfMisses++; -#endif /* DEBUG_HIT_RATE */ - - while (HT_GET_KEY(HTable[HKey]) != 0xFFFFFL) { -#ifdef DEBUG_HIT_RATE - NumberOfMisses++; -#endif /* DEBUG_HIT_RATE */ - HKey = (HKey + 1) & HT_KEY_MASK; - } - HTable[HKey] = HT_PUT_KEY(Key) | HT_PUT_CODE(Code); -} - -/****************************************************************************** - Routine to test if given Key exists in HashTable and if so returns its code * - Returns the Code if key was found, -1 if not. * -******************************************************************************/ -int _ExistsHashTable(GifHashTableType *HashTable, uint32_t Key) { - int HKey = KeyItem(Key); - uint32_t *HTable = HashTable->HTable, HTKey; - -#ifdef DEBUG_HIT_RATE - NumberOfTests++; - NumberOfMisses++; -#endif /* DEBUG_HIT_RATE */ - - while ((HTKey = HT_GET_KEY(HTable[HKey])) != 0xFFFFFL) { -#ifdef DEBUG_HIT_RATE - NumberOfMisses++; -#endif /* DEBUG_HIT_RATE */ - if (Key == HTKey) { - return HT_GET_CODE(HTable[HKey]); - } - HKey = (HKey + 1) & HT_KEY_MASK; - } - - return -1; -} - -/****************************************************************************** - Routine to generate an HKey for the hashtable out of the given unique key. * - The given Key is assumed to be 20 bits as follows: lower 8 bits are the * - new postfix character, while the upper 12 bits are the prefix code. * - Because the average hit ratio is only 2 (2 hash references per entry), * - evaluating more complex keys (such as twin prime keys) does not worth it! * -******************************************************************************/ -static int KeyItem(uint32_t Item) { - return ((Item >> 12) ^ Item) & HT_KEY_MASK; -} - -#ifdef DEBUG_HIT_RATE -/****************************************************************************** - Debugging routine to print the hit ratio - number of times the hash table * - was tested per operation. This routine was used to test the KeyItem routine * -******************************************************************************/ -void HashTablePrintHitRatio(void) { - printf("Hash Table Hit Ratio is %ld/%ld = %ld%%.\n", NumberOfMisses, - NumberOfTests, NumberOfMisses * 100 / NumberOfTests); -} -#endif /* DEBUG_HIT_RATE */ - -/* end */ diff --git a/product/include/torchvision/io/image/cpu/giflib/gif_hash.h b/product/include/torchvision/io/image/cpu/giflib/gif_hash.h deleted file mode 100644 index 009cb5b8081..00000000000 --- a/product/include/torchvision/io/image/cpu/giflib/gif_hash.h +++ /dev/null @@ -1,42 +0,0 @@ -/****************************************************************************** - -gif_hash.h - magfic constants and declarations for GIF LZW - -******************************************************************************/ -// SPDX-License-Identifier: MIT - -#ifndef _GIF_HASH_H_ -#define _GIF_HASH_H_ - -#ifndef _WIN32 -#include -#endif /* _WIN32 */ -#include - -#define HT_SIZE 8192 /* 12bits = 4096 or twice as big! */ -#define HT_KEY_MASK 0x1FFF /* 13bits keys */ -#define HT_KEY_NUM_BITS 13 /* 13bits keys */ -#define HT_MAX_KEY 8191 /* 13bits - 1, maximal code possible */ -#define HT_MAX_CODE 4095 /* Biggest code possible in 12 bits. */ - -/* The 32 bits of the long are divided into two parts for the key & code: */ -/* 1. The code is 12 bits as our compression algorithm is limited to 12bits */ -/* 2. The key is 12 bits Prefix code + 8 bit new char or 20 bits. */ -/* The key is the upper 20 bits. The code is the lower 12. */ -#define HT_GET_KEY(l) (l >> 12) -#define HT_GET_CODE(l) (l & 0x0FFF) -#define HT_PUT_KEY(l) (l << 12) -#define HT_PUT_CODE(l) (l & 0x0FFF) - -typedef struct GifHashTableType { - uint32_t HTable[HT_SIZE]; -} GifHashTableType; - -GifHashTableType *_InitHashTable(void); -void _ClearHashTable(GifHashTableType *HashTable); -void _InsertHashTable(GifHashTableType *HashTable, uint32_t Key, int Code); -int _ExistsHashTable(GifHashTableType *HashTable, uint32_t Key); - -#endif /* _GIF_HASH_H_ */ - -/* end */ diff --git a/product/include/torchvision/io/image/cpu/giflib/gif_lib.h b/product/include/torchvision/io/image/cpu/giflib/gif_lib.h deleted file mode 100644 index d0c61d51682..00000000000 --- a/product/include/torchvision/io/image/cpu/giflib/gif_lib.h +++ /dev/null @@ -1,291 +0,0 @@ -/****************************************************************************** - -gif_lib.h - service library for decoding and encoding GIF images - -SPDX-License-Identifier: MIT - -*****************************************************************************/ - -#ifndef _GIF_LIB_H_ -#define _GIF_LIB_H_ 1 - -#ifdef __cplusplus -extern "C" { -#endif /* __cplusplus */ - -#define GIFLIB_MAJOR 5 -#define GIFLIB_MINOR 2 -#define GIFLIB_RELEASE 2 - -#define GIF_ERROR 0 -#define GIF_OK 1 - -#include -#include - -#define GIF_STAMP "GIFVER" /* First chars in file - GIF stamp. */ -#define GIF_STAMP_LEN sizeof(GIF_STAMP) - 1 -#define GIF_VERSION_POS 3 /* Version first character in stamp. */ -#define GIF87_STAMP "GIF87a" /* First chars in file - GIF stamp. */ -#define GIF89_STAMP "GIF89a" /* First chars in file - GIF stamp. */ - -typedef unsigned char GifPixelType; -typedef unsigned char *GifRowType; -typedef unsigned char GifByteType; -typedef unsigned int GifPrefixType; -typedef int GifWord; - -typedef struct GifColorType { - GifByteType Red, Green, Blue; -} GifColorType; - -typedef struct ColorMapObject { - int ColorCount; - int BitsPerPixel; - bool SortFlag; - GifColorType *Colors; /* on malloc(3) heap */ -} ColorMapObject; - -typedef struct GifImageDesc { - GifWord Left, Top, Width, Height; /* Current image dimensions. */ - bool Interlace; /* Sequential/Interlaced lines. */ - ColorMapObject *ColorMap; /* The local color map */ -} GifImageDesc; - -typedef struct ExtensionBlock { - int ByteCount; - GifByteType *Bytes; /* on malloc(3) heap */ - int Function; /* The block function code */ -#define CONTINUE_EXT_FUNC_CODE 0x00 /* continuation subblock */ -#define COMMENT_EXT_FUNC_CODE 0xfe /* comment */ -#define GRAPHICS_EXT_FUNC_CODE 0xf9 /* graphics control (GIF89) */ -#define PLAINTEXT_EXT_FUNC_CODE 0x01 /* plaintext */ -#define APPLICATION_EXT_FUNC_CODE 0xff /* application block (GIF89) */ -} ExtensionBlock; - -typedef struct SavedImage { - GifImageDesc ImageDesc; - GifByteType *RasterBits; /* on malloc(3) heap */ - int ExtensionBlockCount; /* Count of extensions before image */ - ExtensionBlock *ExtensionBlocks; /* Extensions before image */ -} SavedImage; - -typedef struct GifFileType { - GifWord SWidth, SHeight; /* Size of virtual canvas */ - GifWord SColorResolution; /* How many colors can we generate? */ - GifWord SBackGroundColor; /* Background color for virtual canvas */ - GifByteType AspectByte; /* Used to compute pixel aspect ratio */ - ColorMapObject *SColorMap; /* Global colormap, NULL if nonexistent. */ - int ImageCount; /* Number of current image (both APIs) */ - GifImageDesc Image; /* Current image (low-level API) */ - SavedImage *SavedImages; /* Image sequence (high-level API) */ - int ExtensionBlockCount; /* Count extensions past last image */ - ExtensionBlock *ExtensionBlocks; /* Extensions past last image */ - int Error; /* Last error condition reported */ - void *UserData; /* hook to attach user data (TVT) */ - void *Private; /* Don't mess with this! */ -} GifFileType; - -#define GIF_ASPECT_RATIO(n) ((n) + 15.0 / 64.0) - -typedef enum { - UNDEFINED_RECORD_TYPE, - SCREEN_DESC_RECORD_TYPE, - IMAGE_DESC_RECORD_TYPE, /* Begin with ',' */ - EXTENSION_RECORD_TYPE, /* Begin with '!' */ - TERMINATE_RECORD_TYPE /* Begin with ';' */ -} GifRecordType; - -/* func type to read gif data from arbitrary sources (TVT) */ -typedef int (*InputFunc)(GifFileType *, GifByteType *, int); - -/* func type to write gif data to arbitrary targets. - * Returns count of bytes written. (MRB) - */ -typedef int (*OutputFunc)(GifFileType *, const GifByteType *, int); - -/****************************************************************************** - GIF89 structures -******************************************************************************/ - -typedef struct GraphicsControlBlock { - int DisposalMode; -#define DISPOSAL_UNSPECIFIED 0 /* No disposal specified. */ -#define DISPOSE_DO_NOT 1 /* Leave image in place */ -#define DISPOSE_BACKGROUND 2 /* Set area too background color */ -#define DISPOSE_PREVIOUS 3 /* Restore to previous content */ - bool UserInputFlag; /* User confirmation required before disposal */ - int DelayTime; /* pre-display delay in 0.01sec units */ - int TransparentColor; /* Palette index for transparency, -1 if none */ -#define NO_TRANSPARENT_COLOR -1 -} GraphicsControlBlock; - -/****************************************************************************** - GIF encoding routines -******************************************************************************/ - -/* Main entry points */ -GifFileType *EGifOpenFileName(const char *GifFileName, - const bool GifTestExistence, int *Error); -GifFileType *EGifOpenFileHandle(const int GifFileHandle, int *Error); -GifFileType *EGifOpen(void *userPtr, OutputFunc writeFunc, int *Error); -int EGifSpew(GifFileType *GifFile); -const char *EGifGetGifVersion(GifFileType *GifFile); /* new in 5.x */ -int EGifCloseFile(GifFileType *GifFile, int *ErrorCode); - -#define E_GIF_SUCCEEDED 0 -#define E_GIF_ERR_OPEN_FAILED 1 /* And EGif possible errors. */ -#define E_GIF_ERR_WRITE_FAILED 2 -#define E_GIF_ERR_HAS_SCRN_DSCR 3 -#define E_GIF_ERR_HAS_IMAG_DSCR 4 -#define E_GIF_ERR_NO_COLOR_MAP 5 -#define E_GIF_ERR_DATA_TOO_BIG 6 -#define E_GIF_ERR_NOT_ENOUGH_MEM 7 -#define E_GIF_ERR_DISK_IS_FULL 8 -#define E_GIF_ERR_CLOSE_FAILED 9 -#define E_GIF_ERR_NOT_WRITEABLE 10 - -/* These are legacy. You probably do not want to call them directly */ -int EGifPutScreenDesc(GifFileType *GifFile, const int GifWidth, - const int GifHeight, const int GifColorRes, - const int GifBackGround, - const ColorMapObject *GifColorMap); -int EGifPutImageDesc(GifFileType *GifFile, const int GifLeft, const int GifTop, - const int GifWidth, const int GifHeight, - const bool GifInterlace, - const ColorMapObject *GifColorMap); -void EGifSetGifVersion(GifFileType *GifFile, const bool gif89); -int EGifPutLine(GifFileType *GifFile, GifPixelType *GifLine, int GifLineLen); -int EGifPutPixel(GifFileType *GifFile, const GifPixelType GifPixel); -int EGifPutComment(GifFileType *GifFile, const char *GifComment); -int EGifPutExtensionLeader(GifFileType *GifFile, const int GifExtCode); -int EGifPutExtensionBlock(GifFileType *GifFile, const int GifExtLen, - const void *GifExtension); -int EGifPutExtensionTrailer(GifFileType *GifFile); -int EGifPutExtension(GifFileType *GifFile, const int GifExtCode, - const int GifExtLen, const void *GifExtension); -int EGifPutCode(GifFileType *GifFile, int GifCodeSize, - const GifByteType *GifCodeBlock); -int EGifPutCodeNext(GifFileType *GifFile, const GifByteType *GifCodeBlock); - -/****************************************************************************** - GIF decoding routines -******************************************************************************/ - -/* Main entry points */ -GifFileType *DGifOpenFileName(const char *GifFileName, int *Error); -GifFileType *DGifOpenFileHandle(int GifFileHandle, int *Error); -int DGifSlurp(GifFileType *GifFile); -GifFileType *DGifOpen(void *userPtr, InputFunc readFunc, - int *Error); /* new one (TVT) */ -int DGifCloseFile(GifFileType *GifFile, int *ErrorCode); - -#define D_GIF_SUCCEEDED 0 -#define D_GIF_ERR_OPEN_FAILED 101 /* And DGif possible errors. */ -#define D_GIF_ERR_READ_FAILED 102 -#define D_GIF_ERR_NOT_GIF_FILE 103 -#define D_GIF_ERR_NO_SCRN_DSCR 104 -#define D_GIF_ERR_NO_IMAG_DSCR 105 -#define D_GIF_ERR_NO_COLOR_MAP 106 -#define D_GIF_ERR_WRONG_RECORD 107 -#define D_GIF_ERR_DATA_TOO_BIG 108 -#define D_GIF_ERR_NOT_ENOUGH_MEM 109 -#define D_GIF_ERR_CLOSE_FAILED 110 -#define D_GIF_ERR_NOT_READABLE 111 -#define D_GIF_ERR_IMAGE_DEFECT 112 -#define D_GIF_ERR_EOF_TOO_SOON 113 - -/* These are legacy. You probably do not want to call them directly */ -int DGifGetScreenDesc(GifFileType *GifFile); -int DGifGetRecordType(GifFileType *GifFile, GifRecordType *GifType); -int DGifGetImageHeader(GifFileType *GifFile); -int DGifGetImageDesc(GifFileType *GifFile); -int DGifGetLine(GifFileType *GifFile, GifPixelType *GifLine, int GifLineLen); -int DGifGetPixel(GifFileType *GifFile, GifPixelType GifPixel); -int DGifGetExtension(GifFileType *GifFile, int *GifExtCode, - GifByteType **GifExtension); -int DGifGetExtensionNext(GifFileType *GifFile, GifByteType **GifExtension); -int DGifGetCode(GifFileType *GifFile, int *GifCodeSize, - GifByteType **GifCodeBlock); -int DGifGetCodeNext(GifFileType *GifFile, GifByteType **GifCodeBlock); -int DGifGetLZCodes(GifFileType *GifFile, int *GifCode); -const char *DGifGetGifVersion(GifFileType *GifFile); - -/****************************************************************************** - Error handling and reporting. -******************************************************************************/ -extern const char *GifErrorString(int ErrorCode); /* new in 2012 - ESR */ - -/***************************************************************************** - it g in core. -******************************************************************************/ - -/****************************************************************************** - Color map handling from gif_alloc.c -******************************************************************************/ - -extern ColorMapObject *GifMakeMapObject(int ColorCount, - const GifColorType *ColorMap); -extern void GifFreeMapObject(ColorMapObject *Object); -extern ColorMapObject *GifUnionColorMap(const ColorMapObject *ColorIn1, - const ColorMapObject *ColorIn2, - GifPixelType ColorTransIn2[]); -extern int GifBitSize(int n); - -/****************************************************************************** - Support for the in-core structures allocation (slurp mode). -******************************************************************************/ - -extern void GifApplyTranslation(SavedImage *Image, - const GifPixelType Translation[]); -extern int GifAddExtensionBlock(int *ExtensionBlock_Count, - ExtensionBlock **ExtensionBlocks, int Function, - unsigned int Len, unsigned char ExtData[]); -extern void GifFreeExtensions(int *ExtensionBlock_Count, - ExtensionBlock **ExtensionBlocks); -extern SavedImage *GifMakeSavedImage(GifFileType *GifFile, - const SavedImage *CopyFrom); -extern void GifFreeSavedImages(GifFileType *GifFile); - -/****************************************************************************** - 5.x functions for GIF89 graphics control blocks -******************************************************************************/ - -int DGifExtensionToGCB(const size_t GifExtensionLength, - const GifByteType *GifExtension, - GraphicsControlBlock *GCB); -size_t EGifGCBToExtension(const GraphicsControlBlock *GCB, - GifByteType *GifExtension); - -int DGifSavedExtensionToGCB(GifFileType *GifFile, int ImageIndex, - GraphicsControlBlock *GCB); -int EGifGCBToSavedExtension(const GraphicsControlBlock *GCB, - GifFileType *GifFile, int ImageIndex); - -/****************************************************************************** - The library's internal utility font -******************************************************************************/ - -#define GIF_FONT_WIDTH 8 -#define GIF_FONT_HEIGHT 8 -extern const unsigned char GifAsciiTable8x8[][GIF_FONT_WIDTH]; - -extern void GifDrawText8x8(SavedImage *Image, const int x, const int y, - const char *legend, const int color); - -extern void GifDrawBox(SavedImage *Image, const int x, const int y, const int w, - const int d, const int color); - -extern void GifDrawRectangle(SavedImage *Image, const int x, const int y, - const int w, const int d, const int color); - -extern void GifDrawBoxedText8x8(SavedImage *Image, const int x, const int y, - const char *legend, const int border, - const int bg, const int fg); - -#ifdef __cplusplus -} -#endif /* __cplusplus */ -#endif /* _GIF_LIB_H */ - -/* end */ diff --git a/product/include/torchvision/io/image/cpu/giflib/gif_lib_private.h b/product/include/torchvision/io/image/cpu/giflib/gif_lib_private.h deleted file mode 100644 index 19578d4530c..00000000000 --- a/product/include/torchvision/io/image/cpu/giflib/gif_lib_private.h +++ /dev/null @@ -1,72 +0,0 @@ -/**************************************************************************** - -gif_lib_private.h - internal giflib routines and structures - -SPDX-License-Identifier: MIT - -****************************************************************************/ - -#ifndef _GIF_LIB_PRIVATE_H -#define _GIF_LIB_PRIVATE_H - -#include "gif_hash.h" -#include "gif_lib.h" - -#ifndef SIZE_MAX -#define SIZE_MAX UINTPTR_MAX -#endif - -#define EXTENSION_INTRODUCER 0x21 -#define DESCRIPTOR_INTRODUCER 0x2c -#define TERMINATOR_INTRODUCER 0x3b - -#define LZ_MAX_CODE 4095 /* Biggest code possible in 12 bits. */ -#define LZ_BITS 12 - -#define FLUSH_OUTPUT 4096 /* Impossible code, to signal flush. */ -#define FIRST_CODE 4097 /* Impossible code, to signal first. */ -#define NO_SUCH_CODE 4098 /* Impossible code, to signal empty. */ - -#define FILE_STATE_WRITE 0x01 -#define FILE_STATE_SCREEN 0x02 -#define FILE_STATE_IMAGE 0x04 -#define FILE_STATE_READ 0x08 - -#define IS_READABLE(Private) (Private->FileState & FILE_STATE_READ) -#define IS_WRITEABLE(Private) (Private->FileState & FILE_STATE_WRITE) - -typedef struct GifFilePrivateType { - GifWord FileState, FileHandle, /* Where all this data goes to! */ - BitsPerPixel, /* Bits per pixel (Codes uses at least this + 1). */ - ClearCode, /* The CLEAR LZ code. */ - EOFCode, /* The EOF LZ code. */ - RunningCode, /* The next code algorithm can generate. */ - RunningBits, /* The number of bits required to represent - RunningCode. */ - MaxCode1, /* 1 bigger than max. possible code, in RunningBits bits. - */ - LastCode, /* The code before the current code. */ - CrntCode, /* Current algorithm code. */ - StackPtr, /* For character stack (see below). */ - CrntShiftState; /* Number of bits in CrntShiftDWord. */ - unsigned long CrntShiftDWord; /* For bytes decomposition into codes. */ - unsigned long PixelCount; /* Number of pixels in image. */ - FILE *File; /* File as stream. */ - InputFunc Read; /* function to read gif input (TVT) */ - OutputFunc Write; /* function to write gif output (MRB) */ - GifByteType Buf[256]; /* Compressed input is buffered here. */ - GifByteType Stack[LZ_MAX_CODE]; /* Decoded pixels are stacked here. */ - GifByteType Suffix[LZ_MAX_CODE + 1]; /* So we can trace the codes. */ - GifPrefixType Prefix[LZ_MAX_CODE + 1]; - GifHashTableType *HashTable; - bool gif89; -} GifFilePrivateType; - -#ifndef HAVE_REALLOCARRAY -extern void *openbsd_reallocarray(void *optr, size_t nmemb, size_t size); -#define reallocarray openbsd_reallocarray -#endif - -#endif /* _GIF_LIB_PRIVATE_H */ - -/* end */ diff --git a/product/include/torchvision/io/image/cpu/giflib/gifalloc.c b/product/include/torchvision/io/image/cpu/giflib/gifalloc.c deleted file mode 100644 index 926d54ebcf7..00000000000 --- a/product/include/torchvision/io/image/cpu/giflib/gifalloc.c +++ /dev/null @@ -1,425 +0,0 @@ -/***************************************************************************** - - GIF construction tools - -****************************************************************************/ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Copyright (C) Eric S. Raymond - -#include -#include -#include - -#include "gif_lib.h" -#include "gif_lib_private.h" - -#define MAX(x, y) (((x) > (y)) ? (x) : (y)) - -/****************************************************************************** - Miscellaneous utility functions -******************************************************************************/ - -/* return smallest bitfield size n will fit in */ -int GifBitSize(int n) { - int i; - - for (i = 1; i <= 8; i++) { - if ((1 << i) >= n) { - break; - } - } - return (i); -} - -/****************************************************************************** - Color map object functions -******************************************************************************/ - -/* - * Allocate a color map of given size; initialize with contents of - * ColorMap if that pointer is non-NULL. - */ -ColorMapObject *GifMakeMapObject(int ColorCount, const GifColorType *ColorMap) { - ColorMapObject *Object; - - /*** FIXME: Our ColorCount has to be a power of two. Is it necessary to - * make the user know that or should we automatically round up instead? - */ - if (ColorCount != (1 << GifBitSize(ColorCount))) { - return ((ColorMapObject *)NULL); - } - - Object = (ColorMapObject *)malloc(sizeof(ColorMapObject)); - if (Object == (ColorMapObject *)NULL) { - return ((ColorMapObject *)NULL); - } - - Object->Colors = - (GifColorType *)calloc(ColorCount, sizeof(GifColorType)); - if (Object->Colors == (GifColorType *)NULL) { - free(Object); - return ((ColorMapObject *)NULL); - } - - Object->ColorCount = ColorCount; - Object->BitsPerPixel = GifBitSize(ColorCount); - Object->SortFlag = false; - - if (ColorMap != NULL) { - memcpy((char *)Object->Colors, (char *)ColorMap, - ColorCount * sizeof(GifColorType)); - } - - return (Object); -} - -/******************************************************************************* - Free a color map object -*******************************************************************************/ -void GifFreeMapObject(ColorMapObject *Object) { - if (Object != NULL) { - (void)free(Object->Colors); - (void)free(Object); - } -} - -#ifdef DEBUG -void DumpColorMap(ColorMapObject *Object, FILE *fp) { - if (Object != NULL) { - int i, j, Len = Object->ColorCount; - - for (i = 0; i < Len; i += 4) { - for (j = 0; j < 4 && j < Len; j++) { - (void)fprintf(fp, "%3d: %02x %02x %02x ", - i + j, Object->Colors[i + j].Red, - Object->Colors[i + j].Green, - Object->Colors[i + j].Blue); - } - (void)fprintf(fp, "\n"); - } - } -} -#endif /* DEBUG */ - -/******************************************************************************* - Compute the union of two given color maps and return it. If result can't - fit into 256 colors, NULL is returned, the allocated union otherwise. - ColorIn1 is copied as is to ColorUnion, while colors from ColorIn2 are - copied iff they didn't exist before. ColorTransIn2 maps the old - ColorIn2 into the ColorUnion color map table./ -*******************************************************************************/ -ColorMapObject *GifUnionColorMap(const ColorMapObject *ColorIn1, - const ColorMapObject *ColorIn2, - GifPixelType ColorTransIn2[]) { - int i, j, CrntSlot, RoundUpTo, NewGifBitSize; - ColorMapObject *ColorUnion; - - /* - * We don't worry about duplicates within either color map; if - * the caller wants to resolve those, he can perform unions - * with an empty color map. - */ - - /* Allocate table which will hold the result for sure. */ - ColorUnion = GifMakeMapObject( - MAX(ColorIn1->ColorCount, ColorIn2->ColorCount) * 2, NULL); - - if (ColorUnion == NULL) { - return (NULL); - } - - /* - * Copy ColorIn1 to ColorUnion. - */ - for (i = 0; i < ColorIn1->ColorCount; i++) { - ColorUnion->Colors[i] = ColorIn1->Colors[i]; - } - CrntSlot = ColorIn1->ColorCount; - - /* - * Potentially obnoxious hack: - * - * Back CrntSlot down past all contiguous {0, 0, 0} slots at the end - * of table 1. This is very useful if your display is limited to - * 16 colors. - */ - while (ColorIn1->Colors[CrntSlot - 1].Red == 0 && - ColorIn1->Colors[CrntSlot - 1].Green == 0 && - ColorIn1->Colors[CrntSlot - 1].Blue == 0) { - CrntSlot--; - } - - /* Copy ColorIn2 to ColorUnion (use old colors if they exist): */ - for (i = 0; i < ColorIn2->ColorCount && CrntSlot <= 256; i++) { - /* Let's see if this color already exists: */ - for (j = 0; j < ColorIn1->ColorCount; j++) { - if (memcmp(&ColorIn1->Colors[j], &ColorIn2->Colors[i], - sizeof(GifColorType)) == 0) { - break; - } - } - - if (j < ColorIn1->ColorCount) { - ColorTransIn2[i] = j; /* color exists in Color1 */ - } else { - /* Color is new - copy it to a new slot: */ - ColorUnion->Colors[CrntSlot] = ColorIn2->Colors[i]; - ColorTransIn2[i] = CrntSlot++; - } - } - - if (CrntSlot > 256) { - GifFreeMapObject(ColorUnion); - return ((ColorMapObject *)NULL); - } - - NewGifBitSize = GifBitSize(CrntSlot); - RoundUpTo = (1 << NewGifBitSize); - - if (RoundUpTo != ColorUnion->ColorCount) { - GifColorType *Map = ColorUnion->Colors; - - /* - * Zero out slots up to next power of 2. - * We know these slots exist because of the way ColorUnion's - * start dimension was computed. - */ - for (j = CrntSlot; j < RoundUpTo; j++) { - Map[j].Red = Map[j].Green = Map[j].Blue = 0; - } - - /* perhaps we can shrink the map? */ - if (RoundUpTo < ColorUnion->ColorCount) { - GifColorType *new_map = (GifColorType *)reallocarray( - Map, RoundUpTo, sizeof(GifColorType)); - if (new_map == NULL) { - GifFreeMapObject(ColorUnion); - return ((ColorMapObject *)NULL); - } - ColorUnion->Colors = new_map; - } - } - - ColorUnion->ColorCount = RoundUpTo; - ColorUnion->BitsPerPixel = NewGifBitSize; - - return (ColorUnion); -} - -/******************************************************************************* - Apply a given color translation to the raster bits of an image -*******************************************************************************/ -void GifApplyTranslation(SavedImage *Image, const GifPixelType Translation[]) { - int i; - int RasterSize = - Image->ImageDesc.Height * Image->ImageDesc.Width; - - for (i = 0; i < RasterSize; i++) { - Image->RasterBits[i] = Translation[Image->RasterBits[i]]; - } -} - -/****************************************************************************** - Extension record functions -******************************************************************************/ -int GifAddExtensionBlock(int *ExtensionBlockCount, - ExtensionBlock **ExtensionBlocks, int Function, - unsigned int Len, unsigned char ExtData[]) { - ExtensionBlock *ep; - - if (*ExtensionBlocks == NULL) { - *ExtensionBlocks = - (ExtensionBlock *)malloc(sizeof(ExtensionBlock)); - } else { - ExtensionBlock *ep_new = (ExtensionBlock *)reallocarray( - *ExtensionBlocks, (*ExtensionBlockCount + 1), - sizeof(ExtensionBlock)); - if (ep_new == NULL) { - return (GIF_ERROR); - } - *ExtensionBlocks = ep_new; - } - - if (*ExtensionBlocks == NULL) { - return (GIF_ERROR); - } - - ep = &(*ExtensionBlocks)[(*ExtensionBlockCount)++]; - - ep->Function = Function; - ep->ByteCount = Len; - ep->Bytes = (GifByteType *)malloc(ep->ByteCount); - if (ep->Bytes == NULL) { - return (GIF_ERROR); - } - - if (ExtData != NULL) { - memcpy(ep->Bytes, ExtData, Len); - } - - return (GIF_OK); -} - -void GifFreeExtensions(int *ExtensionBlockCount, - ExtensionBlock **ExtensionBlocks) { - ExtensionBlock *ep; - - if (*ExtensionBlocks == NULL) { - return; - } - - for (ep = *ExtensionBlocks; - ep < (*ExtensionBlocks + *ExtensionBlockCount); ep++) { - (void)free((char *)ep->Bytes); - } - (void)free((char *)*ExtensionBlocks); - *ExtensionBlocks = NULL; - *ExtensionBlockCount = 0; -} - -/****************************************************************************** - Image block allocation functions -******************************************************************************/ - -/* Private Function: - * Frees the last image in the GifFile->SavedImages array - */ -void FreeLastSavedImage(GifFileType *GifFile) { - SavedImage *sp; - - if ((GifFile == NULL) || (GifFile->SavedImages == NULL)) { - return; - } - - /* Remove one SavedImage from the GifFile */ - GifFile->ImageCount--; - sp = &GifFile->SavedImages[GifFile->ImageCount]; - - /* Deallocate its Colormap */ - if (sp->ImageDesc.ColorMap != NULL) { - GifFreeMapObject(sp->ImageDesc.ColorMap); - sp->ImageDesc.ColorMap = NULL; - } - - /* Deallocate the image data */ - if (sp->RasterBits != NULL) { - free((char *)sp->RasterBits); - } - - /* Deallocate any extensions */ - GifFreeExtensions(&sp->ExtensionBlockCount, &sp->ExtensionBlocks); - - /*** FIXME: We could realloc the GifFile->SavedImages structure but is - * there a point to it? Saves some memory but we'd have to do it every - * time. If this is used in GifFreeSavedImages then it would be - * inefficient (The whole array is going to be deallocated.) If we just - * use it when we want to free the last Image it's convenient to do it - * here. - */ -} - -/* - * Append an image block to the SavedImages array - */ -SavedImage *GifMakeSavedImage(GifFileType *GifFile, - const SavedImage *CopyFrom) { - // cppcheck-suppress ctunullpointer - if (GifFile->SavedImages == NULL) { - GifFile->SavedImages = (SavedImage *)malloc(sizeof(SavedImage)); - } else { - SavedImage *newSavedImages = (SavedImage *)reallocarray( - GifFile->SavedImages, (GifFile->ImageCount + 1), - sizeof(SavedImage)); - if (newSavedImages == NULL) { - return ((SavedImage *)NULL); - } - GifFile->SavedImages = newSavedImages; - } - if (GifFile->SavedImages == NULL) { - return ((SavedImage *)NULL); - } else { - SavedImage *sp = &GifFile->SavedImages[GifFile->ImageCount++]; - - if (CopyFrom != NULL) { - memcpy((char *)sp, CopyFrom, sizeof(SavedImage)); - - /* - * Make our own allocated copies of the heap fields in - * the copied record. This guards against potential - * aliasing problems. - */ - - /* first, the local color map */ - if (CopyFrom->ImageDesc.ColorMap != NULL) { - sp->ImageDesc.ColorMap = GifMakeMapObject( - CopyFrom->ImageDesc.ColorMap->ColorCount, - CopyFrom->ImageDesc.ColorMap->Colors); - if (sp->ImageDesc.ColorMap == NULL) { - FreeLastSavedImage(GifFile); - return (SavedImage *)(NULL); - } - } - - /* next, the raster */ - sp->RasterBits = (unsigned char *)reallocarray( - NULL, - (CopyFrom->ImageDesc.Height * - CopyFrom->ImageDesc.Width), - sizeof(GifPixelType)); - if (sp->RasterBits == NULL) { - FreeLastSavedImage(GifFile); - return (SavedImage *)(NULL); - } - memcpy(sp->RasterBits, CopyFrom->RasterBits, - sizeof(GifPixelType) * - CopyFrom->ImageDesc.Height * - CopyFrom->ImageDesc.Width); - - /* finally, the extension blocks */ - if (CopyFrom->ExtensionBlocks != NULL) { - sp->ExtensionBlocks = - (ExtensionBlock *)reallocarray( - NULL, CopyFrom->ExtensionBlockCount, - sizeof(ExtensionBlock)); - if (sp->ExtensionBlocks == NULL) { - FreeLastSavedImage(GifFile); - return (SavedImage *)(NULL); - } - memcpy(sp->ExtensionBlocks, - CopyFrom->ExtensionBlocks, - sizeof(ExtensionBlock) * - CopyFrom->ExtensionBlockCount); - } - } else { - memset((char *)sp, '\0', sizeof(SavedImage)); - } - - return (sp); - } -} - -void GifFreeSavedImages(GifFileType *GifFile) { - SavedImage *sp; - - if ((GifFile == NULL) || (GifFile->SavedImages == NULL)) { - return; - } - for (sp = GifFile->SavedImages; - sp < GifFile->SavedImages + GifFile->ImageCount; sp++) { - if (sp->ImageDesc.ColorMap != NULL) { - GifFreeMapObject(sp->ImageDesc.ColorMap); - sp->ImageDesc.ColorMap = NULL; - } - - if (sp->RasterBits != NULL) { - free((char *)sp->RasterBits); - } - - GifFreeExtensions(&sp->ExtensionBlockCount, - &sp->ExtensionBlocks); - } - free((char *)GifFile->SavedImages); - GifFile->SavedImages = NULL; -} - -/* end */ diff --git a/product/include/torchvision/io/image/cpu/giflib/openbsd-reallocarray.c b/product/include/torchvision/io/image/cpu/giflib/openbsd-reallocarray.c deleted file mode 100644 index e09ab245ad4..00000000000 --- a/product/include/torchvision/io/image/cpu/giflib/openbsd-reallocarray.c +++ /dev/null @@ -1,73 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (C) 2008 Otto Moerbeek - * SPDX-License-Identifier: MIT - */ - -#include -#include -#include -#include - -#ifndef SIZE_MAX -#define SIZE_MAX UINTPTR_MAX -#endif - -/* - * This is sqrt(SIZE_MAX+1), as s1*s2 <= SIZE_MAX - * if both s1 < MUL_NO_OVERFLOW and s2 < MUL_NO_OVERFLOW - */ -#define MUL_NO_OVERFLOW ((size_t)1 << (sizeof(size_t) * 4)) - -void *openbsd_reallocarray(void *optr, size_t nmemb, size_t size) { - if ((nmemb >= MUL_NO_OVERFLOW || size >= MUL_NO_OVERFLOW) && - nmemb > 0 && SIZE_MAX / nmemb < size) { - errno = ENOMEM; - return NULL; - } - /* - * Head off variations in realloc behavior on different - * platforms (reported by MarkR ) - * - * The behaviour of reallocarray is implementation-defined if - * nmemb or size is zero. It can return NULL or non-NULL - * depending on the platform. - * https://www.securecoding.cert.org/confluence/display/c/MEM04-C.Beware+of+zero-lengthallocations - * - * Here are some extracts from realloc man pages on different platforms. - * - * void realloc( void memblock, size_t size ); - * - * Windows: - * - * If there is not enough available memory to expand the block - * to the given size, the original block is left unchanged, - * and NULL is returned. If size is zero, then the block - * pointed to by memblock is freed; the return value is NULL, - * and memblock is left pointing at a freed block. - * - * OpenBSD: - * - * If size or nmemb is equal to 0, a unique pointer to an - * access protected, zero sized object is returned. Access via - * this pointer will generate a SIGSEGV exception. - * - * Linux: - * - * If size was equal to 0, either NULL or a pointer suitable - * to be passed to free() is returned. - * - * OS X: - * - * If size is zero and ptr is not NULL, a new, minimum sized - * object is allocated and the original object is freed. - * - * It looks like images with zero width or height can trigger - * this, and fuzzing behaviour will differ by platform, so - * fuzzing on one platform may not detect zero-size allocation - * problems on other platforms. - */ - if (size == 0 || nmemb == 0) { - return NULL; - } - return realloc(optr, size * nmemb); -} diff --git a/product/include/torchvision/io/image/cpu/read_write_file.cpp b/product/include/torchvision/io/image/cpu/read_write_file.cpp deleted file mode 100644 index 06de72a5053..00000000000 --- a/product/include/torchvision/io/image/cpu/read_write_file.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include "read_write_file.h" - -#include - -#ifdef _WIN32 -#define WIN32_LEAN_AND_MEAN -#include -#endif - -namespace vision { -namespace image { - -#ifdef _WIN32 -namespace { -std::wstring utf8_decode(const std::string& str) { - if (str.empty()) { - return std::wstring(); - } - int size_needed = MultiByteToWideChar( - CP_UTF8, 0, str.c_str(), static_cast(str.size()), nullptr, 0); - TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode"); - std::wstring wstrTo(size_needed, 0); - MultiByteToWideChar( - CP_UTF8, - 0, - str.c_str(), - static_cast(str.size()), - &wstrTo[0], - size_needed); - return wstrTo; -} -} // namespace -#endif - -torch::Tensor read_file(const std::string& filename) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.read_write_file.read_file"); -#ifdef _WIN32 - // According to - // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019, - // we should use struct __stat64 and _wstat64 for 64-bit file size on Windows. - struct __stat64 stat_buf; - auto fileW = utf8_decode(filename); - int rc = _wstat64(fileW.c_str(), &stat_buf); -#else - struct stat stat_buf; - int rc = stat(filename.c_str(), &stat_buf); -#endif - // errno is a variable defined in errno.h - TORCH_CHECK( - rc == 0, "[Errno ", errno, "] ", strerror(errno), ": '", filename, "'"); - - int64_t size = stat_buf.st_size; - - TORCH_CHECK(size > 0, "Expected a non empty file"); - -#ifdef _WIN32 - // TODO: Once torch::from_file handles UTF-8 paths correctly, we should move - // back to use the following implementation since it uses file mapping. - // auto data = - // torch::from_file(filename, /*shared=*/false, /*size=*/size, - // torch::kU8).clone() - FILE* infile = _wfopen(fileW.c_str(), L"rb"); - - TORCH_CHECK(infile != nullptr, "Error opening input file"); - - auto data = torch::empty({size}, torch::kU8); - auto dataBytes = data.data_ptr(); - - fread(dataBytes, sizeof(uint8_t), size, infile); - fclose(infile); -#else - auto data = - torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8); -#endif - - return data; -} - -void write_file(const std::string& filename, torch::Tensor& data) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.read_write_file.write_file"); - // Check that the input tensor is on CPU - TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); - - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); - - // Check that the input tensor is 3-dimensional - TORCH_CHECK(data.dim() == 1, "Input data should be a 1-dimensional tensor"); - - auto fileBytes = data.data_ptr(); - auto fileCStr = filename.c_str(); -#ifdef _WIN32 - auto fileW = utf8_decode(filename); - FILE* outfile = _wfopen(fileW.c_str(), L"wb"); -#else - FILE* outfile = fopen(fileCStr, "wb"); -#endif - - TORCH_CHECK(outfile != nullptr, "Error opening output file"); - - fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile); - fclose(outfile); -} - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cpu/read_write_file.h b/product/include/torchvision/io/image/cpu/read_write_file.h deleted file mode 100644 index a5a712dd8e2..00000000000 --- a/product/include/torchvision/io/image/cpu/read_write_file.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor read_file(const std::string& filename); - -C10_EXPORT void write_file(const std::string& filename, torch::Tensor& data); - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cuda/decode_jpegs_cuda.cpp b/product/include/torchvision/io/image/cuda/decode_jpegs_cuda.cpp deleted file mode 100644 index 6314ececef1..00000000000 --- a/product/include/torchvision/io/image/cuda/decode_jpegs_cuda.cpp +++ /dev/null @@ -1,603 +0,0 @@ -#include "decode_jpegs_cuda.h" -#if !NVJPEG_FOUND -namespace vision { -namespace image { -std::vector decode_jpegs_cuda( - const std::vector& encoded_images, - vision::image::ImageReadMode mode, - torch::Device device) { - TORCH_CHECK( - false, "decode_jpegs_cuda: torchvision not compiled with nvJPEG support"); -} -} // namespace image -} // namespace vision - -#else -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -namespace vision { -namespace image { - -std::mutex decoderMutex; -std::unique_ptr cudaJpegDecoder; - -std::vector decode_jpegs_cuda( - const std::vector& encoded_images, - vision::image::ImageReadMode mode, - torch::Device device) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda"); - - std::lock_guard lock(decoderMutex); - std::vector contig_images; - contig_images.reserve(encoded_images.size()); - - TORCH_CHECK( - device.is_cuda(), "Expected the device parameter to be a cuda device"); - - for (auto& encoded_image : encoded_images) { - TORCH_CHECK( - encoded_image.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - - TORCH_CHECK( - !encoded_image.is_cuda(), - "The input tensor must be on CPU when decoding with nvjpeg") - - TORCH_CHECK( - encoded_image.dim() == 1 && encoded_image.numel() > 0, - "Expected a non empty 1-dimensional tensor"); - - // nvjpeg requires images to be contiguous - if (encoded_image.is_contiguous()) { - contig_images.push_back(encoded_image); - } else { - contig_images.push_back(encoded_image.contiguous()); - } - } - - int major_version; - int minor_version; - nvjpegStatus_t get_major_property_status = - nvjpegGetProperty(MAJOR_VERSION, &major_version); - nvjpegStatus_t get_minor_property_status = - nvjpegGetProperty(MINOR_VERSION, &minor_version); - - TORCH_CHECK( - get_major_property_status == NVJPEG_STATUS_SUCCESS, - "nvjpegGetProperty failed: ", - get_major_property_status); - TORCH_CHECK( - get_minor_property_status == NVJPEG_STATUS_SUCCESS, - "nvjpegGetProperty failed: ", - get_minor_property_status); - if ((major_version < 11) || ((major_version == 11) && (minor_version < 6))) { - TORCH_WARN_ONCE( - "There is a memory leak issue in the nvjpeg library for CUDA versions < 11.6. " - "Make sure to rely on CUDA 11.6 or above before using decode_jpeg(..., device='cuda')."); - } - - at::cuda::CUDAGuard device_guard(device); - - if (cudaJpegDecoder == nullptr || device != cudaJpegDecoder->target_device) { - if (cudaJpegDecoder != nullptr) - cudaJpegDecoder.reset(new CUDAJpegDecoder(device)); - else { - cudaJpegDecoder = std::make_unique(device); - std::atexit([]() { cudaJpegDecoder.reset(); }); - } - } - - nvjpegOutputFormat_t output_format; - - switch (mode) { - case vision::image::IMAGE_READ_MODE_UNCHANGED: - // Using NVJPEG_OUTPUT_UNCHANGED causes differently sized output channels - // which is related to the subsampling used I'm not sure why this is the - // case, but for now we're just using RGB and later removing channels from - // grayscale images. - output_format = NVJPEG_OUTPUT_UNCHANGED; - break; - case vision::image::IMAGE_READ_MODE_GRAY: - output_format = NVJPEG_OUTPUT_Y; - break; - case vision::image::IMAGE_READ_MODE_RGB: - output_format = NVJPEG_OUTPUT_RGB; - break; - default: - TORCH_CHECK( - false, "The provided mode is not supported for JPEG decoding on GPU"); - } - - try { - at::cuda::CUDAEvent event; - auto result = cudaJpegDecoder->decode_images(contig_images, output_format); - auto current_stream{ - device.has_index() ? at::cuda::getCurrentCUDAStream( - cudaJpegDecoder->original_device.index()) - : at::cuda::getCurrentCUDAStream()}; - event.record(cudaJpegDecoder->stream); - event.block(current_stream); - return result; - } catch (const std::exception& e) { - if (typeid(e) != typeid(std::runtime_error)) { - TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); - } else { - throw; - } - } -} - -CUDAJpegDecoder::CUDAJpegDecoder(const torch::Device& target_device) - : original_device{torch::kCUDA, torch::cuda::current_device()}, - target_device{target_device}, - stream{ - target_device.has_index() - ? at::cuda::getStreamFromPool(false, target_device.index()) - : at::cuda::getStreamFromPool(false)} { - nvjpegStatus_t status; - - hw_decode_available = true; - status = nvjpegCreateEx( - NVJPEG_BACKEND_HARDWARE, - NULL, - NULL, - NVJPEG_FLAGS_DEFAULT, - &nvjpeg_handle); - if (status == NVJPEG_STATUS_ARCH_MISMATCH) { - status = nvjpegCreateEx( - NVJPEG_BACKEND_DEFAULT, - NULL, - NULL, - NVJPEG_FLAGS_DEFAULT, - &nvjpeg_handle); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to initialize nvjpeg with default backend: ", - status); - hw_decode_available = false; - } else { - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to initialize nvjpeg with hardware backend: ", - status); - } - - status = nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to create nvjpeg state: ", - status); - - status = nvjpegDecoderCreate( - nvjpeg_handle, NVJPEG_BACKEND_DEFAULT, &nvjpeg_decoder); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to create nvjpeg decoder: ", - status); - - status = nvjpegDecoderStateCreate( - nvjpeg_handle, nvjpeg_decoder, &nvjpeg_decoupled_state); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to create nvjpeg decoder state: ", - status); - - status = nvjpegBufferPinnedCreate(nvjpeg_handle, NULL, &pinned_buffers[0]); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to create pinned buffer: ", - status); - - status = nvjpegBufferPinnedCreate(nvjpeg_handle, NULL, &pinned_buffers[1]); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to create pinned buffer: ", - status); - - status = nvjpegBufferDeviceCreate(nvjpeg_handle, NULL, &device_buffer); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to create device buffer: ", - status); - - status = nvjpegJpegStreamCreate(nvjpeg_handle, &jpeg_streams[0]); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to create jpeg stream: ", - status); - - status = nvjpegJpegStreamCreate(nvjpeg_handle, &jpeg_streams[1]); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to create jpeg stream: ", - status); - - status = nvjpegDecodeParamsCreate(nvjpeg_handle, &nvjpeg_decode_params); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to create decode params: ", - status); -} - -CUDAJpegDecoder::~CUDAJpegDecoder() { - /* - The below code works on Mac and Linux, but fails on Windows. - This is because on Windows, the atexit hook which calls this - destructor executes after cuda is already shut down causing SIGSEGV. - We do not have a solution to this problem at the moment, so we'll - just leak the libnvjpeg & cuda variables for the time being and hope - that the CUDA runtime handles cleanup for us. - Please send a PR if you have a solution for this problem. - */ - - // nvjpegStatus_t status; - - // status = nvjpegDecodeParamsDestroy(nvjpeg_decode_params); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg decode params: ", - // status); - - // status = nvjpegJpegStreamDestroy(jpeg_streams[0]); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy jpeg stream: ", - // status); - - // status = nvjpegJpegStreamDestroy(jpeg_streams[1]); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy jpeg stream: ", - // status); - - // status = nvjpegBufferPinnedDestroy(pinned_buffers[0]); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy pinned buffer[0]: ", - // status); - - // status = nvjpegBufferPinnedDestroy(pinned_buffers[1]); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy pinned buffer[1]: ", - // status); - - // status = nvjpegBufferDeviceDestroy(device_buffer); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy device buffer: ", - // status); - - // status = nvjpegJpegStateDestroy(nvjpeg_decoupled_state); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg decoupled state: ", - // status); - - // status = nvjpegDecoderDestroy(nvjpeg_decoder); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg decoder: ", - // status); - - // status = nvjpegJpegStateDestroy(nvjpeg_state); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg state: ", - // status); - - // status = nvjpegDestroy(nvjpeg_handle); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); -} - -std::tuple< - std::vector, - std::vector, - std::vector> -CUDAJpegDecoder::prepare_buffers( - const std::vector& encoded_images, - const nvjpegOutputFormat_t& output_format) { - /* - This function scans the encoded images' jpeg headers and - allocates decoding buffers based on the metadata found - - Args: - - encoded_images (std::vector): a vector of tensors - containing the jpeg bitstreams to be decoded. Each tensor must have dtype - torch.uint8 and device cpu - - output_format (nvjpegOutputFormat_t): NVJPEG_OUTPUT_RGB, NVJPEG_OUTPUT_Y - or NVJPEG_OUTPUT_UNCHANGED - - Returns: - - decoded_images (std::vector): a vector of nvjpegImages - containing pointers to the memory of the decoded images - - output_tensors (std::vector): a vector of Tensors - containing the decoded images. `decoded_images` points to the memory of - output_tensors - - channels (std::vector): a vector of ints containing the number of - output image channels for every image - */ - - int width[NVJPEG_MAX_COMPONENT]; - int height[NVJPEG_MAX_COMPONENT]; - std::vector channels(encoded_images.size()); - nvjpegChromaSubsampling_t subsampling; - nvjpegStatus_t status; - - std::vector output_tensors{encoded_images.size()}; - std::vector decoded_images{encoded_images.size()}; - - for (std::vector::size_type i = 0; i < encoded_images.size(); - i++) { - // extract bitstream meta data to figure out the number of channels, height, - // width for every image - status = nvjpegGetImageInfo( - nvjpeg_handle, - (unsigned char*)encoded_images[i].data_ptr(), - encoded_images[i].numel(), - &channels[i], - &subsampling, - width, - height); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, "Failed to get image info: ", status); - - TORCH_CHECK( - subsampling != NVJPEG_CSS_UNKNOWN, "Unknown chroma subsampling"); - - // output channels may be different from the actual number of channels in - // the image, e.g. we decode a grayscale image as RGB and slice off the - // extra channels later - int output_channels = 3; - if (output_format == NVJPEG_OUTPUT_RGB || - output_format == NVJPEG_OUTPUT_UNCHANGED) { - output_channels = 3; - } else if (output_format == NVJPEG_OUTPUT_Y) { - output_channels = 1; - } - - // reserve output buffer - auto output_tensor = torch::empty( - {int64_t(output_channels), int64_t(height[0]), int64_t(width[0])}, - torch::dtype(torch::kU8).device(target_device)); - output_tensors[i] = output_tensor; - - // fill nvjpegImage_t struct - for (int c = 0; c < output_channels; c++) { - decoded_images[i].channel[c] = output_tensor[c].data_ptr(); - decoded_images[i].pitch[c] = width[0]; - } - for (int c = output_channels; c < NVJPEG_MAX_COMPONENT; c++) { - decoded_images[i].channel[c] = NULL; - decoded_images[i].pitch[c] = 0; - } - } - return {decoded_images, output_tensors, channels}; -} - -std::vector CUDAJpegDecoder::decode_images( - const std::vector& encoded_images, - const nvjpegOutputFormat_t& output_format) { - /* - This function decodes a batch of jpeg bitstreams. - We scan all encoded bitstreams and sort them into two groups: - 1. Baseline JPEGs: Can be decoded with hardware support on A100+ GPUs. - 2. Other JPEGs (e.g. progressive JPEGs): Can also be decoded on the - GPU (albeit with software support only) but need some preprocessing on the - host first. - - See - https://github.com/NVIDIA/CUDALibrarySamples/blob/f17940ac4e705bf47a8c39f5365925c1665f6c98/nvJPEG/nvJPEG-Decoder/nvjpegDecoder.cpp#L33 - for reference. - - Args: - - encoded_images (std::vector): a vector of tensors - containing the jpeg bitstreams to be decoded - - output_format (nvjpegOutputFormat_t): NVJPEG_OUTPUT_RGB, NVJPEG_OUTPUT_Y - or NVJPEG_OUTPUT_UNCHANGED - - device (torch::Device): The desired CUDA device for the returned Tensors - - Returns: - - output_tensors (std::vector): a vector of Tensors - containing the decoded images - */ - - auto [decoded_imgs_buf, output_tensors, channels] = - prepare_buffers(encoded_images, output_format); - - nvjpegStatus_t status; - cudaError_t cudaStatus; - - cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK( - cudaStatus == cudaSuccess, - "Failed to synchronize CUDA stream: ", - cudaStatus); - - // baseline JPEGs can be batch decoded with hardware support on A100+ GPUs - // ultra fast! - std::vector hw_input_buffer; - std::vector hw_input_buffer_size; - std::vector hw_output_buffer; - - // other JPEG types such as progressive JPEGs can be decoded one-by-one in - // software slow :( - std::vector sw_input_buffer; - std::vector sw_input_buffer_size; - std::vector sw_output_buffer; - - if (hw_decode_available) { - for (std::vector::size_type i = 0; i < encoded_images.size(); - ++i) { - // extract bitstream meta data to figure out whether a bit-stream can be - // decoded - nvjpegJpegStreamParseHeader( - nvjpeg_handle, - encoded_images[i].data_ptr(), - encoded_images[i].numel(), - jpeg_streams[0]); - int isSupported = -1; - nvjpegDecodeBatchedSupported( - nvjpeg_handle, jpeg_streams[0], &isSupported); - - if (isSupported == 0) { - hw_input_buffer.push_back(encoded_images[i].data_ptr()); - hw_input_buffer_size.push_back(encoded_images[i].numel()); - hw_output_buffer.push_back(decoded_imgs_buf[i]); - } else { - sw_input_buffer.push_back(encoded_images[i].data_ptr()); - sw_input_buffer_size.push_back(encoded_images[i].numel()); - sw_output_buffer.push_back(decoded_imgs_buf[i]); - } - } - } else { - for (std::vector::size_type i = 0; i < encoded_images.size(); - ++i) { - sw_input_buffer.push_back(encoded_images[i].data_ptr()); - sw_input_buffer_size.push_back(encoded_images[i].numel()); - sw_output_buffer.push_back(decoded_imgs_buf[i]); - } - } - - if (hw_input_buffer.size() > 0) { - // UNCHANGED behaves weird, so we use RGB instead - status = nvjpegDecodeBatchedInitialize( - nvjpeg_handle, - nvjpeg_state, - hw_input_buffer.size(), - 1, - output_format == NVJPEG_OUTPUT_UNCHANGED ? NVJPEG_OUTPUT_RGB - : output_format); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to initialize batch decoding: ", - status); - - status = nvjpegDecodeBatched( - nvjpeg_handle, - nvjpeg_state, - hw_input_buffer.data(), - hw_input_buffer_size.data(), - hw_output_buffer.data(), - stream); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, "Failed to decode batch: ", status); - } - - if (sw_input_buffer.size() > 0) { - status = - nvjpegStateAttachDeviceBuffer(nvjpeg_decoupled_state, device_buffer); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to attach device buffer: ", - status); - int buffer_index = 0; - // UNCHANGED behaves weird, so we use RGB instead - status = nvjpegDecodeParamsSetOutputFormat( - nvjpeg_decode_params, - output_format == NVJPEG_OUTPUT_UNCHANGED ? NVJPEG_OUTPUT_RGB - : output_format); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to set output format: ", - status); - for (std::vector::size_type i = 0; i < sw_input_buffer.size(); - ++i) { - status = nvjpegJpegStreamParse( - nvjpeg_handle, - sw_input_buffer[i], - sw_input_buffer_size[i], - 0, - 0, - jpeg_streams[buffer_index]); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to parse jpeg stream: ", - status); - - status = nvjpegStateAttachPinnedBuffer( - nvjpeg_decoupled_state, pinned_buffers[buffer_index]); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to attach pinned buffer: ", - status); - - status = nvjpegDecodeJpegHost( - nvjpeg_handle, - nvjpeg_decoder, - nvjpeg_decoupled_state, - nvjpeg_decode_params, - jpeg_streams[buffer_index]); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to decode jpeg stream: ", - status); - - cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK( - cudaStatus == cudaSuccess, - "Failed to synchronize CUDA stream: ", - cudaStatus); - - status = nvjpegDecodeJpegTransferToDevice( - nvjpeg_handle, - nvjpeg_decoder, - nvjpeg_decoupled_state, - jpeg_streams[buffer_index], - stream); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to transfer jpeg to device: ", - status); - - buffer_index = 1 - buffer_index; // switch pinned buffer in pipeline mode - // to avoid an extra sync - - status = nvjpegDecodeJpegDevice( - nvjpeg_handle, - nvjpeg_decoder, - nvjpeg_decoupled_state, - &sw_output_buffer[i], - stream); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to decode jpeg stream: ", - status); - } - } - - cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK( - cudaStatus == cudaSuccess, - "Failed to synchronize CUDA stream: ", - cudaStatus); - - // prune extraneous channels from single channel images - if (output_format == NVJPEG_OUTPUT_UNCHANGED) { - for (std::vector::size_type i = 0; i < output_tensors.size(); - ++i) { - if (channels[i] == 1) { - output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); - } - } - } - - return output_tensors; -} - -} // namespace image -} // namespace vision - -#endif diff --git a/product/include/torchvision/io/image/cuda/decode_jpegs_cuda.h b/product/include/torchvision/io/image/cuda/decode_jpegs_cuda.h deleted file mode 100644 index 2458a103a3a..00000000000 --- a/product/include/torchvision/io/image/cuda/decode_jpegs_cuda.h +++ /dev/null @@ -1,45 +0,0 @@ -#pragma once -#include -#include -#include "../image_read_mode.h" - -#if NVJPEG_FOUND -#include -#include - -namespace vision { -namespace image { -class CUDAJpegDecoder { - public: - CUDAJpegDecoder(const torch::Device& target_device); - ~CUDAJpegDecoder(); - - std::vector decode_images( - const std::vector& encoded_images, - const nvjpegOutputFormat_t& output_format); - - const torch::Device original_device; - const torch::Device target_device; - const c10::cuda::CUDAStream stream; - - private: - std::tuple< - std::vector, - std::vector, - std::vector> - prepare_buffers( - const std::vector& encoded_images, - const nvjpegOutputFormat_t& output_format); - nvjpegJpegState_t nvjpeg_state; - nvjpegJpegState_t nvjpeg_decoupled_state; - nvjpegBufferPinned_t pinned_buffers[2]; - nvjpegBufferDevice_t device_buffer; - nvjpegJpegStream_t jpeg_streams[2]; - nvjpegDecodeParams_t nvjpeg_decode_params; - nvjpegJpegDecoder_t nvjpeg_decoder; - bool hw_decode_available{false}; - nvjpegHandle_t nvjpeg_handle; -}; -} // namespace image -} // namespace vision -#endif diff --git a/product/include/torchvision/io/image/cuda/encode_decode_jpegs_cuda.h b/product/include/torchvision/io/image/cuda/encode_decode_jpegs_cuda.h deleted file mode 100644 index 3fdf715b00f..00000000000 --- a/product/include/torchvision/io/image/cuda/encode_decode_jpegs_cuda.h +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include -#include "../image_read_mode.h" -#include "decode_jpegs_cuda.h" -#include "encode_jpegs_cuda.h" - -namespace vision { -namespace image { - -/* - -Fast jpeg decoding with CUDA. -A100+ GPUs have dedicated hardware support for jpeg decoding. - -Args: - - encoded_images (const std::vector&): a vector of tensors - containing the jpeg bitstreams to be decoded. Each tensor must have dtype - torch.uint8 and device cpu - - mode (ImageReadMode): IMAGE_READ_MODE_UNCHANGED, IMAGE_READ_MODE_GRAY and -IMAGE_READ_MODE_RGB are supported - - device (torch::Device): The desired CUDA device to run the decoding on and -which will contain the output tensors - -Returns: - - decoded_images (std::vector): a vector of torch::Tensors of -dtype torch.uint8 on the specified containing the decoded images - -Notes: - - If a single image fails, the whole batch fails. - - This function is thread-safe -*/ -C10_EXPORT std::vector decode_jpegs_cuda( - const std::vector& encoded_images, - vision::image::ImageReadMode mode, - torch::Device device); - -/* -Fast jpeg encoding with CUDA. - -Args: - - decoded_images (const std::vector&): a vector of contiguous -CUDA tensors of dtype torch.uint8 to be encoded. - - quality (int64_t): 0-100, 75 is the default - -Returns: - - encoded_images (std::vector): a vector of CUDA -torch::Tensors of dtype torch.uint8 containing the encoded images - -Notes: - - If a single image fails, the whole batch fails. - - This function is thread-safe -*/ -C10_EXPORT std::vector encode_jpegs_cuda( - const std::vector& decoded_images, - const int64_t quality); - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/cuda/encode_jpegs_cuda.cpp b/product/include/torchvision/io/image/cuda/encode_jpegs_cuda.cpp deleted file mode 100644 index 1f10327ddbf..00000000000 --- a/product/include/torchvision/io/image/cuda/encode_jpegs_cuda.cpp +++ /dev/null @@ -1,274 +0,0 @@ -#include "encode_jpegs_cuda.h" -#if !NVJPEG_FOUND -namespace vision { -namespace image { -std::vector encode_jpegs_cuda( - const std::vector& decoded_images, - const int64_t quality) { - TORCH_CHECK( - false, "encode_jpegs_cuda: torchvision not compiled with nvJPEG support"); -} -} // namespace image -} // namespace vision -#else - -#include -#include -#include -#include -#include -#include -#include -#include -#include "c10/core/ScalarType.h" - -namespace vision { -namespace image { - -// We use global variables to cache the encoder and decoder instances and -// reuse them across calls to the corresponding pytorch functions -std::mutex encoderMutex; -std::unique_ptr cudaJpegEncoder; - -std::vector encode_jpegs_cuda( - const std::vector& decoded_images, - const int64_t quality) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cuda.encode_jpegs_cuda.encode_jpegs_cuda"); - - // Some nvjpeg structures are not thread safe so we're keeping it single - // threaded for now. In the future this may be an opportunity to unlock - // further speedups - std::lock_guard lock(encoderMutex); - TORCH_CHECK(decoded_images.size() > 0, "Empty input tensor list"); - torch::Device device = decoded_images[0].device(); - at::cuda::CUDAGuard device_guard(device); - - // lazy init of the encoder class - // the encoder object holds on to a lot of state and is expensive to create, - // so we reuse it across calls. NB: the cached structures are device specific - // and cannot be reused across devices - if (cudaJpegEncoder == nullptr || device != cudaJpegEncoder->target_device) { - if (cudaJpegEncoder != nullptr) - delete cudaJpegEncoder.release(); - - cudaJpegEncoder = std::make_unique(device); - - // Unfortunately, we cannot rely on the smart pointer releasing the encoder - // object correctly upon program exit. This is because, when cudaJpegEncoder - // gets destroyed, the CUDA runtime may already be shut down, rendering all - // destroy* calls in the encoder destructor invalid. Instead, we use an - // atexit hook which executes after main() finishes, but hopefully before - // CUDA shuts down when the program exits. If CUDA is already shut down the - // destructor will detect this and will not attempt to destroy any encoder - // structures. - std::atexit([]() { delete cudaJpegEncoder.release(); }); - } - - std::vector contig_images; - contig_images.reserve(decoded_images.size()); - for (const auto& image : decoded_images) { - TORCH_CHECK( - image.dtype() == torch::kU8, "Input tensor dtype should be uint8"); - - TORCH_CHECK( - image.device() == device, - "All input tensors must be on the same CUDA device when encoding with nvjpeg") - - TORCH_CHECK( - image.dim() == 3 && image.numel() > 0, - "Input data should be a 3-dimensional tensor"); - - TORCH_CHECK( - image.size(0) == 3, - "The number of channels should be 3, got: ", - image.size(0)); - - // nvjpeg requires images to be contiguous - if (image.is_contiguous()) { - contig_images.push_back(image); - } else { - contig_images.push_back(image.contiguous()); - } - } - - cudaJpegEncoder->set_quality(quality); - std::vector encoded_images; - at::cuda::CUDAEvent event; - event.record(cudaJpegEncoder->stream); - for (const auto& image : contig_images) { - auto encoded_image = cudaJpegEncoder->encode_jpeg(image); - encoded_images.push_back(encoded_image); - } - - // We use a dedicated stream to do the encoding and even though the results - // may be ready on that stream we cannot assume that they are also available - // on the current stream of the calling context when this function returns. We - // use a blocking event to ensure that this is indeed the case. Crucially, we - // do not want to block the host at this particular point - // (which is what cudaStreamSynchronize would do.) Events allow us to - // synchronize the streams without blocking the host. - event.block(at::cuda::getCurrentCUDAStream( - cudaJpegEncoder->original_device.has_index() - ? cudaJpegEncoder->original_device.index() - : 0)); - return encoded_images; -} - -CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device) - : original_device{torch::kCUDA, torch::cuda::current_device()}, - target_device{target_device}, - stream{ - target_device.has_index() - ? at::cuda::getStreamFromPool(false, target_device.index()) - : at::cuda::getStreamFromPool(false)} { - nvjpegStatus_t status; - status = nvjpegCreateSimple(&nvjpeg_handle); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to create nvjpeg handle: ", - status); - - status = nvjpegEncoderStateCreate(nvjpeg_handle, &nv_enc_state, stream); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to create nvjpeg encoder state: ", - status); - - status = nvjpegEncoderParamsCreate(nvjpeg_handle, &nv_enc_params, stream); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to create nvjpeg encoder params: ", - status); -} - -CUDAJpegEncoder::~CUDAJpegEncoder() { - /* - The below code works on Mac and Linux, but fails on Windows. - This is because on Windows, the atexit hook which calls this - destructor executes after cuda is already shut down causing SIGSEGV. - We do not have a solution to this problem at the moment, so we'll - just leak the libnvjpeg & cuda variables for the time being and hope - that the CUDA runtime handles cleanup for us. - Please send a PR if you have a solution for this problem. - */ - - // // We run cudaGetDeviceCount as a dummy to test if the CUDA runtime is - // still - // // initialized. If it is not, we can skip the rest of this function as it - // is - // // unsafe to execute. - // int deviceCount = 0; - // cudaError_t error = cudaGetDeviceCount(&deviceCount); - // if (error != cudaSuccess) - // return; // CUDA runtime has already shut down. There's nothing we can do - // // now. - - // nvjpegStatus_t status; - - // status = nvjpegEncoderParamsDestroy(nv_enc_params); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg encoder params: ", - // status); - - // status = nvjpegEncoderStateDestroy(nv_enc_state); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, - // "Failed to destroy nvjpeg encoder state: ", - // status); - - // cudaStreamSynchronize(stream); - - // status = nvjpegDestroy(nvjpeg_handle); - // TORCH_CHECK( - // status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); -} - -torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) { - int channels = src_image.size(0); - int height = src_image.size(1); - int width = src_image.size(2); - - nvjpegStatus_t status; - cudaError_t cudaStatus; - status = nvjpegEncoderParamsSetSamplingFactors( - nv_enc_params, NVJPEG_CSS_444, stream); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to set nvjpeg encoder params sampling factors: ", - status); - - nvjpegImage_t target_image; - for (int c = 0; c < channels; c++) { - target_image.channel[c] = src_image[c].data_ptr(); - // this is why we need contiguous tensors - target_image.pitch[c] = width; - } - for (int c = channels; c < NVJPEG_MAX_COMPONENT; c++) { - target_image.channel[c] = nullptr; - target_image.pitch[c] = 0; - } - // Encode the image - status = nvjpegEncodeImage( - nvjpeg_handle, - nv_enc_state, - nv_enc_params, - &target_image, - NVJPEG_INPUT_RGB, - width, - height, - stream); - - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, "image encoding failed: ", status); - // Retrieve length of the encoded image - size_t length; - status = nvjpegEncodeRetrieveBitstreamDevice( - nvjpeg_handle, nv_enc_state, NULL, &length, stream); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to retrieve encoded image stream state: ", - status); - - // Synchronize the stream to ensure that the encoded image is ready - cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); - - // Reserve buffer for the encoded image - torch::Tensor encoded_image = torch::empty( - {static_cast(length)}, - torch::TensorOptions() - .dtype(torch::kByte) - .layout(torch::kStrided) - .device(target_device) - .requires_grad(false)); - cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus); - // Retrieve the encoded image - status = nvjpegEncodeRetrieveBitstreamDevice( - nvjpeg_handle, - nv_enc_state, - encoded_image.data_ptr(), - &length, - 0); - TORCH_CHECK( - status == NVJPEG_STATUS_SUCCESS, - "Failed to retrieve encoded image: ", - status); - return encoded_image; -} - -void CUDAJpegEncoder::set_quality(const int64_t quality) { - nvjpegStatus_t paramsQualityStatus = - nvjpegEncoderParamsSetQuality(nv_enc_params, quality, stream); - TORCH_CHECK( - paramsQualityStatus == NVJPEG_STATUS_SUCCESS, - "Failed to set nvjpeg encoder params quality: ", - paramsQualityStatus); -} - -} // namespace image -} // namespace vision - -#endif // NVJPEG_FOUND diff --git a/product/include/torchvision/io/image/cuda/encode_jpegs_cuda.h b/product/include/torchvision/io/image/cuda/encode_jpegs_cuda.h deleted file mode 100644 index 543940f1585..00000000000 --- a/product/include/torchvision/io/image/cuda/encode_jpegs_cuda.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once -#include -#include -#if NVJPEG_FOUND - -#include -#include -#include - -namespace vision { -namespace image { - -class CUDAJpegEncoder { - public: - CUDAJpegEncoder(const torch::Device& device); - ~CUDAJpegEncoder(); - - torch::Tensor encode_jpeg(const torch::Tensor& src_image); - - void set_quality(const int64_t quality); - - const torch::Device original_device; - const torch::Device target_device; - const c10::cuda::CUDAStream stream; - - protected: - nvjpegEncoderState_t nv_enc_state; - nvjpegEncoderParams_t nv_enc_params; - nvjpegHandle_t nvjpeg_handle; -}; -} // namespace image -} // namespace vision -#endif diff --git a/product/include/torchvision/io/image/image.cpp b/product/include/torchvision/io/image/image.cpp deleted file mode 100644 index 43e8ecbe4a2..00000000000 --- a/product/include/torchvision/io/image/image.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include "image.h" - -#include - -// If we are in a Windows environment, we need to define -// initialization functions for the _custom_ops extension -#ifdef _WIN32 -void* PyInit_image(void) { - return nullptr; -} -#endif - -namespace vision { -namespace image { - -static auto registry = - torch::RegisterOperators() - .op("image::decode_gif", &decode_gif) - .op("image::decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", - &decode_png) - .op("image::encode_png", &encode_png) - .op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", - &decode_jpeg) - .op("image::decode_webp", &decode_webp) - .op("image::decode_avif", &decode_avif) - .op("image::encode_jpeg", &encode_jpeg) - .op("image::read_file", &read_file) - .op("image::write_file", &write_file) - .op("image::decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", - &decode_image) - .op("image::decode_jpegs_cuda", &decode_jpegs_cuda) - .op("image::encode_jpegs_cuda", &encode_jpegs_cuda) - .op("image::_jpeg_version", &_jpeg_version) - .op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo); - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/io/image/image.h b/product/include/torchvision/io/image/image.h deleted file mode 100644 index 91a5144fa1c..00000000000 --- a/product/include/torchvision/io/image/image.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include "cpu/decode_avif.h" -#include "cpu/decode_gif.h" -#include "cpu/decode_image.h" -#include "cpu/decode_jpeg.h" -#include "cpu/decode_png.h" -#include "cpu/decode_webp.h" -#include "cpu/encode_jpeg.h" -#include "cpu/encode_png.h" -#include "cpu/read_write_file.h" -#include "cuda/encode_decode_jpegs_cuda.h" diff --git a/product/include/torchvision/io/image/image_read_mode.h b/product/include/torchvision/io/image/image_read_mode.h deleted file mode 100644 index 84425265c34..00000000000 --- a/product/include/torchvision/io/image/image_read_mode.h +++ /dev/null @@ -1,17 +0,0 @@ -#pragma once - -#include - -namespace vision { -namespace image { - -/* Should be kept in-sync with Python ImageReadMode enum */ -using ImageReadMode = int64_t; -const ImageReadMode IMAGE_READ_MODE_UNCHANGED = 0; -const ImageReadMode IMAGE_READ_MODE_GRAY = 1; -const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2; -const ImageReadMode IMAGE_READ_MODE_RGB = 3; -const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4; - -} // namespace image -} // namespace vision diff --git a/product/include/torchvision/macros.h b/product/include/torchvision/macros.h deleted file mode 100644 index f907280e24e..00000000000 --- a/product/include/torchvision/macros.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#if defined(_WIN32) && !defined(TORCHVISION_BUILD_STATIC_LIBS) -#if defined(torchvision_EXPORTS) -#define VISION_API __declspec(dllexport) -#else -#define VISION_API __declspec(dllimport) -#endif -#else -#define VISION_API -#endif diff --git a/product/include/torchvision/ops/autograd/deform_conv2d_kernel.cpp b/product/include/torchvision/ops/autograd/deform_conv2d_kernel.cpp deleted file mode 100644 index 0a7bbf9014e..00000000000 --- a/product/include/torchvision/ops/autograd/deform_conv2d_kernel.cpp +++ /dev/null @@ -1,266 +0,0 @@ -#include "../deform_conv2d.h" - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -class DeformConv2dFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& weight, - const torch::autograd::Variable& offset, - const torch::autograd::Variable& mask, - const torch::autograd::Variable& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask) { - at::AutoDispatchBelowADInplaceOrView g; - auto output = deform_conv2d_symint( - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); - - ctx->save_for_backward({input, weight, offset, mask, bias}); - ctx->saved_data["stride_h"] = stride_h; - ctx->saved_data["stride_w"] = stride_w; - ctx->saved_data["pad_h"] = pad_h; - ctx->saved_data["pad_w"] = pad_w; - ctx->saved_data["dilation_h"] = dilation_h; - ctx->saved_data["dilation_w"] = dilation_w; - ctx->saved_data["groups"] = groups; - ctx->saved_data["offset_groups"] = offset_groups; - ctx->saved_data["use_mask"] = use_mask; - - return { - output, - }; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - auto saved = ctx->get_saved_variables(); - auto input = saved[0]; - auto weight = saved[1]; - auto offset = saved[2]; - auto mask = saved[3]; - auto bias = saved[4]; - - auto stride_h = ctx->saved_data["stride_h"].toSymInt(); - auto stride_w = ctx->saved_data["stride_w"].toSymInt(); - auto pad_h = ctx->saved_data["pad_h"].toSymInt(); - auto pad_w = ctx->saved_data["pad_w"].toSymInt(); - auto dilation_h = ctx->saved_data["dilation_h"].toSymInt(); - auto dilation_w = ctx->saved_data["dilation_w"].toSymInt(); - auto groups = ctx->saved_data["groups"].toSymInt(); - auto offset_groups = ctx->saved_data["offset_groups"].toSymInt(); - auto use_mask = ctx->saved_data["use_mask"].toBool(); - - auto grads = detail::_deform_conv2d_backward_symint( - grad_output[0], - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); - auto grad_input = std::get<0>(grads); - auto grad_weight = std::get<1>(grads); - auto grad_offset = std::get<2>(grads); - auto grad_mask = std::get<3>(grads); - auto grad_bias = std::get<4>(grads); - - return { - grad_input, - grad_weight, - grad_offset, - grad_mask, - grad_bias, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - }; - } -}; - -// TODO: There should be an easier way to do this -class DeformConv2dBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& input, - const torch::autograd::Variable& weight, - const torch::autograd::Variable& offset, - const torch::autograd::Variable& mask, - const torch::autograd::Variable& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask) { - at::AutoDispatchBelowADInplaceOrView g; - auto result = detail::_deform_conv2d_backward_symint( - grad, - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); - - auto grad_input = std::get<0>(result); - auto grad_weight = std::get<1>(result); - auto grad_offset = std::get<2>(result); - auto grad_mask = std::get<3>(result); - auto grad_bias = std::get<4>(result); - - return { - grad_input, - grad_weight, - grad_offset, - grad_mask, - grad_bias, - }; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on deform_conv2d not supported"); - } -}; - -at::Tensor deform_conv2d_autograd( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask) { - return DeformConv2dFunction::apply( - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask)[0]; -} - -std::tuple -deform_conv2d_backward_autograd( - const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask) { - auto result = DeformConv2dBackwardFunction::apply( - grad, - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); - - return std::make_tuple(result[0], result[1], result[2], result[3], result[4]); -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), - TORCH_FN(deform_conv2d_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), - TORCH_FN(deform_conv2d_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp b/product/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp deleted file mode 100644 index 7205e9b15db..00000000000 --- a/product/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp +++ /dev/null @@ -1,167 +0,0 @@ -#include "../ps_roi_align.h" - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -class PSROIAlignFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio) { - ctx->saved_data["spatial_scale"] = spatial_scale; - ctx->saved_data["pooled_height"] = pooled_height; - ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["sampling_ratio"] = sampling_ratio; - ctx->saved_data["input_shape"] = input.sym_sizes(); - at::AutoDispatchBelowADInplaceOrView g; - auto result = ps_roi_align_symint( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio); - - auto output = std::get<0>(result); - auto channel_mapping = std::get<1>(result); - ctx->save_for_backward({rois, channel_mapping}); - ctx->mark_non_differentiable({channel_mapping}); - - return {output, channel_mapping}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - // Use data saved in forward - auto saved = ctx->get_saved_variables(); - auto rois = saved[0]; - auto channel_mapping = saved[1]; - auto input_shape = ctx->saved_data["input_shape"].toList(); - auto grad_in = detail::_ps_roi_align_backward_symint( - grad_output[0], - rois, - channel_mapping, - ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toSymInt(), - ctx->saved_data["pooled_width"].toSymInt(), - ctx->saved_data["sampling_ratio"].toInt(), - input_shape[0].get().toSymInt(), - input_shape[1].get().toSymInt(), - input_shape[2].get().toSymInt(), - input_shape[3].get().toSymInt()); - - return { - grad_in, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable()}; - } -}; - -// TODO: There should be an easier way to do this -class PSROIAlignBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& rois, - const torch::autograd::Variable& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - at::AutoDispatchBelowADInplaceOrView g; - auto grad_in = detail::_ps_roi_align_backward_symint( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - batch_size, - channels, - height, - width); - - return {grad_in}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on ps_roi_align not supported"); - } -}; - -std::tuple ps_roi_align_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio) { - auto result = PSROIAlignFunction::apply( - input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); - - return std::make_tuple(result[0], result[1]); -} - -at::Tensor ps_roi_align_backward_autograd( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - return PSROIAlignBackwardFunction::apply( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - batch_size, - channels, - height, - width)[0]; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), - TORCH_FN(ps_roi_align_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), - TORCH_FN(ps_roi_align_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp b/product/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp deleted file mode 100644 index 39b83819f94..00000000000 --- a/product/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp +++ /dev/null @@ -1,152 +0,0 @@ -#include "../ps_roi_pool.h" - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -class PSROIPoolFunction : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - ctx->saved_data["spatial_scale"] = spatial_scale; - ctx->saved_data["pooled_height"] = pooled_height; - ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["input_shape"] = input.sym_sizes(); - at::AutoDispatchBelowADInplaceOrView g; - auto result = ps_roi_pool_symint( - input, rois, spatial_scale, pooled_height, pooled_width); - - auto output = std::get<0>(result); - auto channel_mapping = std::get<1>(result); - ctx->save_for_backward({rois, channel_mapping}); - ctx->mark_non_differentiable({channel_mapping}); - - return {output, channel_mapping}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - // Use data saved in forward - auto saved = ctx->get_saved_variables(); - auto rois = saved[0]; - auto channel_mapping = saved[1]; - auto input_shape = ctx->saved_data["input_shape"].toList(); - auto grad_in = detail::_ps_roi_pool_backward_symint( - grad_output[0], - rois, - channel_mapping, - ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toSymInt(), - ctx->saved_data["pooled_width"].toSymInt(), - input_shape[0].get().toSymInt(), - input_shape[1].get().toSymInt(), - input_shape[2].get().toSymInt(), - input_shape[3].get().toSymInt()); - - return { - grad_in, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable()}; - } -}; - -// TODO: There should be an easier way to do this -class PSROIPoolBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& rois, - const torch::autograd::Variable& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - at::AutoDispatchBelowADInplaceOrView g; - auto grad_in = detail::_ps_roi_pool_backward_symint( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); - - return {grad_in}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on ps_roi_pool not supported"); - } -}; - -std::tuple ps_roi_pool_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - auto result = PSROIPoolFunction::apply( - input, rois, spatial_scale, pooled_height, pooled_width); - - return std::make_tuple(result[0], result[1]); -} - -at::Tensor ps_roi_pool_backward_autograd( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - return PSROIPoolBackwardFunction::apply( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width)[0]; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), - TORCH_FN(ps_roi_pool_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), - TORCH_FN(ps_roi_pool_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/autograd/roi_align_kernel.cpp b/product/include/torchvision/ops/autograd/roi_align_kernel.cpp deleted file mode 100644 index 6d792fe09d9..00000000000 --- a/product/include/torchvision/ops/autograd/roi_align_kernel.cpp +++ /dev/null @@ -1,167 +0,0 @@ -#include "../roi_align.h" - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -class ROIAlignFunction : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - bool aligned) { - ctx->saved_data["spatial_scale"] = spatial_scale; - ctx->saved_data["pooled_height"] = pooled_height; - ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["sampling_ratio"] = sampling_ratio; - ctx->saved_data["aligned"] = aligned; - ctx->saved_data["input_shape"] = input.sym_sizes(); - ctx->save_for_backward({rois}); - at::AutoDispatchBelowADInplaceOrView g; - auto result = roi_align_symint( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned); - return {result}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - // Use data saved in forward - auto saved = ctx->get_saved_variables(); - auto rois = saved[0]; - auto input_shape = ctx->saved_data["input_shape"].toList(); - auto grad_in = detail::_roi_align_backward_symint( - grad_output[0], - rois, - ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toSymInt(), - ctx->saved_data["pooled_width"].toSymInt(), - input_shape[0].get().toSymInt(), - input_shape[1].get().toSymInt(), - input_shape[2].get().toSymInt(), - input_shape[3].get().toSymInt(), - ctx->saved_data["sampling_ratio"].toInt(), - ctx->saved_data["aligned"].toBool()); - return { - grad_in, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable()}; - } -}; - -// TODO: There should be an easier way to do this -class ROIAlignBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned) { - at::AutoDispatchBelowADInplaceOrView g; - auto result = detail::_roi_align_backward_symint( - grad, - rois, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width, - sampling_ratio, - aligned); - return {result}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on roi_align not supported"); - } -}; - -at::Tensor roi_align_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - bool aligned) { - return ROIAlignFunction::apply( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned)[0]; -} - -at::Tensor roi_align_backward_autograd( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned) { - return ROIAlignBackwardFunction::apply( - grad, - rois, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width, - sampling_ratio, - aligned)[0]; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_align"), - TORCH_FN(roi_align_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), - TORCH_FN(roi_align_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/autograd/roi_pool_kernel.cpp b/product/include/torchvision/ops/autograd/roi_pool_kernel.cpp deleted file mode 100644 index 508bafb2b1e..00000000000 --- a/product/include/torchvision/ops/autograd/roi_pool_kernel.cpp +++ /dev/null @@ -1,152 +0,0 @@ -#include "../roi_pool.h" - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -class ROIPoolFunction : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - ctx->saved_data["spatial_scale"] = spatial_scale; - ctx->saved_data["pooled_height"] = pooled_height; - ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["input_shape"] = input.sym_sizes(); - at::AutoDispatchBelowADInplaceOrView g; - auto result = roi_pool_symint( - input, rois, spatial_scale, pooled_height, pooled_width); - - auto output = std::get<0>(result); - auto argmax = std::get<1>(result); - ctx->save_for_backward({rois, argmax}); - ctx->mark_non_differentiable({argmax}); - - return {output, argmax}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - // Use data saved in forward - auto saved = ctx->get_saved_variables(); - auto rois = saved[0]; - auto argmax = saved[1]; - auto input_shape = ctx->saved_data["input_shape"].toList(); - auto grad_in = detail::_roi_pool_backward_symint( - grad_output[0], - rois, - argmax, - ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toSymInt(), - ctx->saved_data["pooled_width"].toSymInt(), - input_shape[0].get().toSymInt(), - input_shape[1].get().toSymInt(), - input_shape[2].get().toSymInt(), - input_shape[3].get().toSymInt()); - - return { - grad_in, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable()}; - } -}; - -// TODO: There should be an easier way to do this -class ROIPoolBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& rois, - const torch::autograd::Variable& argmax, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - at::AutoDispatchBelowADInplaceOrView g; - auto grad_in = detail::_roi_pool_backward_symint( - grad, - rois, - argmax, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); - - return {grad_in}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on roi_pool not supported"); - } -}; - -std::tuple roi_pool_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - auto result = ROIPoolFunction::apply( - input, rois, spatial_scale, pooled_height, pooled_width); - - return std::make_tuple(result[0], result[1]); -} - -at::Tensor roi_pool_backward_autograd( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - return ROIPoolBackwardFunction::apply( - grad, - rois, - argmax, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width)[0]; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_pool"), - TORCH_FN(roi_pool_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), - TORCH_FN(roi_pool_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp b/product/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp deleted file mode 100644 index c5e59077aa6..00000000000 --- a/product/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp +++ /dev/null @@ -1,1172 +0,0 @@ -/*! - ******************* BEGIN Caffe Copyright Notice and Disclaimer - ***************** - * - * COPYRIGHT - * - * All contributions by the University of California: - * Copyright (c) 2014-2017 The Regents of the University of California (Regents) - * All rights reserved. - * - * All other contributions: - * Copyright (c) 2014-2017, the respective contributors - * All rights reserved. - * - * Caffe uses a shared copyright model: each contributor holds copyright over - * their contributions to Caffe. The project versioning records all such - * contribution and copyright details. If a contributor wants to further mark - * their specific copyright on a particular contribution, they should indicate - * their copyright solely in the commit message of the change when it is - * committed. - * - * LICENSE - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE - *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * CONTRIBUTION AGREEMENT - * - * By contributing to the BVLC/caffe repository through pull-request, comment, - * or otherwise, the contributor releases their content to the - * license and copyright terms herein. - * - ***************** END Caffe Copyright Notice and Disclaimer - ********************* - * - * Copyright (c) 2018 Microsoft - * Licensed under The MIT License [see LICENSE for details] - * \file modulated_deformable_im2col.cuh - * \brief Function definitions of converting an image to - * column matrix based on kernel, padding, dilation, and offset. - * These functions are mainly used in deformable convolution operators. - * \ref: https://arxiv.org/abs/1703.06211 - * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng - */ - -// modified from -// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu - -// modified from -// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -const int kMaxParallelImgs = 32; - -template -scalar_t bilinear_interpolate( - const scalar_t* in, - int height, - int width, - scalar_t h, - scalar_t w) { - if (h <= -1 || height <= h || w <= -1 || width <= w) { - return 0; - } - - int h_low = floor(h); - int w_low = floor(w); - int h_high = h_low + 1; - int w_high = w_low + 1; - - scalar_t lh = h - h_low; - scalar_t lw = w - w_low; - scalar_t hh = 1 - lh, hw = 1 - lw; - - scalar_t v1 = 0; - if (h_low >= 0 && w_low >= 0) - v1 = in[h_low * width + w_low]; - scalar_t v2 = 0; - if (h_low >= 0 && w_high <= width - 1) - v2 = in[h_low * width + w_high]; - scalar_t v3 = 0; - if (h_high <= height - 1 && w_low >= 0) - v3 = in[h_high * width + w_low]; - scalar_t v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - v4 = in[h_high * width + w_high]; - - scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - - scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; -} - -template -void deformable_im2col_kernel( - int n, - const scalar_t* input, - const scalar_t* offset, - const scalar_t* mask, - int height, - int width, - int weight_h, - int weight_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int batch_sz, - int n_in_channels, - int n_offset_grps, - int out_h, - int out_w, - bool use_mask, - scalar_t* columns) { - for (int index = 0; index != n; ++index) { - const int out_x = index % out_w; - const int out_y = (index / out_w) % out_h; - const int out_b = (index / (out_w * out_h)) % batch_sz; - const int in_c = index / (out_w * out_h * batch_sz); - const int out_c = in_c * weight_h * weight_w; - - int c_per_offset_grp = n_in_channels / n_offset_grps; - const int grp_idx = in_c / c_per_offset_grp; - - auto columns_ptr = columns + - (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + - out_y * out_w + out_x); - - auto input_ptr = input + - (out_b * (n_in_channels * height * width) + in_c * (height * width)); - - auto offset_ptr = offset + - (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * - out_w; - - auto mask_ptr = mask; - if (use_mask) { - mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * - out_h * out_w; - } - - for (int i = 0; i < weight_h; ++i) { - for (int j = 0; j < weight_w; ++j) { - const int mask_idx = i * weight_w + j; - const int offset_idx = 2 * mask_idx; - - scalar_t mask_value = 1; - if (use_mask) { - mask_value = - mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; - } - - const scalar_t offset_h = - offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; - const scalar_t offset_w = offset_ptr - [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; - const scalar_t y = - (out_y * stride_h - pad_h) + i * dilation_h + offset_h; - const scalar_t x = - (out_x * stride_w - pad_w) + j * dilation_w + offset_w; - *columns_ptr = - mask_value * bilinear_interpolate(input_ptr, height, width, y, x); - columns_ptr += batch_sz * out_h * out_w; - } - } - } -} - -void deformable_im2col( - const at::Tensor& input, - const at::Tensor& data_offset, - const at::Tensor& data_mask, - int n_in_channels, - int height, - int width, - int weight_h, - int weight_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int out_h, - int out_w, - int parallel_imgs, - int deformable_group, - bool use_mask, - at::Tensor data_col) { - int num_kernels = n_in_channels * out_h * out_w * parallel_imgs; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "deformable_im2col", ([&] { - deformable_im2col_kernel( - num_kernels, - input.data_ptr(), - data_offset.data_ptr(), - data_mask.data_ptr(), - height, - width, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - parallel_imgs, - n_in_channels, - deformable_group, - out_h, - out_w, - use_mask, - data_col.data_ptr()); - })); -} - -int get_greatest_divisor_below_bound(int n, int bound) { - for (int k = bound; k > 1; --k) { - if (n % k == 0) { - return k; - } - } - return 1; -} - -template -void deformable_col2im_kernel( - int n, - const scalar_t* col, - const scalar_t* offset, - const scalar_t* mask, - int channels, - int height, - int width, - int kernel_h, - int kernel_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int batch_sz, - int n_offset_grps, - int out_h, - int out_w, - bool use_mask, - scalar_t* grad_im) { - for (int index = 0; index != n; ++index) { - const int out_x = index % out_w; - const int out_y = (index / out_w) % out_h; - const int b = (index / (out_w * out_h)) % batch_sz; - const int j = (index / (out_w * out_h * batch_sz)) % kernel_w; - const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; - const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); - - int c_per_offset_grp = channels / n_offset_grps; - const int offset_grp = c / c_per_offset_grp; - - auto offset_ptr = offset + - (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * - out_w; - - auto mask_ptr = mask; - if (use_mask) { - mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * - out_h * out_w; - } - - const int mask_idx = i * kernel_w + j; - const int offset_idx = 2 * mask_idx; - - const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; - const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; - - const scalar_t offset_h = offset_ptr[offset_h_ptr]; - const scalar_t offset_w = offset_ptr[offset_w_ptr]; - - scalar_t mask_value = 1; - if (use_mask) { - mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; - } - - const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; - const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; - - for (int dy = -1; dy <= 1; dy++) { - for (int dx = -1; dx <= 1; dx++) { - int yp = int(y) + dy; - int xp = int(x) + dx; - if (0 <= yp && yp < height && 0 <= xp && xp < width && - std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { - int grad_pos = ((b * channels + c) * height + yp) * width + xp; - scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); - grad_im[grad_pos] += mask_value * weight * col[index]; - } - } - } - } -} - -void compute_grad_input( - const at::Tensor& columns, - const at::Tensor& offset, - const at::Tensor& mask, - int channels, - int height, - int width, - int weight_h, - int weight_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int parallel_imgs, - int n_offset_grps, - bool use_mask, - at::Tensor grad_im) { - int out_h = - (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - int out_w = - (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int num_kernels = - channels * weight_h * weight_w * out_h * out_w * parallel_imgs; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - columns.scalar_type(), "compute_grad_input", ([&] { - deformable_col2im_kernel( - num_kernels, - columns.data_ptr(), - offset.data_ptr(), - mask.data_ptr(), - channels, - height, - width, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - parallel_imgs, - n_offset_grps, - out_h, - out_w, - use_mask, - grad_im.data_ptr()); - })); -} - -template -scalar_t get_coordinate_weight( - const scalar_t* im_data, - int height, - int width, - scalar_t y, - scalar_t x, - bool is_y_direction) { - int y_l = floor(y); - int x_l = floor(x); - int y_h = y_l + 1; - int x_h = x_l + 1; - - bool valid_y_l = 0 <= y_l && y_l < height; - bool valid_y_h = 0 <= y_h && y_h < height; - bool valid_x_l = 0 <= x_l && x_l < width; - bool valid_x_h = 0 <= x_h && x_h < width; - - scalar_t zero = 0; - scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero; - scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero; - scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero; - scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero; - - if (is_y_direction) { - scalar_t dx = x - x_l; - return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx); - } else { - scalar_t dy = y - y_l; - return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx); - } -} - -template -void deformable_col2im_coord_kernel( - int n, - const scalar_t* col, - const scalar_t* im, - const scalar_t* offset, - const scalar_t* mask, - int channels, - int height, - int width, - int weight_h, - int weight_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int batch_sz, - int offset_channels, - int n_offset_grps, - int out_h, - int out_w, - bool use_mask, - scalar_t* grad_offset, - scalar_t* grad_mask) { - for (int index = 0; index != n; ++index) { - scalar_t grad_offset_val = 0; - scalar_t grad_mask_val = 0; - - int w = index % out_w; - int h = (index / out_w) % out_h; - int w_w = (index / (out_w * out_h * 2)) % weight_w; - int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; - int c = (index / (out_w * out_h)) % offset_channels; - int b = index / (out_w * out_h * offset_channels); - - const int offset_grp = c / (2 * weight_h * weight_w); - const int col_step = weight_h * weight_w; - - int c_per_offset_grp = channels / n_offset_grps; - - auto col_ptr = col + - offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * - out_h; - auto im_ptr = im + - (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; - auto offset_ptr = offset + - (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * - out_w; - - auto mask_ptr = mask; - if (use_mask) { - mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * - out_h * out_w; - } - - const int offset_c = c - offset_grp * 2 * weight_h * weight_w; - const bool is_y_direction = offset_c % 2 == 0; - - const int c_bound = c_per_offset_grp * weight_h * weight_w; - for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { - const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w; - - int out_x = col_pos % out_w; - int out_y = (col_pos / out_w) % out_h; - int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; - int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; - - const int mask_idx = i * weight_w + j; - - const int offset_h_idx = - (((2 * mask_idx) * out_h + out_y) * out_w + out_x); - const int offset_w_idx = - (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); - const scalar_t offset_h = offset_ptr[offset_h_idx]; - const scalar_t offset_w = offset_ptr[offset_w_idx]; - - scalar_t mask_value = 1; - if (use_mask) { - mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; - } - - scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; - scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; - - const scalar_t weight = - get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); - grad_offset_val += mask_value * weight * col_ptr[col_pos]; - - if (use_mask && is_y_direction) { - grad_mask_val += col_ptr[col_pos] * - bilinear_interpolate(im_ptr, height, width, y, x); - } - - im_ptr += height * width; - } - - grad_offset[index] = grad_offset_val; - - if (use_mask && is_y_direction) { - const int idx = - ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + - w_w) * - out_h + - h) * - out_w + - w; - grad_mask[idx] = grad_mask_val; - } - } -} - -void compute_grad_offset_and_mask( - const at::Tensor& columns, - const at::Tensor& input, - const at::Tensor& offset, - const at::Tensor& mask, - int channels, - int height, - int width, - int weight_h, - int weight_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int parallel_imgs, - int n_offset_grps, - bool use_mask, - at::Tensor grad_offset, - at::Tensor grad_mask) { - int out_h = - (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - int out_w = - (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int num_kernels = - out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - columns.scalar_type(), "compute_grad_offset_and_mask", ([&] { - deformable_col2im_coord_kernel( - num_kernels, - columns.data_ptr(), - input.data_ptr(), - offset.data_ptr(), - mask.data_ptr(), - channels, - height, - width, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - parallel_imgs, - 2 * weight_h * weight_w * n_offset_grps, - n_offset_grps, - out_h, - out_w, - use_mask, - grad_offset.data_ptr(), - grad_mask.data_ptr()); - })); -} - -std::tuple backward_gradient_inputs( - at::Tensor input, - at::Tensor weight, - at::Tensor offset, - at::Tensor mask, - at::Tensor grad_out, - int stride_h, - int stride_w, - int pad_h, - int pad_w, - int dilation_h, - int dilation_w, - int n_weight_grps, - int n_offset_grps, - int n_parallel_imgs, - bool use_mask) { - int batch_sz = input.size(0); - int n_in_channels = input.size(1); - int in_h = input.size(2); - int in_w = input.size(3); - - n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); - - long n_out_channels = weight.size(0); - int weight_h = weight.size(2); - int weight_w = weight.size(3); - - long out_h = - (in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - long out_w = - (in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - - auto grad_input = at::zeros_like(input); - auto grad_offset = at::zeros_like(offset); - auto grad_mask = at::zeros_like(mask); - - if (batch_sz == 0) { - return std::make_tuple(grad_input, grad_offset, grad_mask); - } - - auto columns = at::empty( - {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, - input.options()); - - // Separate into blocks - grad_input = grad_input.reshape( - {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - input = input.reshape( - {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - - grad_offset = grad_offset.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - offset = offset.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - - if (use_mask) { - grad_mask = grad_mask.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - mask = mask.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - } - - grad_out = grad_out - .reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_weight_grps, - n_out_channels / n_weight_grps, - out_h, - out_w}) - .permute({0, 2, 3, 1, 4, 5}); - - weight = weight.reshape( - {n_weight_grps, - weight.size(0) / n_weight_grps, - weight.size(1), - weight.size(2), - weight.size(3)}); - - columns = columns.view( - {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); - - for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { - columns.zero_(); - // Separate into weight groups - for (int g = 0; g < n_weight_grps; g++) { - columns[g] = columns[g].addmm_( - weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); - } - - compute_grad_offset_and_mask( - columns, - input[elt], - offset[elt], - mask[elt], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - grad_offset[elt], - grad_mask[elt]); - - compute_grad_input( - columns, - offset[elt], - mask[elt], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - grad_input[elt]); - } - - grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); - grad_offset = grad_offset.view( - {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - - if (use_mask) { - grad_mask = grad_mask.view( - {batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w}); - } - - return std::make_tuple(grad_input, grad_offset, grad_mask); -} - -at::Tensor backward_gradient_parameters( - at::Tensor input, - const at::Tensor& weight, - at::Tensor offset, - at::Tensor mask, - const at::Tensor& grad_out, - int stride_h, - int stride_w, - int pad_h, - int pad_w, - int dilation_h, - int dilation_w, - int n_weight_grps, - int n_offset_grps, - int n_parallel_imgs, - bool use_mask) { - int batch_sz = input.size(0); - int n_in_channels = input.size(1); - int in_h = input.size(2); - int in_w = input.size(3); - - n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); - - long n_out_channels = weight.size(0); - int weight_h = weight.size(2); - int weight_w = weight.size(3); - - long out_h = grad_out.size(2); - long out_w = grad_out.size(3); - - auto grad_weight = at::zeros_like(weight); - if (batch_sz == 0) { - return grad_weight; - } - - at::Tensor grad_out_buf = grad_out - .reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_weight_grps, - n_out_channels / n_weight_grps, - out_h, - out_w}) - .permute({0, 2, 3, 1, 4, 5}) - .contiguous(); - - input = input.reshape( - {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - - offset = offset.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - - if (use_mask) { - mask = mask.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - } - - grad_weight = grad_weight.view( - {n_weight_grps, - grad_weight.size(0) / n_weight_grps, - grad_weight.size(1), - grad_weight.size(2), - grad_weight.size(3)}); - - auto columns = at::empty( - {n_weight_grps, - n_in_channels * weight_w * weight_h / n_weight_grps, - n_parallel_imgs * out_h * out_w}, - input.options()); - - for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { - deformable_im2col( - input[elt], - offset[elt], - mask[elt], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - out_h, - out_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - columns); - - for (int g = 0; g < n_weight_grps; g++) { - grad_weight[g] = - grad_weight[g] - .flatten(1) - .addmm_( - grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) - .view_as(grad_weight[g]); - } - } - - grad_weight = grad_weight.view( - {grad_weight.size(0) * grad_weight.size(1), - grad_weight.size(2), - grad_weight.size(3), - grad_weight.size(4)}); - return grad_weight; -} - -at::Tensor deform_conv2d_forward_kernel( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - bool use_mask) { - at::Tensor input_c = input.contiguous(); - at::Tensor offset_c = offset.contiguous(); - at::Tensor weight_c = weight.contiguous(); - at::Tensor mask_c = mask.contiguous(); - at::Tensor bias_c = bias.contiguous(); - - TORCH_CHECK(input_c.ndimension() == 4); - TORCH_CHECK(offset_c.ndimension() == 4); - TORCH_CHECK(!use_mask || mask_c.ndimension() == 4); - TORCH_CHECK(weight_c.ndimension() == 4); - TORCH_CHECK(input_c.device().is_cpu(), "input must be a CPU tensor"); - - int batch_sz = input_c.size(0); - int n_in_channels = input_c.size(1); - int in_h = input_c.size(2); - int in_w = input_c.size(3); - - int n_parallel_imgs = - get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); - - // Unpack shapes and args - int out_channels = weight_c.size(0); - int weight_h = weight_c.size(2); - int weight_w = weight_c.size(3); - - int ker_h = dilation_h * (weight_h - 1) + 1; - int ker_w = dilation_w * (weight_w - 1) + 1; - int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; - int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; - - TORCH_CHECK( - weight_h > 0 && weight_w > 0, - "weight_h: ", - weight_h, - " weight_w: ", - weight_w); - TORCH_CHECK( - stride_h > 0 && stride_w > 0, - "stride_h: ", - stride_h, - " stride_w: ", - stride_w); - TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); - TORCH_CHECK( - dilation_h > 0 && dilation_w > 0, - "dilation_h: ", - dilation_h, - " dilation_w: ", - dilation_w); - - TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1)); - TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0); - TORCH_CHECK( - (offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w), - "offset.shape[1] is not valid: got: ", - offset_c.size(1), - " expected: ", - n_offset_grps * 2 * weight_h * weight_w); - TORCH_CHECK( - (!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w), - "mask.shape[1] is not valid: got: ", - mask_c.size(1), - " expected: ", - n_offset_grps * weight_h * weight_w); - TORCH_CHECK(input_c.size(1) % n_offset_grps == 0); - - TORCH_CHECK( - (offset_c.size(0) == input_c.size(0)), "invalid batch size of offset"); - TORCH_CHECK( - (offset_c.size(2) == out_h && offset_c.size(3) == out_w), - "offset output dims: (", - offset_c.size(2), - ", ", - offset_c.size(3), - ") - ", - "computed output dims: (", - out_h, - ", ", - out_w, - ")"); - TORCH_CHECK( - (mask_c.size(0) == input_c.size(0)), "invalid batch size of mask"); - TORCH_CHECK( - (!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)), - "mask output dims: (", - mask_c.size(2), - ", ", - mask_c.size(3), - ") - ", - "computed output dims: (", - out_h, - ", ", - out_w, - ")"); - TORCH_CHECK( - out_h > 0 && out_w > 0, - "Calculated output size too small - out_h: ", - out_h, - " out_w: ", - out_w); - - auto out = - at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options()); - if (batch_sz == 0) { - return out; - } - - // Separate batches into blocks - out = out.view( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - out_channels, - out_h, - out_w}); - input_c = input_c.view( - {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - - offset_c = offset_c.view( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - - if (use_mask) { - mask_c = mask_c.view( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - } - - at::Tensor out_buf = at::zeros( - {batch_sz / n_parallel_imgs, - out_channels, - n_parallel_imgs * out_h, - out_w}, - out.options()); - - // Separate channels into convolution groups - out_buf = out_buf.view( - {out_buf.size(0), - n_weight_grps, - out_buf.size(1) / n_weight_grps, - out_buf.size(2), - out_buf.size(3)}); - weight_c = weight_c.view( - {n_weight_grps, - weight_c.size(0) / n_weight_grps, - weight_c.size(1), - weight_c.size(2), - weight_c.size(3)}); - - // Sample points and perform convolution - auto columns = at::zeros( - {n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, - input_c.options()); - for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { - deformable_im2col( - input_c[b], - offset_c[b], - mask_c[b], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - out_h, - out_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - columns); - - columns = columns.view( - {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); - for (int g = 0; g < n_weight_grps; g++) { - out_buf[b][g] = out_buf[b][g] - .flatten(1) - .addmm_(weight_c[g].flatten(1), columns[g]) - .view_as(out_buf[b][g]); - } - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - } - - out_buf = out_buf.view( - {batch_sz / n_parallel_imgs, - out_channels, - n_parallel_imgs, - out_h, - out_w}); - out_buf.transpose_(1, 2); - out.copy_(out_buf); - out = out.view({batch_sz, out_channels, out_h, out_w}); - - return out + bias_c.view({1, out_channels, 1, 1}); -} - -std::tuple -deform_conv2d_backward_kernel( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - bool use_mask) { - at::Tensor grad_out_c = grad_out.contiguous(); - at::Tensor input_c = input.contiguous(); - at::Tensor weight_c = weight.contiguous(); - at::Tensor offset_c = offset.contiguous(); - at::Tensor mask_c = mask.contiguous(); - at::Tensor bias_c = bias.contiguous(); - - const int batch_sz = input_c.size(0); - const int n_parallel_imgs = - get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); - - auto grad_input_and_offset_and_mask = backward_gradient_inputs( - input_c, - weight_c, - offset_c, - mask_c, - grad_out_c, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - n_weight_grps, - n_offset_grps, - n_parallel_imgs, - use_mask); - - auto grad_input = std::get<0>(grad_input_and_offset_and_mask); - auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); - auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); - - auto grad_weight = backward_gradient_parameters( - input_c, - weight_c, - offset_c, - mask_c, - grad_out_c, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - n_weight_grps, - n_offset_grps, - n_parallel_imgs, - use_mask); - - auto grad_bias = at::ones_like(bias_c) * grad_out_c.sum({0, 2, 3}); - - return std::make_tuple( - grad_input, grad_weight, grad_offset, grad_mask, grad_bias); -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), - TORCH_FN(deform_conv2d_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), - TORCH_FN(deform_conv2d_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/cpu/nms_kernel.cpp b/product/include/torchvision/ops/cpu/nms_kernel.cpp deleted file mode 100644 index 50479066cbd..00000000000 --- a/product/include/torchvision/ops/cpu/nms_kernel.cpp +++ /dev/null @@ -1,117 +0,0 @@ -#include -#include - -namespace vision { -namespace ops { - -namespace { - -template -at::Tensor nms_kernel_impl( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold) { - TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); - TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); - TORCH_CHECK( - dets.scalar_type() == scores.scalar_type(), - "dets should have the same type as scores"); - - if (dets.numel() == 0) - return at::empty({0}, dets.options().dtype(at::kLong)); - - auto x1_t = dets.select(1, 0).contiguous(); - auto y1_t = dets.select(1, 1).contiguous(); - auto x2_t = dets.select(1, 2).contiguous(); - auto y2_t = dets.select(1, 3).contiguous(); - - at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); - - auto order_t = std::get<1>( - scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); - - auto ndets = dets.size(0); - at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); - at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); - - auto suppressed = suppressed_t.data_ptr(); - auto keep = keep_t.data_ptr(); - auto order = order_t.data_ptr(); - auto x1 = x1_t.data_ptr(); - auto y1 = y1_t.data_ptr(); - auto x2 = x2_t.data_ptr(); - auto y2 = y2_t.data_ptr(); - auto areas = areas_t.data_ptr(); - - int64_t num_to_keep = 0; - - for (int64_t _i = 0; _i < ndets; _i++) { - auto i = order[_i]; - if (suppressed[i] == 1) - continue; - keep[num_to_keep++] = i; - auto ix1 = x1[i]; - auto iy1 = y1[i]; - auto ix2 = x2[i]; - auto iy2 = y2[i]; - auto iarea = areas[i]; - - for (int64_t _j = _i + 1; _j < ndets; _j++) { - auto j = order[_j]; - if (suppressed[j] == 1) - continue; - auto xx1 = std::max(ix1, x1[j]); - auto yy1 = std::max(iy1, y1[j]); - auto xx2 = std::min(ix2, x2[j]); - auto yy2 = std::min(iy2, y2[j]); - - auto w = std::max(static_cast(0), xx2 - xx1); - auto h = std::max(static_cast(0), yy2 - yy1); - auto inter = w * h; - auto ovr = inter / (iarea + areas[j] - inter); - if (ovr > iou_threshold) - suppressed[j] = 1; - } - } - return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); -} - -at::Tensor nms_kernel( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold) { - TORCH_CHECK( - dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); - TORCH_CHECK( - dets.size(1) == 4, - "boxes should have 4 elements in dimension 1, got ", - dets.size(1)); - TORCH_CHECK( - scores.dim() == 1, - "scores should be a 1d tensor, got ", - scores.dim(), - "D"); - TORCH_CHECK( - dets.size(0) == scores.size(0), - "boxes and scores should have same number of elements in ", - "dimension 0, got ", - dets.size(0), - " and ", - scores.size(0)); - - auto result = at::empty({0}, dets.options()); - - AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] { - result = nms_kernel_impl(dets, scores, iou_threshold); - }); - return result; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp b/product/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp deleted file mode 100644 index 1c272427d3f..00000000000 --- a/product/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp +++ /dev/null @@ -1,429 +0,0 @@ -#include -#include - -namespace vision { -namespace ops { - -namespace { - -template -T bilinear_interpolate( - const T* input, - int height, - int width, - T y, - T x, - int index /* index for debug only*/) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - return 0; - } - - if (y <= 0) - y = 0; - if (x <= 0) - x = 0; - - int y_low = (int)y; - int x_low = (int)x; - int y_high; - int x_high; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T)x_low; - } else { - x_high = x_low + 1; - } - - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - - // do bilinear interpolation - T v1 = input[y_low * width + x_low]; - T v2 = input[y_low * width + x_high]; - T v3 = input[y_high * width + x_low]; - T v4 = input[y_high * width + x_high]; - T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - - T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - - return val; -} - -template -void ps_roi_align_forward_kernel_impl( - int num_rois, - const T* input, - const T spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - int sampling_ratio, - const T* rois, - int channels_out, - T* output, - int* channel_mapping) { - for (int n = 0; n < num_rois; n++) { - // [start, end) interval for spatial sampling - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); - T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); - T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); - T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - int c_in = 0; - for (int c_out = 0; c_out < channels_out; ++c_out) { - for (int ph = 0; ph < pooled_height; ++ph) { - for (int pw = 0; pw < pooled_width; ++pw) { - int index = - ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + - pw; - - // Do not using floor/ceil; this implementation detail is critical - T hstart = static_cast(ph) * bin_size_h + roi_start_h; - T wstart = static_cast(pw) * bin_size_w + roi_start_w; - - // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_width / pooled_width); - const T count = roi_bin_grid_h * roi_bin_grid_w; - - const T* offset_input = - input + (roi_batch_ind * channels + c_in) * height * width; - - T out_sum = 0; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = hstart + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = wstart + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - T val = bilinear_interpolate( - offset_input, height, width, y, x, index); - out_sum += val; - } - } - - out_sum /= count; - output[index] = out_sum; - channel_mapping[index] = c_in; - c_in++; - } - } - } - } -} - -template -void bilinear_interpolate_gradient( - int height, - int width, - T y, - T x, - T& w1, - T& w2, - T& w3, - T& w4, - int& x_low, - int& x_high, - int& y_low, - int& y_high, - int index /* index for debug only*/) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - w1 = w2 = w3 = w4 = 0.; - x_low = x_high = y_low = y_high = -1; - return; - } - - if (y <= 0) - y = 0; - if (x <= 0) - x = 0; - - y_low = (int)y; - x_low = (int)x; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T)x_low; - } else { - x_high = x_low + 1; - } - - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - - // reference in forward - // T v1 = input[y_low * width + x_low]; - // T v2 = input[y_low * width + x_high]; - // T v3 = input[y_high * width + x_low]; - // T v4 = input[y_high * width + x_high]; - // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - - w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; -} - -template -inline void add(T* address, const T& val) { - *address += val; -} - -template -void ps_roi_align_backward_kernel_impl( - int nthreads, - const T* grad_output, - const int* channel_mapping, - const T spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - int sampling_ratio, - int channels_out, - T* grad_input, - const T* rois) { - for (int index = 0; index < nthreads; index++) { - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int n = index / pooled_width / pooled_height / channels_out; - - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); - T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); - T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); - T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); - - // Force too small ROIs to be 1x1 - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - T bin_size_h = roi_height / static_cast(pooled_height); - T bin_size_w = roi_width / static_cast(pooled_width); - - int c_in = channel_mapping[index]; - T* grad_input_offset = - grad_input + (roi_batch_ind * channels + c_in) * height * width; - - // Do not using floor/ceil; this implementation detail is critical - T hstart = static_cast(ph) * bin_size_h + roi_start_h; - T wstart = static_cast(pw) * bin_size_w + roi_start_w; - - const T grad_output_this_bin = grad_output[index]; - - // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 - int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - const T count = roi_bin_grid_h * roi_bin_grid_w; - - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = hstart + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = wstart + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - - T w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - - bilinear_interpolate_gradient( - height, - width, - y, - x, - w1, - w2, - w3, - w4, - x_low, - x_high, - y_low, - y_high, - index); - - T g1 = grad_output_this_bin * w1 / count; - T g2 = grad_output_this_bin * w2 / count; - T g3 = grad_output_this_bin * w3 / count; - T g4 = grad_output_this_bin * w4 / count; - - if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - add(grad_input_offset + y_low * width + x_low, g1); - add(grad_input_offset + y_low * width + x_high, g2); - add(grad_input_offset + y_high * width + x_low, g3); - add(grad_input_offset + y_high * width + x_high, g4); - } // if - } // ix - } // iy - } -} - -std::tuple ps_roi_align_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio) { - // Check if input tensors are CPU tensors - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK( - rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "ps_roi_align_forward_kernel"; - at::checkAllSameType(c, {input_t, rois_t}); - - int num_rois = rois.size(0); - int channels = input.size(1); - int height = input.size(2); - int width = input.size(3); - - TORCH_CHECK( - channels % (pooled_height * pooled_width) == 0, - "input channels must be a multiple of pooling height * pooling width"); - int channels_out = channels / (pooled_height * pooled_width); - - auto output = at::zeros( - {num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = - at::zeros(output.sizes(), input.options().dtype(at::kInt)); - - if (output.numel() == 0) { - return std::make_tuple(output, channel_mapping); - } - - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ps_roi_align_forward_kernel", [&] { - ps_roi_align_forward_kernel_impl( - num_rois, - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - rois_.data_ptr(), - channels_out, - output.data_ptr(), - channel_mapping.data_ptr()); - }); - return std::make_tuple(output, channel_mapping); -} - -at::Tensor ps_roi_align_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - // Check if input tensors are CPU tensors - TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK( - channel_mapping.device().is_cpu(), - "channel_mapping must be a CPU tensor"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, - channel_mapping_t{channel_mapping, "channel_mapping", 3}; - - at::CheckedFrom c = "ps_roi_align_backward_kernel"; - at::checkAllSameType(c, {grad_t, rois_t}); - - auto grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); - - // handle possibly empty gradients - if (grad.numel() == 0) { - return grad_input; - } - - int channels_out = channels / (pooled_height * pooled_width); - - auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ps_roi_align_backward_kernel", [&] { - ps_roi_align_backward_kernel_impl( - grad.numel(), - grad_.data_ptr(), - channel_mapping.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - channels_out, - grad_input.data_ptr(), - rois_.data_ptr()); - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), - TORCH_FN(ps_roi_align_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), - TORCH_FN(ps_roi_align_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp b/product/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp deleted file mode 100644 index 607cbe4bab6..00000000000 --- a/product/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp +++ /dev/null @@ -1,273 +0,0 @@ -#include -#include - -namespace vision { -namespace ops { - -namespace { - -template -inline void add(T* address, const T& val) { - *address += val; -} - -template -void ps_roi_pool_forward_kernel_impl( - const T* input, - const T spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - const T* rois, - int channels_out, - int num_rois, - T* output, - int* channel_mapping) { - for (int n = 0; n < num_rois; ++n) { - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - int roi_start_w = round(offset_rois[1] * spatial_scale); - int roi_start_h = round(offset_rois[2] * spatial_scale); - int roi_end_w = round(offset_rois[3] * spatial_scale); - int roi_end_h = round(offset_rois[4] * spatial_scale); - - // Force too small ROIs to be 1x1 - int roi_width = std::max(roi_end_w - roi_start_w, 1); - int roi_height = std::max(roi_end_h - roi_start_h, 1); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - int c_in = 0; - for (int c_out = 0; c_out < channels_out; ++c_out) { - for (int ph = 0; ph < pooled_height; ++ph) { - for (int pw = 0; pw < pooled_width; ++pw) { - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = - static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = - static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - // Add roi offsets and clip to input boundaries - hstart = std::min(std::max(hstart + roi_start_h, 0), height - 1); - hend = std::min(std::max(hend + roi_start_h, 0), height - 1); - wstart = std::min(std::max(wstart + roi_start_w, 0), width - 1); - wend = std::min(std::max(wend + roi_start_w, 0), width - 1); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - const T* offset_input = - input + (roi_batch_ind * channels + c_in) * height * width; - - T out_sum = 0; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - int input_index = h * width + w; - out_sum += offset_input[input_index]; - } - } - - int index = - ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + - pw; - T bin_area = (hend - hstart) * (wend - wstart); - output[index] = is_empty ? static_cast(0) : out_sum / bin_area; - channel_mapping[index] = c_in; - c_in++; - } - } - } - } -} - -template -void ps_roi_pool_backward_kernel_impl( - const T* grad_output, - const int* channel_mapping, - int num_rois, - const T spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - int channels_out, - T* grad_input, - const T* rois) { - for (int n = 0; n < num_rois; ++n) { - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - int roi_start_w = roundf(offset_rois[1] * spatial_scale); - int roi_start_h = roundf(offset_rois[2] * spatial_scale); - int roi_end_w = roundf(offset_rois[3] * spatial_scale); - int roi_end_h = roundf(offset_rois[4] * spatial_scale); - - // Force too small ROIs to be 1x1 - int roi_width = std::max(roi_end_w - roi_start_w, 1); - int roi_height = std::max(roi_end_h - roi_start_h, 1); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - for (int ph = 0; ph < pooled_height; ++ph) { - for (int pw = 0; pw < pooled_width; ++pw) { - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - // Add roi offsets and clip to input boundaries - hstart = std::min(std::max(hstart + roi_start_h, 0), height); - hend = std::min(std::max(hend + roi_start_h, 0), height); - wstart = std::min(std::max(wstart + roi_start_w, 0), width); - wend = std::min(std::max(wend + roi_start_w, 0), width); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - for (int c_out = 0; c_out < channels_out; ++c_out) { - int index = - ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + - pw; - int c_in = channel_mapping[index]; - - T* grad_input_offset = - grad_input + (roi_batch_ind * channels + c_in) * height * width; - T bin_area = (hend - hstart) * (wend - wstart); - T diff_val = - is_empty ? static_cast(0) : grad_output[index] / bin_area; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - int grad_input_index = h * width + w; - add(grad_input_offset + grad_input_index, diff_val); - } - } - } - } - } - } -} - -std::tuple ps_roi_pool_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - // Check if input tensors are CPU tensors - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK( - rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "ps_roi_pool_forward_kernel"; - at::checkAllSameType(c, {input_t, rois_t}); - - int num_rois = rois.size(0); - int channels = input.size(1); - int height = input.size(2); - int width = input.size(3); - - TORCH_CHECK( - channels % (pooled_height * pooled_width) == 0, - "input channels must be a multiple of pooling height * pooling width"); - int channels_out = channels / (pooled_height * pooled_width); - - auto output = at::zeros( - {num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = - at::zeros(output.sizes(), input.options().dtype(at::kInt)); - - auto output_size = output.numel(); - if (output_size == 0) { - return std::make_tuple(output, channel_mapping); - } - - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ps_roi_pool_forward_kernel", [&] { - ps_roi_pool_forward_kernel_impl( - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - rois_.data_ptr(), - channels_out, - num_rois, - output.data_ptr(), - channel_mapping.data_ptr()); - }); - return std::make_tuple(output, channel_mapping); -} - -at::Tensor ps_roi_pool_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - // Check if input tensors are CPU tensors - TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK( - channel_mapping.device().is_cpu(), - "channel_mapping must be a CPU tensor"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, - channel_mapping_t{channel_mapping, "channel_mapping", 3}; - - at::CheckedFrom c = "ps_roi_pool_backward_kernel"; - at::checkAllSameType(c, {grad_t, rois_t}); - - auto num_rois = rois.size(0); - auto grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); - - // handle possibly empty gradients - if (grad.numel() == 0) { - return grad_input; - } - - int channels_out = channels / (pooled_height * pooled_width); - - auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] { - ps_roi_pool_backward_kernel_impl( - grad_.data_ptr(), - channel_mapping.data_ptr(), - num_rois, - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - channels_out, - grad_input.data_ptr(), - rois_.data_ptr()); - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), - TORCH_FN(ps_roi_pool_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), - TORCH_FN(ps_roi_pool_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/cpu/roi_align_common.h b/product/include/torchvision/ops/cpu/roi_align_common.h deleted file mode 100644 index e10c67b5b79..00000000000 --- a/product/include/torchvision/ops/cpu/roi_align_common.h +++ /dev/null @@ -1,128 +0,0 @@ -#pragma once - -#include - -namespace vision { -namespace ops { -namespace detail { - -template -struct PreCalc { - int pos1; - int pos2; - int pos3; - int pos4; - T w1; - T w2; - T w3; - T w4; -}; - -// This helper computes the interpolation weights (w1, w2...) for every sampling -// point of a given box. There are pool_height * pool_width * roi_bin_grid_h * -// roi_bin_grid_w such sampling points. -// -// The weights (w1, w2...) are computed as the areas in this figure: -// https://en.wikipedia.org/wiki/Bilinear_interpolation#/media/File:Bilinear_interpolation_visualisation.svg -// and pos1, pos2 etc correspond to the indices of their respective pixels. -// -// Note: the weights and indices are shared across all channels, which is why -// they are pre-calculated prior to the main loop in the RoIAlign kernel. -// implementation taken from Caffe2 -template -void pre_calc_for_bilinear_interpolate( - int height, - int width, - int pooled_height, - int pooled_width, - T roi_start_h, - T roi_start_w, - T bin_size_h, - T bin_size_w, - int roi_bin_grid_h, - int roi_bin_grid_w, - std::vector>& pre_calc) { - int pre_calc_index = 0; - for (int ph = 0; ph < pooled_height; ph++) { - for (int pw = 0; pw < pooled_width; pw++) { - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T yy = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T xx = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - - T x = xx; - T y = yy; - // deal with: inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - PreCalc pc; - pc.pos1 = 0; - pc.pos2 = 0; - pc.pos3 = 0; - pc.pos4 = 0; - pc.w1 = 0; - pc.w2 = 0; - pc.w3 = 0; - pc.w4 = 0; - pre_calc[pre_calc_index] = pc; - pre_calc_index += 1; - continue; - } - - if (y <= 0) { - y = 0; - } - if (x <= 0) { - x = 0; - } - - int y_low = (int)y; - int x_low = (int)x; - int y_high; - int x_high; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T)x_low; - } else { - x_high = x_low + 1; - } - - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - - // save weights and indices - PreCalc pc; - pc.pos1 = y_low * width + x_low; - pc.pos2 = y_low * width + x_high; - pc.pos3 = y_high * width + x_low; - pc.pos4 = y_high * width + x_high; - pc.w1 = w1; - pc.w2 = w2; - pc.w3 = w3; - pc.w4 = w4; - pre_calc[pre_calc_index] = pc; - - pre_calc_index += 1; - } - } - } - } -} - -} // namespace detail -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/cpu/roi_align_kernel.cpp b/product/include/torchvision/ops/cpu/roi_align_kernel.cpp deleted file mode 100644 index b787de6f6bb..00000000000 --- a/product/include/torchvision/ops/cpu/roi_align_kernel.cpp +++ /dev/null @@ -1,400 +0,0 @@ -#include -#include - -#include "./roi_align_common.h" - -namespace vision { -namespace ops { - -namespace { - -template -void roi_align_forward_kernel_impl( - int n_rois, - const T* input, - const T& spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - int sampling_ratio, - bool aligned, - const T* rois, - T* output) { - // (n, c, ph, pw) is an element in the pooled output - // can be parallelized using omp - // #pragma omp parallel for num_threads(32) - for (int n = 0; n < n_rois; n++) { - int index_n = n * channels * pooled_width * pooled_height; - - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T offset = aligned ? (T)0.5 : (T)0.0; - T roi_start_w = offset_rois[1] * spatial_scale - offset; - T roi_start_h = offset_rois[2] * spatial_scale - offset; - T roi_end_w = offset_rois[3] * spatial_scale - offset; - T roi_end_h = offset_rois[4] * spatial_scale - offset; - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = std::max(roi_width, (T)1.); - roi_height = std::max(roi_height, (T)1.); - } - - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 - int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - - // We do average (integral) pooling inside a bin - // When the grid is empty, output zeros. - const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 - - // we want to precalculate indices and weights shared by all channels, - // this is the key point of optimization - std::vector> pre_calc( - roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); - detail::pre_calc_for_bilinear_interpolate( - height, - width, - pooled_height, - pooled_width, - roi_start_h, - roi_start_w, - bin_size_h, - bin_size_w, - roi_bin_grid_h, - roi_bin_grid_w, - pre_calc); - - for (int c = 0; c < channels; c++) { - int index_n_c = index_n + c * pooled_width * pooled_height; - const T* offset_input = - input + (roi_batch_ind * channels + c) * height * width; - int pre_calc_index = 0; - - for (int ph = 0; ph < pooled_height; ph++) { - for (int pw = 0; pw < pooled_width; pw++) { - int index = index_n_c + ph * pooled_width + pw; - - T output_val = 0.; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - detail::PreCalc pc = pre_calc[pre_calc_index]; - output_val += pc.w1 * offset_input[pc.pos1] + - pc.w2 * offset_input[pc.pos2] + - pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4]; - - pre_calc_index += 1; - } - } - output_val /= count; // Average pooling - - output[index] = output_val; - } // for pw - } // for ph - } // for c - } // for n -} - -template -void bilinear_interpolate_gradient( - int height, - int width, - T y, - T x, - T& w1, - T& w2, - T& w3, - T& w4, - int& x_low, - int& x_high, - int& y_low, - int& y_high, - int index /* index for debug only*/) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - w1 = w2 = w3 = w4 = 0.; - x_low = x_high = y_low = y_high = -1; - return; - } - - if (y <= 0) - y = 0; - if (x <= 0) - x = 0; - - y_low = (int)y; - x_low = (int)x; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T)x_low; - } else { - x_high = x_low + 1; - } - - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - - // reference in forward - // T v1 = input[y_low * width + x_low]; - // T v2 = input[y_low * width + x_high]; - // T v3 = input[y_high * width + x_low]; - // T v4 = input[y_high * width + x_high]; - // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - - w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; -} - -template -inline void add(T* address, const T& val) { - *address += val; -} - -template -void roi_align_backward_kernel_impl( - int nthreads, - const T* grad_output, - const T& spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - int sampling_ratio, - bool aligned, - T* grad_input, - const T* rois, - int n_stride, - int c_stride, - int h_stride, - int w_stride) { - for (int index = 0; index < nthreads; index++) { - // (n, c, ph, pw) is an element in the pooled output - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; - - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T offset = aligned ? (T)0.5 : (T)0.0; - T roi_start_w = offset_rois[1] * spatial_scale - offset; - T roi_start_h = offset_rois[2] * spatial_scale - offset; - T roi_end_w = offset_rois[3] * spatial_scale - offset; - T roi_end_h = offset_rois[4] * spatial_scale - offset; - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = std::max(roi_width, (T)1.); - roi_height = std::max(roi_height, (T)1.); - } - - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - T* offset_grad_input = - grad_input + ((roi_batch_ind * channels + c) * height * width); - - int output_offset = n * n_stride + c * c_stride; - const T* offset_grad_output = grad_output + output_offset; - const T grad_output_this_bin = - offset_grad_output[ph * h_stride + pw * w_stride]; - - // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 - int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - - // We do average (integral) pooling inside a bin - const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 - - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - - T w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - - bilinear_interpolate_gradient( - height, - width, - y, - x, - w1, - w2, - w3, - w4, - x_low, - x_high, - y_low, - y_high, - index); - - T g1 = grad_output_this_bin * w1 / count; - T g2 = grad_output_this_bin * w2 / count; - T g3 = grad_output_this_bin * w3 / count; - T g4 = grad_output_this_bin * w4 / count; - - if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - // atomic add is not needed for now since it is single threaded - add(offset_grad_input + y_low * width + x_low, static_cast(g1)); - add(offset_grad_input + y_low * width + x_high, static_cast(g2)); - add(offset_grad_input + y_high * width + x_low, static_cast(g3)); - add(offset_grad_input + y_high * width + x_high, static_cast(g4)); - } // if - } // ix - } // iy - } // for -} - -at::Tensor roi_align_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned) { - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_align_forward_kernel"; - at::checkAllSameType(c, {input_t, rois_t}); - - auto num_rois = rois.size(0); - auto channels = input.size(1); - auto height = input.size(2); - auto width = input.size(3); - - at::Tensor output = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, input.options()); - - if (output.numel() == 0) - return output; - - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "roi_align_forward_kernel", [&] { - roi_align_forward_kernel_impl( - num_rois, - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - aligned, - rois_.data_ptr(), - output.data_ptr()); - }); - return output; -} - -at::Tensor roi_align_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned) { - TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_align_backward_kernel"; - at::checkAllSameType(c, {grad_t, rois_t}); - - at::Tensor grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); - - // handle possibly empty gradients - if (grad.numel() == 0) { - return grad_input; - } - - // get stride values to ensure indexing into gradients is correct. - int n_stride = grad.stride(0); - int c_stride = grad.stride(1); - int h_stride = grad.stride(2); - int w_stride = grad.stride(3); - - auto rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "roi_align_backward_kernel", [&] { - roi_align_backward_kernel_impl( - grad.numel(), - grad.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - aligned, - grad_input.data_ptr(), - rois_.data_ptr(), - n_stride, - c_stride, - h_stride, - w_stride); - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_align"), - TORCH_FN(roi_align_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), - TORCH_FN(roi_align_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/cpu/roi_pool_kernel.cpp b/product/include/torchvision/ops/cpu/roi_pool_kernel.cpp deleted file mode 100644 index b099523896a..00000000000 --- a/product/include/torchvision/ops/cpu/roi_pool_kernel.cpp +++ /dev/null @@ -1,249 +0,0 @@ -#include - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -template -inline void add(T* address, const T& val) { - *address += val; -} - -template -void roi_pool_forward_kernel_impl( - const T* input, - const T spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - const T* rois, - int num_rois, - T* output, - int* argmax_data) { - for (int n = 0; n < num_rois; ++n) { - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - int roi_start_w = round(offset_rois[1] * spatial_scale); - int roi_start_h = round(offset_rois[2] * spatial_scale); - int roi_end_w = round(offset_rois[3] * spatial_scale); - int roi_end_h = round(offset_rois[4] * spatial_scale); - - // Force malformed ROIs to be 1x1 - int roi_width = std::max(roi_end_w - roi_start_w + 1, 1); - int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - for (int ph = 0; ph < pooled_height; ++ph) { - for (int pw = 0; pw < pooled_width; ++pw) { - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - // Add roi offsets and clip to input boundaries - hstart = std::min(std::max(hstart + roi_start_h, 0), height); - hend = std::min(std::max(hend + roi_start_h, 0), height); - wstart = std::min(std::max(wstart + roi_start_w, 0), width); - wend = std::min(std::max(wend + roi_start_w, 0), width); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - for (int c = 0; c < channels; ++c) { - // Define an empty pooling region to be zero - T maxval = is_empty ? 0 : -FLT_MAX; - // If nothing is pooled, argmax = -1 causes nothing to be backprop'd - int maxidx = -1; - - const T* input_offset = - input + (roi_batch_ind * channels + c) * height * width; - - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - int input_index = h * width + w; - if (input_offset[input_index] > maxval) { - maxval = input_offset[input_index]; - maxidx = input_index; - } - } - } - int index = - ((n * channels + c) * pooled_height + ph) * pooled_width + pw; - output[index] = maxval; - argmax_data[index] = maxidx; - } // channels - } // pooled_width - } // pooled_height - } // num_rois -} - -template -void roi_pool_backward_kernel_impl( - const T* grad_output, - const int* argmax_data, - int num_rois, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - T* grad_input, - const T* rois, - int n_stride, - int c_stride, - int h_stride, - int w_stride) { - for (int n = 0; n < num_rois; ++n) { - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - - for (int c = 0; c < channels; ++c) { - T* grad_input_offset = - grad_input + ((roi_batch_ind * channels + c) * height * width); - const int* argmax_data_offset = - argmax_data + (n * channels + c) * pooled_height * pooled_width; - - for (int ph = 0; ph < pooled_height; ++ph) { - for (int pw = 0; pw < pooled_width; ++pw) { - int output_offset = n * n_stride + c * c_stride; - int argmax = argmax_data_offset[ph * pooled_width + pw]; - - if (argmax != -1) { - add(grad_input_offset + argmax, - static_cast( - grad_output - [output_offset + ph * h_stride + pw * w_stride])); - } - } // pooled_width - } // pooled_height - } // channels - } // num_rois -} - -std::tuple roi_pool_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_pool_forward_kernel"; - at::checkAllSameType(c, {input_t, rois_t}); - - int num_rois = rois.size(0); - int channels = input.size(1); - int height = input.size(2); - int width = input.size(3); - - at::Tensor output = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, input.options()); - at::Tensor argmax = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, - input.options().dtype(at::kInt)); - - if (output.numel() == 0) { - return std::make_tuple(output, argmax); - } - - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "roi_pool_forward_kernel", [&] { - roi_pool_forward_kernel_impl( - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - rois_.data_ptr(), - num_rois, - output.data_ptr(), - argmax.data_ptr()); - }); - return std::make_tuple(output, argmax); -} - -at::Tensor roi_pool_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - // Check if input tensors are CPU tensors - TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK(argmax.device().is_cpu(), "argmax must be a CPU tensor"); - TORCH_CHECK( - rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_pool_backward_kernel"; - at::checkAllSameType(c, {grad_t, rois_t}); - - auto num_rois = rois.size(0); - - at::Tensor grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); - - // handle possibly empty gradients - if (grad.numel() == 0) { - return grad_input; - } - - // get stride values to ensure indexing into gradients is correct. - int n_stride = grad.stride(0); - int c_stride = grad.stride(1); - int h_stride = grad.stride(2); - int w_stride = grad.stride(3); - - auto rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "roi_pool_backward_kernel", [&] { - roi_pool_backward_kernel_impl( - grad.data_ptr(), - argmax.data_ptr(), - num_rois, - channels, - height, - width, - pooled_height, - pooled_width, - grad_input.data_ptr(), - rois_.data_ptr(), - n_stride, - c_stride, - h_stride, - w_stride); - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_pool"), - TORCH_FN(roi_pool_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), - TORCH_FN(roi_pool_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/deform_conv2d.cpp b/product/include/torchvision/ops/deform_conv2d.cpp deleted file mode 100644 index 3cda60fe0bc..00000000000 --- a/product/include/torchvision/ops/deform_conv2d.cpp +++ /dev/null @@ -1,172 +0,0 @@ -#include "deform_conv2d.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -at::Tensor deform_conv2d( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, - bool use_mask) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::deform_conv2d", "") - .typed(); - return op.call( - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); -} - -at::Tensor deform_conv2d_symint( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::deform_conv2d", "") - .typed(); - return op.call( - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); -} - -namespace detail { - -std::tuple -_deform_conv2d_backward( - const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, - bool use_mask) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") - .typed(); - return op.call( - grad, - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); -} - -std::tuple -_deform_conv2d_backward_symint( - const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") - .typed(); - return op.call( - grad, - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)")); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/deform_conv2d.h b/product/include/torchvision/ops/deform_conv2d.h deleted file mode 100644 index cf1f142e648..00000000000 --- a/product/include/torchvision/ops/deform_conv2d.h +++ /dev/null @@ -1,82 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API at::Tensor deform_conv2d( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, - bool use_mask); - -VISION_API at::Tensor deform_conv2d_symint( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask); - -namespace detail { - -std::tuple -_deform_conv2d_backward( - const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, - bool use_mask); - -std::tuple -_deform_conv2d_backward_symint( - const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask); - -} // namespace detail - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/mps/mps_helpers.h b/product/include/torchvision/ops/mps/mps_helpers.h deleted file mode 100644 index d3c0e8d94b7..00000000000 --- a/product/include/torchvision/ops/mps/mps_helpers.h +++ /dev/null @@ -1,6 +0,0 @@ -constexpr int threadsPerBlock = 512; - -template -constexpr inline T ceil_div(T n, T m) { - return (n + m - 1) / m; -} diff --git a/product/include/torchvision/ops/mps/mps_kernels.h b/product/include/torchvision/ops/mps/mps_kernels.h deleted file mode 100644 index e720a1608f1..00000000000 --- a/product/include/torchvision/ops/mps/mps_kernels.h +++ /dev/null @@ -1,1102 +0,0 @@ -#include - -namespace vision { -namespace ops { - -namespace mps { - -static const char* METAL_VISION = R"VISION_METAL( - -#include -#include -using namespace metal; - -/*----------Macros----------*/ - -#define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \ - for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \ - i += (tptg.x * n_tgs)) - -#define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, uint) - -/*----------Helpers--------*/ - -template -inline T ceil_div(T n, T m) { - return (n + m - 1) / m; -} - -template -inline void atomic_add_float( device T* data_ptr, const T val) -{ -#if __METAL_VERSION__ >= 300 - // atomic_float is supported in Metal 3 (macOS Ventura) onward. - device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); -#else - // Custom atomic addition implementation - // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 - // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639 - // https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide) - - // Create an atomic uint pointer for atomic transaction. - device atomic_uint* atom_var = (device atomic_uint*)data_ptr; - // Create necessary storage. - uint fetched_uint, assigning_uint; - T fetched_float, assigning_float; - - // Replace the value in atom_var with 0 and return the previous value in atom_var. - fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed); - // Read out the previous value as float. - fetched_float = *( (thread T*) &fetched_uint ); - - // Do addition and represent the addition result in uint for atomic transaction. - assigning_float = fetched_float + val; - assigning_uint = *((thread uint*) &assigning_float); - - // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr). - while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) { - // If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads. - // Try to assign 0 and get the previously assigned addition result. - uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed); - T fetched_float_again = *( (thread T*) &fetched_uint_again ); - // Re-add again - fetched_float = *((thread T*) &(fetched_uint)); - // Previously assigned addition result + addition result from other threads. - assigning_float = fetched_float_again + fetched_float; - assigning_uint = *( (thread uint*) &assigning_float); - } -#endif -} - -template -inline T bilinear_interpolate( - constant T* input, - integer_t height, - integer_t width, - T y, - T x, - uint index /* index for debug only*/) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - return 0; - } - - if (y <= 0) - y = 0; - if (x <= 0) - x = 0; - - integer_t y_low = (integer_t)y; - integer_t x_low = (integer_t)x; - integer_t y_high; - integer_t x_high; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T)x_low; - } else { - x_high = x_low + 1; - } - - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - - // do bilinear interpolation - T v1 = input[y_low * width + x_low]; - T v2 = input[y_low * width + x_high]; - T v3 = input[y_high * width + x_low]; - T v4 = input[y_high * width + x_high]; - T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - - T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - - return val; -} - -template -inline void bilinear_interpolate_gradient( - integer_t height, - integer_t width, - T y, - T x, - thread T& w1, - thread T& w2, - thread T& w3, - thread T& w4, - thread integer_t& x_low, - thread integer_t& x_high, - thread integer_t& y_low, - thread integer_t& y_high, - uint index /* index for debug only*/) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - w1 = w2 = w3 = w4 = 0.; - x_low = x_high = y_low = y_high = -1; - return; - } - - if (y <= 0) - y = 0; - if (x <= 0) - x = 0; - - y_low = (integer_t)y; - x_low = (integer_t)x; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T)x_low; - } else { - x_high = x_low + 1; - } - - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - - // reference in forward - // T v1 = input[y_low * width + x_low]; - // T v2 = input[y_low * width + x_high]; - // T v3 = input[y_high * width + x_low]; - // T v4 = input[y_high * width + x_high]; - // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - - w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; -} - -template -inline bool IoU( - constant T & a, - threadgroup T & b, - const float threshold) { - auto xx1 = max(a.x, b.x); - auto yy1 = max(a.y, b.y); - auto xx2 = min(a.z, b.z); - auto yy2 = min(a.w, b.w); - auto w = max(static_cast(0), xx2 - xx1); - auto h = max(static_cast(0), yy2 - yy1); - // Upcast to float before multiplications to circumvent precision issues in half. - auto inter = static_cast(w) * static_cast(h); - auto area_b = static_cast(b.z - b.x) * static_cast(b.w - b.y); - auto area_a = static_cast(a.z - a.x) * static_cast(a.w - a.y); - return (inter / (area_a + area_b - inter)) > threshold; -} - -/*----------Kernels----------*/ - -// This should be in sync with the one in nms_kernel.mm. -// Since metal does not support dynamic array, -// we need to make it static instead of deriving it from [[threads_per_threadgroup]]. -constant int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8; - -template -kernel void nms(constant T * dev_boxes [[buffer(0)]], - device uint64_t * mask [[buffer(1)]], - constant int64_t & n_boxes [[buffer(2)]], - constant float & iou_threshold [[buffer(3)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tid2 [[thread_position_in_threadgroup]]) { - - const uint row_start = tgid.y; - const uint col_start = tgid.x; - const uint tid = tid2.x; - const uint row_size = - min(n_boxes - row_start * nmsThreadsPerBlock, nmsThreadsPerBlock); - const uint col_size = - min(n_boxes - col_start * nmsThreadsPerBlock, nmsThreadsPerBlock); - - threadgroup T block_boxes[nmsThreadsPerBlock]; - block_boxes[tid] = dev_boxes[nmsThreadsPerBlock * col_start + tid]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tid < row_size) { - const uint cur_box_idx = nmsThreadsPerBlock * row_start + tid; - uint64_t t = 0; - uint start = 0; - - if (row_start == col_start) { - start = tid + 1; - } - - for (uint i = start; i < col_size; i++){ - if (IoU(dev_boxes[cur_box_idx], block_boxes[i], iou_threshold)){ - t |= static_cast(1) << i; // discard 1 keep 0 - } - } - const uint col_blocks = ceil_div(n_boxes, nmsThreadsPerBlock); - mask[cur_box_idx * col_blocks + col_start] = t; - } -} - -#define REGISTER_NMS_OP(DTYPE) \ -template \ -[[host_name("nms_" #DTYPE)]] \ -kernel void nms( \ - constant DTYPE ## 4 * dev_boxes [[buffer(0)]], \ - device uint64_t * mask [[buffer(1)]], \ - constant int64_t & n_boxes [[buffer(2)]], \ - constant float & iou_threshold [[buffer(3)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void roi_align( - constant T * input [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * output [[buffer(2)]], - constant int64_t & output_size [[buffer(3)]], - constant int64_t & channels [[buffer(4)]], - constant int64_t & height [[buffer(5)]], - constant int64_t & width [[buffer(6)]], - constant int64_t & pooled_height [[buffer(7)]], - constant int64_t & pooled_width [[buffer(8)]], - constant int64_t & sampling_ratio [[buffer(9)]], - constant bool & aligned [[buffer(10)]], - constant float & spatial_scale [[buffer(11)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, c, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t c = (index / pooled_width / pooled_height) % channels; - integer_t n = index / pooled_width / pooled_height / channels; - - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T offset = aligned ? (T)0.5 : (T)0.0; - T roi_start_w = offset_rois[1] * spatial_scale - offset; - T roi_start_h = offset_rois[2] * spatial_scale - offset; - T roi_end_w = offset_rois[3] * spatial_scale - offset; - T roi_end_h = offset_rois[4] * spatial_scale - offset; - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = max(roi_width, (T)1.); - roi_height = max(roi_height, (T)1.); - } - - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - constant T* offset_input = - input + (roi_batch_ind * channels + c) * height * width; - - // We use roi_bin_grid to sample the grid and mimic integral - integer_t roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 - integer_t roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - - // We do average (integral) pooling inside a bin - // When the grid is empty, output zeros. - const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast(1)); // e.g. = 4 - - T output_val = 0.; - for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 - { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - - T val = bilinear_interpolate(offset_input, height, width, y, x, index); - output_val += val; - } - } - output_val /= count; - - output[index] = output_val; - } -} - -#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("roi_align_" #DTYPE)]] \ -kernel void roi_align( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - constant int64_t & output_size [[buffer(3)]], \ - constant int64_t & channels [[buffer(4)]], \ - constant int64_t & height [[buffer(5)]], \ - constant int64_t & width [[buffer(6)]], \ - constant int64_t & pooled_height [[buffer(7)]], \ - constant int64_t & pooled_width [[buffer(8)]], \ - constant int64_t & sampling_ratio [[buffer(9)]], \ - constant bool & aligned [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void roi_align_backward( - constant T * grad_output [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * grad_input [[buffer(2)]], - constant int64_t & output_size [[buffer(3)]], - constant int64_t & channels [[buffer(4)]], - constant int64_t & height [[buffer(5)]], - constant int64_t & width [[buffer(6)]], - constant int64_t & pooled_height [[buffer(7)]], - constant int64_t & pooled_width [[buffer(8)]], - constant int64_t & sampling_ratio [[buffer(9)]], - constant bool & aligned [[buffer(10)]], - constant float & spatial_scale [[buffer(11)]], - constant int64_t & n_stride [[buffer(12)]], - constant int64_t & c_stride [[buffer(13)]], - constant int64_t & h_stride [[buffer(14)]], - constant int64_t & w_stride [[buffer(15)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, c, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t c = (index / pooled_width / pooled_height) % channels; - integer_t n = index / pooled_width / pooled_height / channels; - - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T offset = aligned ? (T)0.5 : (T)0.0; - T roi_start_w = offset_rois[1] * spatial_scale - offset; - T roi_start_h = offset_rois[2] * spatial_scale - offset; - T roi_end_w = offset_rois[3] * spatial_scale - offset; - T roi_end_h = offset_rois[4] * spatial_scale - offset; - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = max(roi_width, (T)1.); - roi_height = max(roi_height, (T)1.); - } - - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - // We need to index the gradient using the tensor strides to access the - // correct values. - const integer_t output_offset = n * n_stride + c * c_stride; - constant T* offset_grad_output = grad_output + output_offset; - const T grad_output_this_bin = - offset_grad_output[ph * h_stride + pw * w_stride]; - - // We use roi_bin_grid to sample the grid and mimic integral - integer_t roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 - integer_t roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - - // We do average (integral) pooling inside a bin - const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 - - const integer_t input_offset = (roi_batch_ind * channels + c) * height * width; - - for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 - { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - - T w1, w2, w3, w4; - integer_t x_low, x_high, y_low, y_high; - - bilinear_interpolate_gradient( - height, - width, - y, - x, - w1, - w2, - w3, - w4, - x_low, - x_high, - y_low, - y_high, - index); - - T g1 = grad_output_this_bin * w1 / count; - T g2 = grad_output_this_bin * w2 / count; - T g3 = grad_output_this_bin * w3 / count; - T g4 = grad_output_this_bin * w4 / count; - - if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - atomic_add_float(grad_input + input_offset + y_low * width + x_low, static_cast(g1)); - atomic_add_float(grad_input + input_offset + y_low * width + x_high, static_cast(g2)); - atomic_add_float(grad_input + input_offset + y_high * width + x_low, static_cast(g3)); - atomic_add_float(grad_input + input_offset + y_high * width + x_high, static_cast(g4)); - - } // if - } // ix - } // iy - } // MPS_1D_KERNEL_LOOP -} - -#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("roi_align_backward_" #DTYPE)]] \ -kernel void roi_align_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * grad_input [[buffer(2)]], \ - constant int64_t & output_size [[buffer(3)]], \ - constant int64_t & channels [[buffer(4)]], \ - constant int64_t & height [[buffer(5)]], \ - constant int64_t & width [[buffer(6)]], \ - constant int64_t & pooled_height [[buffer(7)]], \ - constant int64_t & pooled_width [[buffer(8)]], \ - constant int64_t & sampling_ratio [[buffer(9)]], \ - constant bool & aligned [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - constant int64_t & n_stride [[buffer(12)]], \ - constant int64_t & c_stride [[buffer(13)]], \ - constant int64_t & h_stride [[buffer(14)]], \ - constant int64_t & w_stride [[buffer(15)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void roi_pool( - constant T * input [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * output [[buffer(2)]], - device int64_t * argmax [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant float & spatial_scale [[buffer(10)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, c, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t c = (index / pooled_width / pooled_height) % channels; - integer_t n = index / pooled_width / pooled_height / channels; - - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - integer_t roi_start_w = round(offset_rois[1] * spatial_scale); - integer_t roi_start_h = round(offset_rois[2] * spatial_scale); - integer_t roi_end_w = round(offset_rois[3] * spatial_scale); - integer_t roi_end_h = round(offset_rois[4] * spatial_scale); - - // Force malformed ROIs to be 1x1 - integer_t roi_width = max(roi_end_w - roi_start_w + 1, static_cast(1)); - integer_t roi_height = max(roi_end_h - roi_start_h + 1, static_cast(1)); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - // Add roi offsets and clip to input boundaries - hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height)); - hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height)); - wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width)); - wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width)); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - // Define an empty pooling region to be zero - T maxval = is_empty ? 0 : -FLT_MAX; - // If nothing is pooled, argmax = -1 causes nothing to be backprop'd - integer_t maxidx = -1; - constant T* offset_input = - input + (roi_batch_ind * channels + c) * height * width; - for (integer_t h = hstart; h < hend; ++h) { - for (integer_t w = wstart; w < wend; ++w) { - integer_t input_index = h * width + w; - if (offset_input[input_index] > maxval) { - maxval = offset_input[input_index]; - maxidx = input_index; - } - } - } - output[index] = maxval; - argmax[index] = maxidx; - } -} - -#define REGISTER_ROI_POOL_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("roi_pool_" #DTYPE)]] \ -kernel void roi_pool( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - device int64_t * argmax_data [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant float & spatial_scale [[buffer(10)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void roi_pool_backward( - constant T * grad_output [[buffer(0)]], - constant T * rois [[buffer(1)]], - constant int64_t * argmax_data [[buffer(2)]], - device T * grad_input [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant float & spatial_scale [[buffer(10)]], - constant int64_t & n_stride [[buffer(11)]], - constant int64_t & c_stride [[buffer(12)]], - constant int64_t & h_stride [[buffer(13)]], - constant int64_t & w_stride [[buffer(14)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, c, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t c = (index / pooled_width / pooled_height) % channels; - integer_t n = index / pooled_width / pooled_height / channels; - - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - - const integer_t output_offset = n * n_stride + c * c_stride; - constant integer_t * argmax_data_offset = - argmax_data + (n * channels + c) * pooled_height * pooled_width; - const integer_t argmax = argmax_data_offset[ph * pooled_width + pw]; - const integer_t offset = (roi_batch_ind * channels + c) * height * width; - - if (argmax != -1) { - atomic_add_float(grad_input + offset + argmax, static_cast(grad_output[output_offset + ph * h_stride + pw * w_stride])); - } - - } // MPS_1D_KERNEL_LOOP -} - -#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("roi_pool_backward_" #DTYPE)]] \ -kernel void roi_pool_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - constant int64_t * argmax_data [[buffer(2)]], \ - device DTYPE * grad_input [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant float & spatial_scale [[buffer(10)]], \ - constant int64_t & n_stride [[buffer(11)]], \ - constant int64_t & c_stride [[buffer(12)]], \ - constant int64_t & h_stride [[buffer(13)]], \ - constant int64_t & w_stride [[buffer(14)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void ps_roi_align( - constant T * input [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * output [[buffer(2)]], - device int64_t * channel_mapping [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant int64_t & sampling_ratio [[buffer(10)]], - constant int64_t & channels_out [[buffer(11)]], - constant float & spatial_scale [[buffer(12)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, c_out, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t c_out = (index / pooled_width / pooled_height) % channels_out; - integer_t n = index / pooled_width / pooled_height / channels_out; - - // (n, c_in, ph, pw) is the associated element in the input - integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw; - - // [start, end) interval for spatial sampling - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); - T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); - T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); - T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - // Do not using floor/ceil; this implementation detail is critical - T hstart = static_cast(ph) * bin_size_h + roi_start_h; - T wstart = static_cast(pw) * bin_size_w + roi_start_w; - - // We use roi_bin_grid to sample the grid and mimic integral - integer_t roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); - integer_t roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - const T count = roi_bin_grid_h * roi_bin_grid_w; - - constant T* offset_input = - input + (roi_batch_ind * channels + c_in) * height * width; - T out_sum = 0; - for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = hstart + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); - for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = wstart + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - T val = bilinear_interpolate(offset_input, height, width, y, x, index); - out_sum += val; - } - } - - out_sum /= count; - output[index] = out_sum; - channel_mapping[index] = c_in; - } -} - -#define REGISTER_PS_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("ps_roi_align_" #DTYPE)]] \ -kernel void ps_roi_align( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - device int64_t * channel_mapping [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & sampling_ratio [[buffer(10)]], \ - constant int64_t & channels_out [[buffer(11)]], \ - constant float & spatial_scale [[buffer(12)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void ps_roi_align_backward( - constant T * grad_output [[buffer(0)]], - constant T * rois [[buffer(1)]], - constant int64_t * channel_mapping [[buffer(2)]], - device T * grad_input [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant int64_t & sampling_ratio [[buffer(10)]], - constant int64_t & channels_out [[buffer(11)]], - constant float & spatial_scale [[buffer(12)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, *, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t n = index / pooled_width / pooled_height / channels_out; - - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); - T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); - T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); - T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); - - // Force too small ROIs to be 1x1 - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - T bin_size_h = roi_height / static_cast(pooled_height); - T bin_size_w = roi_width / static_cast(pooled_width); - - integer_t c_in = channel_mapping[index]; - - // Do not using floor/ceil; this implementation detail is critical - T hstart = static_cast(ph) * bin_size_h + roi_start_h; - T wstart = static_cast(pw) * bin_size_w + roi_start_w; - - const T grad_output_this_bin = grad_output[index]; - - // We use roi_bin_grid to sample the grid and mimic integral - integer_t roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 - integer_t roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - const T count = roi_bin_grid_h * roi_bin_grid_w; - - const integer_t offset = (roi_batch_ind * channels + c_in) * height * width; - - for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = hstart + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); - for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = wstart + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - - T w1, w2, w3, w4; - integer_t x_low, x_high, y_low, y_high; - - bilinear_interpolate_gradient( - height, - width, - y, - x, - w1, - w2, - w3, - w4, - x_low, - x_high, - y_low, - y_high, - index); - - T g1 = grad_output_this_bin * w1 / count; - T g2 = grad_output_this_bin * w2 / count; - T g3 = grad_output_this_bin * w3 / count; - T g4 = grad_output_this_bin * w4 / count; - - if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - atomic_add_float(grad_input + offset + y_low * width + x_low, static_cast(g1)); - atomic_add_float(grad_input + offset + y_low * width + x_high, static_cast(g2)); - atomic_add_float(grad_input + offset + y_high * width + x_low, static_cast(g3)); - atomic_add_float(grad_input + offset + y_high * width + x_high, static_cast(g4)); - } // if - } // ix - } // iy - } -} - -#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("ps_roi_align_backward_" #DTYPE)]] \ -kernel void ps_roi_align_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - constant int64_t * channel_mapping [[buffer(2)]], \ - device DTYPE * grad_input [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & sampling_ratio [[buffer(10)]], \ - constant int64_t & channels_out [[buffer(11)]], \ - constant float & spatial_scale [[buffer(12)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void ps_roi_pool( - constant T * input [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * output [[buffer(2)]], - device int64_t * channel_mapping [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant int64_t & channels_out [[buffer(10)]], - constant float & spatial_scale [[buffer(11)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, c_out, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t c_out = (index / (pooled_width * pooled_height)) % channels_out; - integer_t n = index / pooled_width / pooled_height / channels_out; - - // (n, c_in, ph, pw) is the associated element in the input - integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw; - - // [start, end) interval for spatial sampling - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - integer_t roi_start_w = round(offset_rois[1] * spatial_scale); - integer_t roi_start_h = round(offset_rois[2] * spatial_scale); - integer_t roi_end_w = round(offset_rois[3] * spatial_scale); - integer_t roi_end_h = round(offset_rois[4] * spatial_scale); - - // Force too small ROIs to be 1x1 - integer_t roi_width = max(roi_end_w - roi_start_w, static_cast(1)); - integer_t roi_height = max(roi_end_h - roi_start_h, static_cast(1)); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - // Add roi offsets and clip to input boundaries - hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height - 1)); - hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height - 1)); - wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width - 1)); - wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width - 1)); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - constant T* offset_input = - input + (roi_batch_ind * channels + c_in) * height * width; - T out_sum = 0; - for (integer_t h = hstart; h < hend; ++h) { - for (integer_t w = wstart; w < wend; ++w) { - integer_t input_index = h * width + w; - out_sum += offset_input[input_index]; - } - } - - T bin_area = (hend - hstart) * (wend - wstart); - output[index] = is_empty ? static_cast(0) : out_sum / bin_area; - channel_mapping[index] = c_in; - } -} - -#define REGISTER_PS_ROI_POOL_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("ps_roi_pool_" #DTYPE)]] \ -kernel void ps_roi_pool( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - device int64_t * channel_mapping [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & channels_out [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void ps_roi_pool_backward( - constant T * grad_output [[buffer(0)]], - constant T * rois [[buffer(1)]], - constant int64_t * channel_mapping [[buffer(2)]], - device T * grad_input [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant int64_t & channels_out [[buffer(10)]], - constant float & spatial_scale [[buffer(11)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, *, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t n = index / pooled_width / pooled_height / channels_out; - - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - integer_t roi_start_w = round(offset_rois[1] * spatial_scale); - integer_t roi_start_h = round(offset_rois[2] * spatial_scale); - integer_t roi_end_w = round(offset_rois[3] * spatial_scale); - integer_t roi_end_h = round(offset_rois[4] * spatial_scale); - - // Force too small ROIs to be 1x1 - integer_t roi_width = max(roi_end_w - roi_start_w, static_cast(1)); - integer_t roi_height = max(roi_end_h - roi_start_h, static_cast(1)); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - // Add roi offsets and clip to input boundaries - hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height)); - hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height)); - wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width)); - wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width)); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - integer_t c_in = channel_mapping[index]; - T bin_area = (hend - hstart) * (wend - wstart); - T diff_val = is_empty ? static_cast(0) : grad_output[index] / bin_area; - - const integer_t offset = (roi_batch_ind * channels + c_in) * height * width; - - for (integer_t h = hstart; h < hend; ++h) { - for (integer_t w = wstart; w < wend; ++w) { - integer_t grad_input_index = h * width + w; - atomic_add_float(grad_input + offset + grad_input_index, diff_val); - } - } - - } // MPS_1D_KERNEL_LOOP -} - -#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("ps_roi_pool_backward_" #DTYPE)]] \ -kernel void ps_roi_pool_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - constant int64_t * channel_mapping [[buffer(2)]], \ - device DTYPE * grad_input [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & channels_out [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -REGISTER_NMS_OP(float); -REGISTER_NMS_OP(half); -REGISTER_ROI_ALIGN_OP(float, int64_t); -REGISTER_ROI_ALIGN_OP(half, int64_t); -REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t); -REGISTER_ROI_ALIGN_BACKWARD_OP(half, int64_t); -REGISTER_ROI_POOL_OP(float, int64_t); -REGISTER_ROI_POOL_OP(half, int64_t); -REGISTER_ROI_POOL_BACKWARD_OP(float, int64_t); -REGISTER_ROI_POOL_BACKWARD_OP(half, int64_t); -REGISTER_PS_ROI_ALIGN_OP(float, int64_t); -REGISTER_PS_ROI_ALIGN_OP(half, int64_t); -REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float, int64_t); -REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half, int64_t); -REGISTER_PS_ROI_POOL_OP(float, int64_t); -REGISTER_PS_ROI_POOL_OP(half, int64_t); -REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t); -REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t); - -)VISION_METAL"; - -static id compileVisionOpsLibrary(id device) { - static id visionLibrary = nil; - if (visionLibrary) { - return visionLibrary; - } - - NSError* error = nil; - MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; - [options setLanguageVersion:MTLLanguageVersion2_3]; - visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding] - options:options - error:&error]; - TORCH_CHECK(visionLibrary, "Failed to create metal vision library, error: ", [[error description] UTF8String]); - return visionLibrary; -} - -static id visionPipelineState(id device, const std::string& kernel) { - static std::unordered_map> psoCache; - id pso = psoCache[kernel]; - if (pso) { - return pso; - } - - NSError* error = nil; - id visionLib = compileVisionOpsLibrary(device); - id visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]]; - TORCH_CHECK(visionFunc, "Failed to create function state object for: ", kernel); - pso = [device newComputePipelineStateWithFunction:visionFunc error:&error]; - TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); - - psoCache[kernel] = pso; - return pso; -} - -} // namespace mps -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/mps/nms_kernel.mm b/product/include/torchvision/ops/mps/nms_kernel.mm deleted file mode 100644 index 5ee9b5cbeae..00000000000 --- a/product/include/torchvision/ops/mps/nms_kernel.mm +++ /dev/null @@ -1,109 +0,0 @@ -#include -#include -#include "mps_kernels.h" - -namespace vision { -namespace ops { - -namespace { - -// This should be in sync with `nmsThreadsPerBlock` in the metal kernel. -constexpr int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8; - -at::Tensor nms_kernel(const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { - using namespace at::native::mps; - TORCH_CHECK(dets.is_mps(), "dets must be a MPS tensor"); - TORCH_CHECK(scores.is_mps(), "scores must be a MPS tensor"); - - TORCH_CHECK(dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); - TORCH_CHECK(dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1)); - TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D"); - TORCH_CHECK(dets.size(0) == scores.size(0), - "boxes and scores should have same number of elements in ", - "dimension 0, got ", - dets.size(0), - " and ", - scores.size(0)) - - if (dets.numel() == 0) { - return at::empty({0}, dets.options().dtype(at::kLong)); - } - - auto order_t = std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); - auto dets_sorted = dets.index_select(0, order_t).contiguous(); - int64_t dets_num = dets.size(0); - float iou_threshold_f = static_cast(iou_threshold); - - const int col_blocks = (dets_num + nmsThreadsPerBlock - 1) / nmsThreadsPerBlock; - at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); - - id inputBuffer = getMTLBufferStorage(dets_sorted); - id outputBuffer = getMTLBufferStorage(mask); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake(col_blocks, col_blocks, 1); - - const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {dets, scores}); - - [computeEncoder setComputePipelineState:visionPSO]; - [computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0]; - [computeEncoder setBuffer:outputBuffer offset:mask.storage_offset() * mask.element_size() atIndex:1]; - [computeEncoder setBytes:&dets_num length:sizeof(int64_t) atIndex:2]; - [computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > nmsThreadsPerBlock) { - tgSize = nmsThreadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - - int64_t num_to_keep = 0; - - at::Tensor mask_cpu = mask.to(at::kCPU); - unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr(); - - std::vector remv(col_blocks); - memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); - - at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); - int64_t* keep_out = keep.data_ptr(); - - for (int64_t i = 0; i < dets_num; i++) { - int64_t nblock = i / nmsThreadsPerBlock; - int64_t inblock = i % nmsThreadsPerBlock; - - if (!(remv[nblock] & (1ULL << inblock))) { - keep_out[num_to_keep++] = i; - unsigned long long* p = mask_host + i * col_blocks; - for (int64_t j = nblock; j < col_blocks; j++) { - remv[j] |= p[j]; - } - } - } - - return order_t.index( - {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(order_t.device(), keep.scalar_type())}); -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/mps/ps_roi_align_kernel.mm b/product/include/torchvision/ops/mps/ps_roi_align_kernel.mm deleted file mode 100644 index 16b711ad5ef..00000000000 --- a/product/include/torchvision/ops/mps/ps_roi_align_kernel.mm +++ /dev/null @@ -1,205 +0,0 @@ -#include -#include -#include "mps_helpers.h" -#include "mps_kernels.h" - -namespace vision { -namespace ops { - -namespace { - -std::tuple ps_roi_align_forward_kernel(const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio) { - using namespace at::native::mps; - TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "ps_roi_align_forward_kernel"; - at::checkAllSameGPU(c, {input_t, rois_t}); - at::checkAllSameType(c, {input_t, rois_t}); - - int64_t num_rois = rois.size(0); - int64_t channels = input.size(1); - int64_t height = input.size(2); - int64_t width = input.size(3); - float spatial_scale_f = static_cast(spatial_scale); - - TORCH_CHECK(channels % (pooled_height * pooled_width) == 0, - "input channels must be a multiple of pooling height * pooling width"); - - int64_t channels_out = channels / (pooled_height * pooled_width); - - auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong)); - - int64_t output_size = output.numel(); - - if (output_size == 0) { - return std::make_tuple(output, channel_mapping); - } - - auto input_ = input.contiguous(); - auto rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(input_); - id roisBuffer = getMTLBufferStorage(rois_); - id outputBuffer = getMTLBufferStorage(output); - id channelMappingBuffer = getMTLBufferStorage(channel_mapping); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; - [computeEncoder setBuffer:channelMappingBuffer - offset:channel_mapping.storage_offset() * channel_mapping.element_size() - atIndex:3]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; - [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return std::make_tuple(output, channel_mapping); -} - -at::Tensor ps_roi_align_backward_kernel(const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - using namespace at::native::mps; - TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_align backward with float16 inputs."); - TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); - - at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, - channel_mapping_t{channel_mapping, "channel_mapping", 3}; - - at::CheckedFrom c = "ps_roi_align_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); - at::checkAllSameType(c, {grad_t, rois_t}); - - float spatial_scale_f = static_cast(spatial_scale); - - auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); - - if (grad.numel() == 0) { - return grad_input; - } - - int64_t output_size = grad.numel(); - int64_t channels_out = channels / (pooled_height * pooled_width); - - at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel"); - auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(grad_); - id roisBuffer = getMTLBufferStorage(rois_); - id channelMappingBuffer = getMTLBufferStorage(channel_mapping); - id outputBuffer = getMTLBufferStorage(grad_input); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:channelMappingBuffer - offset:channel_mapping.storage_offset() * channel_mapping.element_size() - atIndex:2]; - [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; - [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), TORCH_FN(ps_roi_align_forward_kernel)); - m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), TORCH_FN(ps_roi_align_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/mps/ps_roi_pool_kernel.mm b/product/include/torchvision/ops/mps/ps_roi_pool_kernel.mm deleted file mode 100644 index fc24f6990fa..00000000000 --- a/product/include/torchvision/ops/mps/ps_roi_pool_kernel.mm +++ /dev/null @@ -1,200 +0,0 @@ -#include -#include -#include "mps_helpers.h" -#include "mps_kernels.h" - -namespace vision { -namespace ops { - -namespace { - -std::tuple ps_roi_pool_forward_kernel(const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - using namespace at::native::mps; - TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "ps_roi_pool_forward_kernel"; - at::checkAllSameGPU(c, {input_t, rois_t}); - at::checkAllSameType(c, {input_t, rois_t}); - - int64_t num_rois = rois.size(0); - int64_t channels = input.size(1); - int64_t height = input.size(2); - int64_t width = input.size(3); - float spatial_scale_f = static_cast(spatial_scale); - - TORCH_CHECK(channels % (pooled_height * pooled_width) == 0, - "input channels must be a multiple of pooling height * pooling width"); - int64_t channels_out = channels / (pooled_height * pooled_width); - - auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong)); - auto output_size = output.numel(); - - if (output_size == 0) { - return std::make_tuple(output, channel_mapping); - } - - auto input_ = input.contiguous(); - auto rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(input_); - id roisBuffer = getMTLBufferStorage(rois_); - id outputBuffer = getMTLBufferStorage(output); - id channelMappingBuffer = getMTLBufferStorage(channel_mapping); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "ps_roi_pool_" + scalarToMetalTypeString(input.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; - [computeEncoder setBuffer:channelMappingBuffer - offset:channel_mapping.storage_offset() * channel_mapping.element_size() - atIndex:3]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return std::make_tuple(output, channel_mapping); -} - -at::Tensor ps_roi_pool_backward_kernel(const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - using namespace at::native::mps; - TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_pool backward with float16 inputs."); - TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, - channel_mapping_t{channel_mapping, "channel_mapping", 3}; - - at::CheckedFrom c = "ps_roi_pool_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); - at::checkAllSameType(c, {grad_t, rois_t}); - - float spatial_scale_f = static_cast(spatial_scale); - - auto num_rois = rois.size(0); - auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); - - if (grad.numel() == 0) { - return grad_input; - } - - int64_t channels_out = channels / (pooled_height * pooled_width); - int64_t output_size = grad.numel(); - - at::globalContext().alertNotDeterministic("ps_roi_pool_backward_kernel"); - auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(grad_); - id roisBuffer = getMTLBufferStorage(rois_); - id channelMappingBuffer = getMTLBufferStorage(channel_mapping); - id outputBuffer = getMTLBufferStorage(grad_input); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "ps_roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad_, rois_, channel_mapping}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:channelMappingBuffer - offset:channel_mapping.storage_offset() * channel_mapping.element_size() - atIndex:2]; - [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), TORCH_FN(ps_roi_pool_forward_kernel)); - m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), TORCH_FN(ps_roi_pool_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/mps/roi_align_kernel.mm b/product/include/torchvision/ops/mps/roi_align_kernel.mm deleted file mode 100644 index d4ed8b43fd2..00000000000 --- a/product/include/torchvision/ops/mps/roi_align_kernel.mm +++ /dev/null @@ -1,197 +0,0 @@ -#include -#include -#include "mps_helpers.h" -#include "mps_kernels.h" - -namespace vision { -namespace ops { - -namespace { - -at::Tensor roi_align_forward_kernel(const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned) { - using namespace at::native::mps; - TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_align_forward_kernel"; - at::checkAllSameGPU(c, {input_t, rois_t}); - at::checkAllSameType(c, {input_t, rois_t}); - - int64_t num_rois = rois.size(0); - int64_t channels = input.size(1); - int64_t height = input.size(2); - int64_t width = input.size(3); - float spatial_scale_f = static_cast(spatial_scale); - - at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); - - int64_t output_size = num_rois * pooled_height * pooled_width * channels; - - if (output.numel() == 0) { - return output; - } - - auto input_ = input.contiguous(); - auto rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(input_); - id roisBuffer = getMTLBufferStorage(rois_); - id outputBuffer = getMTLBufferStorage(output); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return output; -} - -at::Tensor roi_align_backward_kernel(const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned) { - using namespace at::native::mps; - TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_align backward with float16 inputs."); - - at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_align_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t}); - at::checkAllSameType(c, {grad_t, rois_t}); - - float spatial_scale_f = static_cast(spatial_scale); - - at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); - - if (grad.numel() == 0) { - return grad_input; - } - - int64_t n_stride = grad.stride(0); - int64_t c_stride = grad.stride(1); - int64_t h_stride = grad.stride(2); - int64_t w_stride = grad.stride(3); - int64_t output_size = grad.numel(); - - at::globalContext().alertNotDeterministic("roi_align_backward_kernel"); - auto rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(grad); - id roisBuffer = getMTLBufferStorage(rois_); - id outputBuffer = getMTLBufferStorage(grad_input); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; - [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12]; - [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13]; - [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14]; - [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_align"), TORCH_FN(roi_align_forward_kernel)); - m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), TORCH_FN(roi_align_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/mps/roi_pool_kernel.mm b/product/include/torchvision/ops/mps/roi_pool_kernel.mm deleted file mode 100644 index 816d8d70863..00000000000 --- a/product/include/torchvision/ops/mps/roi_pool_kernel.mm +++ /dev/null @@ -1,196 +0,0 @@ -#include -#include -#include "mps_helpers.h" -#include "mps_kernels.h" - -namespace vision { -namespace ops { - -namespace { - -std::tuple roi_pool_forward_kernel(const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - using namespace at::native::mps; - TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_pool_forward_kernel"; - at::checkAllSameGPU(c, {input_t, rois_t}); - at::checkAllSameType(c, {input_t, rois_t}); - - int64_t num_rois = rois.size(0); - int64_t channels = input.size(1); - int64_t height = input.size(2); - int64_t width = input.size(3); - float spatial_scale_f = static_cast(spatial_scale); - - at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); - at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kLong)); - - int64_t output_size = num_rois * pooled_height * pooled_width * channels; - - if (output.numel() == 0) { - return std::make_tuple(output, argmax); - } - - auto input_ = input.contiguous(); - auto rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(input_); - id roisBuffer = getMTLBufferStorage(rois_); - id outputBuffer = getMTLBufferStorage(output); - id argmaxBuffer = getMTLBufferStorage(argmax); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "roi_pool_" + scalarToMetalTypeString(input.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; - [computeEncoder setBuffer:argmaxBuffer offset:argmax.storage_offset() * argmax.element_size() atIndex:3]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return std::make_tuple(output, argmax); -} - -at::Tensor roi_pool_backward_kernel(const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - using namespace at::native::mps; - TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_pool backward with float16 inputs."); - TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor"); - - at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3}; - - at::CheckedFrom c = "roi_pool_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t}); - at::checkAllSameType(c, {grad_t, rois_t}); - - float spatial_scale_f = static_cast(spatial_scale); - - at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); - - if (grad.numel() == 0) { - return grad_input; - } - - int64_t n_stride = grad.stride(0); - int64_t c_stride = grad.stride(1); - int64_t h_stride = grad.stride(2); - int64_t w_stride = grad.stride(3); - int64_t output_size = grad.numel(); - - at::globalContext().alertNotDeterministic("roi_pool_backward_kernel"); - auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(grad); - id roisBuffer = getMTLBufferStorage(rois_); - id argmaxBuffer = getMTLBufferStorage(argmax_); - id outputBuffer = getMTLBufferStorage(grad_input); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_, argmax_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2]; - [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; - [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:11]; - [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:12]; - [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:13]; - [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_pool"), TORCH_FN(roi_pool_forward_kernel)); - m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), TORCH_FN(roi_pool_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/nms.cpp b/product/include/torchvision/ops/nms.cpp deleted file mode 100644 index 5ecf8812f1b..00000000000 --- a/product/include/torchvision/ops/nms.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include "nms.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -at::Tensor nms( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.nms.nms"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::nms", "") - .typed(); - return op.call(dets, scores, iou_threshold); -} - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.set_python_module("torchvision._meta_registrations"); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/nms.h b/product/include/torchvision/ops/nms.h deleted file mode 100644 index 8c75a242bff..00000000000 --- a/product/include/torchvision/ops/nms.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API at::Tensor nms( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold); - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/ops.h b/product/include/torchvision/ops/ops.h deleted file mode 100644 index 77995e44197..00000000000 --- a/product/include/torchvision/ops/ops.h +++ /dev/null @@ -1,8 +0,0 @@ -#pragma once - -#include "deform_conv2d.h" -#include "nms.h" -#include "ps_roi_align.h" -#include "ps_roi_pool.h" -#include "roi_align.h" -#include "roi_pool.h" diff --git a/product/include/torchvision/ops/ps_roi_align.cpp b/product/include/torchvision/ops/ps_roi_align.cpp deleted file mode 100644 index de458c0d62d..00000000000 --- a/product/include/torchvision/ops/ps_roi_align.cpp +++ /dev/null @@ -1,112 +0,0 @@ -#include "ps_roi_align.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -std::tuple ps_roi_align( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::ps_roi_align", "") - .typed(); - return op.call( - input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); -} - -std::tuple ps_roi_align_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::ps_roi_align", "") - .typed(); - return op.call( - input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); -} - -namespace detail { - -at::Tensor _ps_roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") - .typed(); - return op.call( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - batch_size, - channels, - height, - width); -} - -at::Tensor _ps_roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") - .typed(); - return op.call( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - batch_size, - channels, - height, - width); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio) -> (Tensor, Tensor)")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/ps_roi_align.h b/product/include/torchvision/ops/ps_roi_align.h deleted file mode 100644 index 75650586bc6..00000000000 --- a/product/include/torchvision/ops/ps_roi_align.h +++ /dev/null @@ -1,56 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API std::tuple ps_roi_align( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio); - -VISION_API std::tuple ps_roi_align_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio); - -namespace detail { - -at::Tensor _ps_roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -at::Tensor _ps_roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width); - -} // namespace detail - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/ps_roi_pool.cpp b/product/include/torchvision/ops/ps_roi_pool.cpp deleted file mode 100644 index 92469d5e380..00000000000 --- a/product/include/torchvision/ops/ps_roi_pool.cpp +++ /dev/null @@ -1,104 +0,0 @@ -#include "ps_roi_pool.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -std::tuple ps_roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::ps_roi_pool", "") - .typed(); - return op.call(input, rois, spatial_scale, pooled_height, pooled_width); -} - -std::tuple ps_roi_pool_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::ps_roi_pool", "") - .typed(); - return op.call(input, rois, spatial_scale, pooled_height, pooled_width); -} - -namespace detail { - -at::Tensor _ps_roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") - .typed(); - return op.call( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -} - -at::Tensor _ps_roi_pool_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") - .typed(); - return op.call( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/ps_roi_pool.h b/product/include/torchvision/ops/ps_roi_pool.h deleted file mode 100644 index 4a3cc54e0e5..00000000000 --- a/product/include/torchvision/ops/ps_roi_pool.h +++ /dev/null @@ -1,52 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API std::tuple ps_roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API std::tuple ps_roi_pool_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width); - -namespace detail { - -at::Tensor _ps_roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -at::Tensor _ps_roi_pool_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width); - -} // namespace detail - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/roi_align.cpp b/product/include/torchvision/ops/roi_align.cpp deleted file mode 100644 index aa6dccb44f2..00000000000 --- a/product/include/torchvision/ops/roi_align.cpp +++ /dev/null @@ -1,132 +0,0 @@ -#include "roi_align.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -at::Tensor roi_align( - const at::Tensor& input, // Input feature map. - const at::Tensor& rois, // List of ROIs to pool over. - double spatial_scale, // The scale of the image features. ROIs will be - // scaled to this. - int64_t pooled_height, // The height of the pooled feature map. - int64_t pooled_width, // The width of the pooled feature - int64_t sampling_ratio, // The number of points to sample in each bin - bool aligned) // The flag for pixel shift -// along each axis. -{ - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::roi_align", "") - .typed(); - return op.call( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned); -} - -at::Tensor roi_align_symint( - const at::Tensor& input, // Input feature map. - const at::Tensor& rois, // List of ROIs to pool over. - double spatial_scale, // The scale of the image features. ROIs will be - // scaled to this. - c10::SymInt pooled_height, // The height of the pooled feature map. - c10::SymInt pooled_width, // The width of the pooled feature - int64_t sampling_ratio, // The number of points to sample in each bin - bool aligned) // The flag for pixel shift -// along each axis. -{ - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::roi_align", "") - .typed(); - return op.call( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned); -} - -namespace detail { - -at::Tensor _roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_roi_align_backward", "") - .typed(); - return op.call( - grad, - rois, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width, - sampling_ratio, - aligned); -} - -at::Tensor _roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_roi_align_backward", "") - .typed(); - return op.call( - grad, - rois, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width, - sampling_ratio, - aligned); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width, int sampling_ratio, bool aligned) -> Tensor")); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/roi_align.h b/product/include/torchvision/ops/roi_align.h deleted file mode 100644 index 072d6d4231c..00000000000 --- a/product/include/torchvision/ops/roi_align.h +++ /dev/null @@ -1,58 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API at::Tensor roi_align( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned); - -VISION_API at::Tensor roi_align_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - bool aligned); - -namespace detail { - -at::Tensor _roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned); - -at::Tensor _roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned); - -} // namespace detail - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/roi_pool.cpp b/product/include/torchvision/ops/roi_pool.cpp deleted file mode 100644 index 20ca3ca91e7..00000000000 --- a/product/include/torchvision/ops/roi_pool.cpp +++ /dev/null @@ -1,102 +0,0 @@ -#include "roi_pool.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -std::tuple roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::roi_pool", "") - .typed(); - return op.call(input, rois, spatial_scale, pooled_height, pooled_width); -} - -std::tuple roi_pool_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::roi_pool", "") - .typed(); - return op.call(input, rois, spatial_scale, pooled_height, pooled_width); -} - -namespace detail { - -at::Tensor _roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_roi_pool_backward", "") - .typed(); - return op.call( - grad, - rois, - argmax, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -} - -at::Tensor _roi_pool_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_roi_pool_backward", "") - .typed(); - return op.call( - grad, - rois, - argmax, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); -} - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/ops/roi_pool.h b/product/include/torchvision/ops/roi_pool.h deleted file mode 100644 index e2133240f4f..00000000000 --- a/product/include/torchvision/ops/roi_pool.h +++ /dev/null @@ -1,52 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API std::tuple roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API std::tuple roi_pool_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width); - -namespace detail { - -at::Tensor _roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -at::Tensor _roi_pool_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width); - -} // namespace detail - -} // namespace ops -} // namespace vision diff --git a/product/include/torchvision/vision.cpp b/product/include/torchvision/vision.cpp deleted file mode 100644 index 806e870a83f..00000000000 --- a/product/include/torchvision/vision.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include "vision.h" - -#include - -#ifdef WITH_CUDA -#include -#endif -#ifdef WITH_HIP -#include -#endif - -// If we are in a Windows environment, we need to define -// initialization functions for the _custom_ops extension. -#if !defined(MOBILE) && defined(_WIN32) -void* PyInit__C(void) { - return nullptr; -} -#endif // !defined(MOBILE) && defined(_WIN32) - -namespace vision { -int64_t cuda_version() { -#ifdef WITH_CUDA - return CUDA_VERSION; -#else - return -1; -#endif -} - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def("_cuda_version", &cuda_version); -} -} // namespace vision diff --git a/product/include/torchvision/vision.h b/product/include/torchvision/vision.h deleted file mode 100644 index 651ef3ca143..00000000000 --- a/product/include/torchvision/vision.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include -#include "macros.h" - -namespace vision { -VISION_API int64_t cuda_version(); - -namespace detail { -extern "C" inline auto _register_ops = &cuda_version; -} // namespace detail -} // namespace vision diff --git a/product/share/cmake/TorchVision/TorchVisionConfig.cmake b/product/share/cmake/TorchVision/TorchVisionConfig.cmake deleted file mode 100644 index 57b2b6caab7..00000000000 --- a/product/share/cmake/TorchVision/TorchVisionConfig.cmake +++ /dev/null @@ -1,74 +0,0 @@ -# TorchVisionConfig.cmake -# -------------------- -# -# Exported targets:: Vision -# - - -####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() ####### -####### Any changes to this file will be overwritten by the next CMake run #### -####### The input file was TorchVisionConfig.cmake.in ######## - -get_filename_component(PACKAGE_${CMAKE_FIND_PACKAGE_NAME}_COUNTER_1 "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) - -macro(set_and_check _var _file) - set(${_var} "${_file}") - if(NOT EXISTS "${_file}") - message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !") - endif() -endmacro() - -macro(check_required_components _NAME) - foreach(comp ${${_NAME}_FIND_COMPONENTS}) - if(NOT ${_NAME}_${comp}_FOUND) - if(${_NAME}_FIND_REQUIRED_${comp}) - set(${_NAME}_FOUND FALSE) - endif() - endif() - endforeach() -endmacro() - -#################################################################################### - -set(PN TorchVision) - -# location of include/torchvision -set(${PN}_INCLUDE_DIR "${PACKAGE_PREFIX_DIR}/include") - -set(${PN}_LIBRARY "") -set(${PN}_DEFINITIONS USING_${PN}) - -check_required_components(${PN}) - - -if(NOT (CMAKE_VERSION VERSION_LESS 3.0)) -#----------------------------------------------------------------------------- -# Don't include targets if this file is being picked up by another -# project which has already built this as a subproject -#----------------------------------------------------------------------------- -if(NOT TARGET ${PN}::${PN}) -include("${CMAKE_CURRENT_LIST_DIR}/${PN}Targets.cmake") - -target_include_directories(${PN}::${PN} INTERFACE "${${PN}_INCLUDE_DIR}") - -if(OFF) - target_compile_definitions(${PN}::${PN} INTERFACE WITH_CUDA) -endif() - -find_package(Torch REQUIRED) -target_link_libraries(${PN}::${PN} INTERFACE torch) - -if(ON) - find_package(PNG REQUIRED) - target_link_libraries(${PN}::${PN} INTERFACE ${PNG_LIBRARY}) - target_compile_definitions(${PN}::${PN} INTERFACE PNG_FOUND) -endif() - -if(ON) - find_package(JPEG REQUIRED) - target_link_libraries(${PN}::${PN} INTERFACE ${JPEG_LIBRARIES}) - target_compile_definitions(${PN}::${PN} INTERFACE JPEG_FOUND) -endif() - -endif() -endif() diff --git a/product/share/cmake/TorchVision/TorchVisionConfigVersion.cmake b/product/share/cmake/TorchVision/TorchVisionConfigVersion.cmake deleted file mode 100644 index 94b7114a138..00000000000 --- a/product/share/cmake/TorchVision/TorchVisionConfigVersion.cmake +++ /dev/null @@ -1,43 +0,0 @@ -# This is a basic version file for the Config-mode of find_package(). -# It is used by write_basic_package_version_file() as input file for configure_file() -# to create a version-file which can be installed along a config.cmake file. -# -# The created file sets PACKAGE_VERSION_EXACT if the current version string and -# the requested version string are exactly the same and it sets -# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version. -# The variable CVF_VERSION must be set before calling configure_file(). - -set(PACKAGE_VERSION "0.20.0a0") - -if (PACKAGE_FIND_VERSION_RANGE) - # Package version must be in the requested version range - if ((PACKAGE_FIND_VERSION_RANGE_MIN STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MIN) - OR ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_GREATER PACKAGE_FIND_VERSION_MAX) - OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND PACKAGE_VERSION VERSION_GREATER_EQUAL PACKAGE_FIND_VERSION_MAX))) - set(PACKAGE_VERSION_COMPATIBLE FALSE) - else() - set(PACKAGE_VERSION_COMPATIBLE TRUE) - endif() -else() - if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION) - set(PACKAGE_VERSION_COMPATIBLE FALSE) - else() - set(PACKAGE_VERSION_COMPATIBLE TRUE) - if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION) - set(PACKAGE_VERSION_EXACT TRUE) - endif() - endif() -endif() - - -# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it: -if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "") - return() -endif() - -# check that the installed version has the same 32/64bit-ness as the one which is currently searching: -if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8") - math(EXPR installedBits "8 * 8") - set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)") - set(PACKAGE_VERSION_UNSUITABLE TRUE) -endif() diff --git a/product/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake b/product/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake deleted file mode 100644 index 91aa482bb9c..00000000000 --- a/product/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake +++ /dev/null @@ -1,20 +0,0 @@ -#---------------------------------------------------------------- -# Generated CMake target import file. -#---------------------------------------------------------------- - -# Commands may need to know the format version. -set(CMAKE_IMPORT_FILE_VERSION 1) - -# Import target "TorchVision::TorchVision" for configuration "" -set_property(TARGET TorchVision::TorchVision APPEND PROPERTY IMPORTED_CONFIGURATIONS NOCONFIG) -set_target_properties(TorchVision::TorchVision PROPERTIES - IMPORTED_LINK_DEPENDENT_LIBRARIES_NOCONFIG "torch" - IMPORTED_LOCATION_NOCONFIG "${_IMPORT_PREFIX}/lib/libtorchvision.dylib" - IMPORTED_SONAME_NOCONFIG "@rpath/libtorchvision.dylib" - ) - -list(APPEND _cmake_import_check_targets TorchVision::TorchVision ) -list(APPEND _cmake_import_check_files_for_TorchVision::TorchVision "${_IMPORT_PREFIX}/lib/libtorchvision.dylib" ) - -# Commands beyond this point should not need to know the version. -set(CMAKE_IMPORT_FILE_VERSION) diff --git a/product/share/cmake/TorchVision/TorchVisionTargets.cmake b/product/share/cmake/TorchVision/TorchVisionTargets.cmake deleted file mode 100644 index 1e07b7fc626..00000000000 --- a/product/share/cmake/TorchVision/TorchVisionTargets.cmake +++ /dev/null @@ -1,102 +0,0 @@ -# Generated by CMake - -if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8) - message(FATAL_ERROR "CMake >= 2.8.0 required") -endif() -if(CMAKE_VERSION VERSION_LESS "2.8.3") - message(FATAL_ERROR "CMake >= 2.8.3 required") -endif() -cmake_policy(PUSH) -cmake_policy(VERSION 2.8.3...3.27) -#---------------------------------------------------------------- -# Generated CMake target import file. -#---------------------------------------------------------------- - -# Commands may need to know the format version. -set(CMAKE_IMPORT_FILE_VERSION 1) - -# Protect against multiple inclusion, which would fail when already imported targets are added once more. -set(_cmake_targets_defined "") -set(_cmake_targets_not_defined "") -set(_cmake_expected_targets "") -foreach(_cmake_expected_target IN ITEMS TorchVision::TorchVision) - list(APPEND _cmake_expected_targets "${_cmake_expected_target}") - if(TARGET "${_cmake_expected_target}") - list(APPEND _cmake_targets_defined "${_cmake_expected_target}") - else() - list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}") - endif() -endforeach() -unset(_cmake_expected_target) -if(_cmake_targets_defined STREQUAL _cmake_expected_targets) - unset(_cmake_targets_defined) - unset(_cmake_targets_not_defined) - unset(_cmake_expected_targets) - unset(CMAKE_IMPORT_FILE_VERSION) - cmake_policy(POP) - return() -endif() -if(NOT _cmake_targets_defined STREQUAL "") - string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}") - string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}") - message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n") -endif() -unset(_cmake_targets_defined) -unset(_cmake_targets_not_defined) -unset(_cmake_expected_targets) - - -# Compute the installation prefix relative to this file. -get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH) -get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) -get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) -get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) -if(_IMPORT_PREFIX STREQUAL "/") - set(_IMPORT_PREFIX "") -endif() - -# Create imported target TorchVision::TorchVision -add_library(TorchVision::TorchVision SHARED IMPORTED) - -# Load information for each installed configuration. -file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/TorchVisionTargets-*.cmake") -foreach(_cmake_config_file IN LISTS _cmake_config_files) - include("${_cmake_config_file}") -endforeach() -unset(_cmake_config_file) -unset(_cmake_config_files) - -# Cleanup temporary variables. -set(_IMPORT_PREFIX) - -# Loop over all imported files and verify that they actually exist -foreach(_cmake_target IN LISTS _cmake_import_check_targets) - if(CMAKE_VERSION VERSION_LESS "3.28" - OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target} - OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}") - foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}") - if(NOT EXISTS "${_cmake_file}") - message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file - \"${_cmake_file}\" -but this file does not exist. Possible reasons include: -* The file was deleted, renamed, or moved to another location. -* An install or uninstall procedure did not complete successfully. -* The installation package was faulty and contained - \"${CMAKE_CURRENT_LIST_FILE}\" -but not all the files it references. -") - endif() - endforeach() - endif() - unset(_cmake_file) - unset("_cmake_import_check_files_for_${_cmake_target}") -endforeach() -unset(_cmake_target) -unset(_cmake_import_check_targets) - -# This file does not depend on other imported targets which have -# been exported from the same project but in a separate export set. - -# Commands beyond this point should not need to know the version. -set(CMAKE_IMPORT_FILE_VERSION) -cmake_policy(POP) diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm index 7295fc2caa5..3a4ee6624b5 100644 --- a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -1,7 +1,8 @@ // vision::ops:: -// deform_conv2d_kernal.mm +// deform_conv2d_kernel.mm // +#include #include #include #include @@ -40,7 +41,7 @@ void deformable_im2col(const at::Tensor& input, at::Tensor data_col) { using namespace at::native::mps; - // Validate tensors as of type mps. + // Validate tensors as of type mps. TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); TORCH_CHECK(data_offset.is_mps(), "data_offset must be a MPS tensor"); TORCH_CHECK(data_mask.is_mps(), "data_mask must be a MPS tensor"); From 970183d062c58a1079db8f4ed3b8c1504b85ef68 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Sat, 16 Nov 2024 13:52:05 +0100 Subject: [PATCH 08/31] Removing framework dir and included files. --- framework/.DS_Store | Bin 6148 -> 0 bytes framework/include/.DS_Store | Bin 6148 -> 0 bytes framework/include/torchvision/.DS_Store | Bin 6148 -> 0 bytes .../torchvision/io/image/cpu/common_jpeg.cpp | 26 - .../torchvision/io/image/cpu/common_jpeg.h | 27 - .../torchvision/io/image/cpu/common_png.h | 6 - .../torchvision/io/image/cpu/decode_image.cpp | 41 - .../torchvision/io/image/cpu/decode_image.h | 15 - .../torchvision/io/image/cpu/decode_jpeg.cpp | 271 ---- .../torchvision/io/image/cpu/decode_jpeg.h | 18 - .../torchvision/io/image/cpu/decode_png.cpp | 259 ---- .../torchvision/io/image/cpu/decode_png.h | 16 - .../torchvision/io/image/cpu/encode_jpeg.cpp | 113 -- .../torchvision/io/image/cpu/encode_jpeg.h | 13 - .../torchvision/io/image/cpu/encode_png.cpp | 180 --- .../torchvision/io/image/cpu/encode_png.h | 13 - .../include/torchvision/io/image/cpu/exif.h | 264 ---- .../io/image/cpu/read_write_file.cpp | 108 -- .../io/image/cpu/read_write_file.h | 13 - .../io/image/cuda/decode_jpeg_cuda.cpp | 208 --- .../io/image/cuda/decode_jpeg_cuda.h | 15 - .../include/torchvision/io/image/image.cpp | 39 - .../include/torchvision/io/image/image.h | 9 - .../torchvision/io/image/image_read_mode.h | 17 - framework/include/torchvision/macros.h | 22 - framework/include/torchvision/ops/.DS_Store | Bin 6148 -> 0 bytes .../ops/autograd/deform_conv2d_kernel.cpp | 266 ---- .../ops/autograd/ps_roi_align_kernel.cpp | 167 --- .../ops/autograd/ps_roi_pool_kernel.cpp | 152 --- .../ops/autograd/roi_align_kernel.cpp | 167 --- .../ops/autograd/roi_pool_kernel.cpp | 152 --- .../ops/cpu/deform_conv2d_kernel.cpp | 1172 ----------------- .../torchvision/ops/cpu/nms_kernel.cpp | 117 -- .../ops/cpu/ps_roi_align_kernel.cpp | 429 ------ .../ops/cpu/ps_roi_pool_kernel.cpp | 273 ---- .../torchvision/ops/cpu/roi_align_common.h | 128 -- .../torchvision/ops/cpu/roi_align_kernel.cpp | 400 ------ .../torchvision/ops/cpu/roi_pool_kernel.cpp | 249 ---- .../include/torchvision/ops/deform_conv2d.cpp | 172 --- .../include/torchvision/ops/deform_conv2d.h | 82 -- .../include/torchvision/ops/mps/mps_helpers.h | 6 - .../include/torchvision/ops/mps/mps_kernels.h | 1102 ---------------- .../include/torchvision/ops/mps/nms_kernel.mm | 109 -- .../ops/mps/ps_roi_align_kernel.mm | 205 --- .../torchvision/ops/mps/ps_roi_pool_kernel.mm | 200 --- .../torchvision/ops/mps/roi_align_kernel.mm | 197 --- .../torchvision/ops/mps/roi_pool_kernel.mm | 196 --- framework/include/torchvision/ops/nms.cpp | 27 - framework/include/torchvision/ops/nms.h | 15 - framework/include/torchvision/ops/ops.h | 8 - .../include/torchvision/ops/ps_roi_align.cpp | 112 -- .../include/torchvision/ops/ps_roi_align.h | 56 - .../include/torchvision/ops/ps_roi_pool.cpp | 104 -- .../include/torchvision/ops/ps_roi_pool.h | 52 - .../include/torchvision/ops/roi_align.cpp | 132 -- framework/include/torchvision/ops/roi_align.h | 58 - .../include/torchvision/ops/roi_pool.cpp | 102 -- framework/include/torchvision/ops/roi_pool.h | 52 - framework/include/torchvision/vision.cpp | 41 - framework/include/torchvision/vision.h | 16 - .../cmake/TorchVision/TorchVisionConfig.cmake | 82 -- .../TorchVisionConfigVersion.cmake | 43 - .../TorchVisionTargets-noconfig.cmake | 20 - .../TorchVision/TorchVisionTargets.cmake | 102 -- 64 files changed, 8656 deletions(-) delete mode 100644 framework/.DS_Store delete mode 100644 framework/include/.DS_Store delete mode 100644 framework/include/torchvision/.DS_Store delete mode 100644 framework/include/torchvision/io/image/cpu/common_jpeg.cpp delete mode 100644 framework/include/torchvision/io/image/cpu/common_jpeg.h delete mode 100644 framework/include/torchvision/io/image/cpu/common_png.h delete mode 100644 framework/include/torchvision/io/image/cpu/decode_image.cpp delete mode 100644 framework/include/torchvision/io/image/cpu/decode_image.h delete mode 100644 framework/include/torchvision/io/image/cpu/decode_jpeg.cpp delete mode 100644 framework/include/torchvision/io/image/cpu/decode_jpeg.h delete mode 100644 framework/include/torchvision/io/image/cpu/decode_png.cpp delete mode 100644 framework/include/torchvision/io/image/cpu/decode_png.h delete mode 100644 framework/include/torchvision/io/image/cpu/encode_jpeg.cpp delete mode 100644 framework/include/torchvision/io/image/cpu/encode_jpeg.h delete mode 100644 framework/include/torchvision/io/image/cpu/encode_png.cpp delete mode 100644 framework/include/torchvision/io/image/cpu/encode_png.h delete mode 100644 framework/include/torchvision/io/image/cpu/exif.h delete mode 100644 framework/include/torchvision/io/image/cpu/read_write_file.cpp delete mode 100644 framework/include/torchvision/io/image/cpu/read_write_file.h delete mode 100644 framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.cpp delete mode 100644 framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.h delete mode 100644 framework/include/torchvision/io/image/image.cpp delete mode 100644 framework/include/torchvision/io/image/image.h delete mode 100644 framework/include/torchvision/io/image/image_read_mode.h delete mode 100644 framework/include/torchvision/macros.h delete mode 100644 framework/include/torchvision/ops/.DS_Store delete mode 100644 framework/include/torchvision/ops/autograd/deform_conv2d_kernel.cpp delete mode 100644 framework/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp delete mode 100644 framework/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp delete mode 100644 framework/include/torchvision/ops/autograd/roi_align_kernel.cpp delete mode 100644 framework/include/torchvision/ops/autograd/roi_pool_kernel.cpp delete mode 100644 framework/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp delete mode 100644 framework/include/torchvision/ops/cpu/nms_kernel.cpp delete mode 100644 framework/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp delete mode 100644 framework/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp delete mode 100644 framework/include/torchvision/ops/cpu/roi_align_common.h delete mode 100644 framework/include/torchvision/ops/cpu/roi_align_kernel.cpp delete mode 100644 framework/include/torchvision/ops/cpu/roi_pool_kernel.cpp delete mode 100644 framework/include/torchvision/ops/deform_conv2d.cpp delete mode 100644 framework/include/torchvision/ops/deform_conv2d.h delete mode 100644 framework/include/torchvision/ops/mps/mps_helpers.h delete mode 100644 framework/include/torchvision/ops/mps/mps_kernels.h delete mode 100644 framework/include/torchvision/ops/mps/nms_kernel.mm delete mode 100644 framework/include/torchvision/ops/mps/ps_roi_align_kernel.mm delete mode 100644 framework/include/torchvision/ops/mps/ps_roi_pool_kernel.mm delete mode 100644 framework/include/torchvision/ops/mps/roi_align_kernel.mm delete mode 100644 framework/include/torchvision/ops/mps/roi_pool_kernel.mm delete mode 100644 framework/include/torchvision/ops/nms.cpp delete mode 100644 framework/include/torchvision/ops/nms.h delete mode 100644 framework/include/torchvision/ops/ops.h delete mode 100644 framework/include/torchvision/ops/ps_roi_align.cpp delete mode 100644 framework/include/torchvision/ops/ps_roi_align.h delete mode 100644 framework/include/torchvision/ops/ps_roi_pool.cpp delete mode 100644 framework/include/torchvision/ops/ps_roi_pool.h delete mode 100644 framework/include/torchvision/ops/roi_align.cpp delete mode 100644 framework/include/torchvision/ops/roi_align.h delete mode 100644 framework/include/torchvision/ops/roi_pool.cpp delete mode 100644 framework/include/torchvision/ops/roi_pool.h delete mode 100644 framework/include/torchvision/vision.cpp delete mode 100644 framework/include/torchvision/vision.h delete mode 100644 framework/share/cmake/TorchVision/TorchVisionConfig.cmake delete mode 100644 framework/share/cmake/TorchVision/TorchVisionConfigVersion.cmake delete mode 100644 framework/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake delete mode 100644 framework/share/cmake/TorchVision/TorchVisionTargets.cmake diff --git a/framework/.DS_Store b/framework/.DS_Store deleted file mode 100644 index d0ccba84354fbb07f0ee1c0ad907b69509dc98c5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHLJ5Iwu5S>i|F@ho`r3+e$w1HxYOiPmsKw?CKY#b4#qVc7;0tI&fS&0=jUGC8czm-StP4 zqLFQmzo-DeyDrV>fhLqu<^ARJaJ!vFx0#U_NixogF(T3D>gxUZ?B!xxR{D#q`IJ|S zvh9B`X&ZMkqp_?bjSSw7H<2p;{ZsO){92b68{THL#y40o-ySV!N;mZQy>^w8PR{Gm zD(`8w=B?CMxwmK3_{lTBnJHijm;$?206m*6Iux|o6fgx$fwcnseTXp_y<#aCJ{_3C z7690RI~b06FTpjjqE{>h5rH`=1xl&YEryeF_+!oUilv~GlMBqajxsylp}4>fe;m@u zc|n^^0aKu^KwD0G-2X3s-v8H&?93D}1^$%+u9J+CAs$J4YvbX#*Txvv7;KzZ3a(3V j5Lz*ExfS!jHNoJI`2gq@OF?*G_Cp}ZV3R4ZQw6>NB|~WX diff --git a/framework/include/.DS_Store b/framework/include/.DS_Store deleted file mode 100644 index 784dc5a7f8d658bf3890dc86f21b072059f4b66d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKOG*Pl5UtWd2C~W0WnUp{H*FYCkfocD1dNb4A%50<5s%>sgsgJ`?~+%aNP?pv zf{0W>)vM{Q>Zy4!-Cab)!^hcxs82*SRFK8#5E<^AI`QBN$hyXyp6HhDp>4-Pe{o1^ z@6iM5UDK2f>%Uvv3`Wy-xze($>(kF8mu=Hb7VQKP_4DiI?d4)~`Bn7n8_}!zK^13P zEhbDb5DWwZ!9XzZ0|szri_|^Cu)#nu5DdICAp1jt3TDS*s9Oh&wg5o6MytS=UP5w` zV|FZtn1QgR0xgxj#b8Uvc=EXHSPU(l*qaabH}9Jl*0*E*q~XNbFl;ao3>-5c9WA6U z=l>Ocna(C(L&5|D!N7lIfU9QI4Doe-wto0dIcpPi3>A^MA`S%p(Io&6vX7kWq|GPs Y5tki{p{yd~nhuPMfD#fW82AMScC|Y_XCp zEkekaW}e5Naq^_h#6-l)hh{-EC!z{NkVTmgF;BW?7A#JdJ#JFlsqOjFe19F}*=KZ5 zX>+?+9Xx-wx2dkzo3ut;RXx`4Pv@@}`;=+a9_PB4syy&^mi)>NRTvb7kj>97Zji-x_TrW5DN*vBgGn-|X2VGlW+ zI4b(=3^)U01`1uy<^I3ICo|dPk5hc*3^)V-i~%m|RlUSZ+1+~ddUDqWjB5-LiR(py mKp*`CU?JzoX?ChVh>o~u*el8|V$bP7{}IT9_~Z=y0t27H*+Jp} diff --git a/framework/include/torchvision/io/image/cpu/common_jpeg.cpp b/framework/include/torchvision/io/image/cpu/common_jpeg.cpp deleted file mode 100644 index 4c993106b45..00000000000 --- a/framework/include/torchvision/io/image/cpu/common_jpeg.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include "common_jpeg.h" - -namespace vision { -namespace image { -namespace detail { - -#if JPEG_FOUND -void torch_jpeg_error_exit(j_common_ptr cinfo) { - /* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce - * pointer */ - torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err; - - /* Always display the message. */ - /* We could postpone this until after returning, if we chose. */ - // (*cinfo->err->output_message)(cinfo); - /* Create the message */ - (*(cinfo->err->format_message))(cinfo, myerr->jpegLastErrorMsg); - - /* Return control to the setjmp point */ - longjmp(myerr->setjmp_buffer, 1); -} -#endif - -} // namespace detail -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/common_jpeg.h b/framework/include/torchvision/io/image/cpu/common_jpeg.h deleted file mode 100644 index 7f7f9f0ccf1..00000000000 --- a/framework/include/torchvision/io/image/cpu/common_jpeg.h +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#if JPEG_FOUND -#include - -#include -#include - -namespace vision { -namespace image { -namespace detail { - -static const JOCTET EOI_BUFFER[1] = {JPEG_EOI}; -struct torch_jpeg_error_mgr { - struct jpeg_error_mgr pub; /* "public" fields */ - char jpegLastErrorMsg[JMSG_LENGTH_MAX]; /* error messages */ - jmp_buf setjmp_buffer; /* for return to caller */ -}; - -using torch_jpeg_error_ptr = struct torch_jpeg_error_mgr*; -void torch_jpeg_error_exit(j_common_ptr cinfo); - -} // namespace detail -} // namespace image -} // namespace vision - -#endif diff --git a/framework/include/torchvision/io/image/cpu/common_png.h b/framework/include/torchvision/io/image/cpu/common_png.h deleted file mode 100644 index 68400d48e05..00000000000 --- a/framework/include/torchvision/io/image/cpu/common_png.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once - -#if PNG_FOUND -#include -#include -#endif diff --git a/framework/include/torchvision/io/image/cpu/decode_image.cpp b/framework/include/torchvision/io/image/cpu/decode_image.cpp deleted file mode 100644 index dbf349b06ca..00000000000 --- a/framework/include/torchvision/io/image/cpu/decode_image.cpp +++ /dev/null @@ -1,41 +0,0 @@ -#include "decode_image.h" - -#include "decode_jpeg.h" -#include "decode_png.h" - -namespace vision { -namespace image { - -torch::Tensor decode_image( - const torch::Tensor& data, - ImageReadMode mode, - bool apply_exif_orientation) { - // Check that tensor is a CPU tensor - TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor"); - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - // Check that the input tensor is 1-dimensional - TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); - - auto datap = data.data_ptr(); - - const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF" - const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" - - if (memcmp(jpeg_signature, datap, 3) == 0) { - return decode_jpeg(data, mode, apply_exif_orientation); - } else if (memcmp(png_signature, datap, 4) == 0) { - return decode_png( - data, mode, /*allow_16_bits=*/false, apply_exif_orientation); - } else { - TORCH_CHECK( - false, - "Unsupported image file. Only jpeg and png ", - "are currently supported."); - } -} - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/decode_image.h b/framework/include/torchvision/io/image/cpu/decode_image.h deleted file mode 100644 index f0e66d397ac..00000000000 --- a/framework/include/torchvision/io/image/cpu/decode_image.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include -#include "../image_read_mode.h" - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor decode_image( - const torch::Tensor& data, - ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, - bool apply_exif_orientation = false); - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/decode_jpeg.cpp b/framework/include/torchvision/io/image/cpu/decode_jpeg.cpp deleted file mode 100644 index ec5953e4106..00000000000 --- a/framework/include/torchvision/io/image/cpu/decode_jpeg.cpp +++ /dev/null @@ -1,271 +0,0 @@ -#include "decode_jpeg.h" -#include "common_jpeg.h" -#include "exif.h" - -namespace vision { -namespace image { - -#if !JPEG_FOUND -torch::Tensor decode_jpeg( - const torch::Tensor& data, - ImageReadMode mode, - bool apply_exif_orientation) { - TORCH_CHECK( - false, "decode_jpeg: torchvision not compiled with libjpeg support"); -} -#else - -using namespace detail; -using namespace exif_private; - -namespace { - -struct torch_jpeg_mgr { - struct jpeg_source_mgr pub; - const JOCTET* data; - size_t len; -}; - -static void torch_jpeg_init_source(j_decompress_ptr cinfo) {} - -static boolean torch_jpeg_fill_input_buffer(j_decompress_ptr cinfo) { - // No more data. Probably an incomplete image; Raise exception. - torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err; - strcpy(myerr->jpegLastErrorMsg, "Image is incomplete or truncated"); - longjmp(myerr->setjmp_buffer, 1); -} - -static void torch_jpeg_skip_input_data(j_decompress_ptr cinfo, long num_bytes) { - torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src; - if (src->pub.bytes_in_buffer < (size_t)num_bytes) { - // Skipping over all of remaining data; output EOI. - src->pub.next_input_byte = EOI_BUFFER; - src->pub.bytes_in_buffer = 1; - } else { - // Skipping over only some of the remaining data. - src->pub.next_input_byte += num_bytes; - src->pub.bytes_in_buffer -= num_bytes; - } -} - -static void torch_jpeg_term_source(j_decompress_ptr cinfo) {} - -static void torch_jpeg_set_source_mgr( - j_decompress_ptr cinfo, - const unsigned char* data, - size_t len) { - torch_jpeg_mgr* src; - if (cinfo->src == 0) { // if this is first time; allocate memory - cinfo->src = (struct jpeg_source_mgr*)(*cinfo->mem->alloc_small)( - (j_common_ptr)cinfo, JPOOL_PERMANENT, sizeof(torch_jpeg_mgr)); - } - src = (torch_jpeg_mgr*)cinfo->src; - src->pub.init_source = torch_jpeg_init_source; - src->pub.fill_input_buffer = torch_jpeg_fill_input_buffer; - src->pub.skip_input_data = torch_jpeg_skip_input_data; - src->pub.resync_to_restart = jpeg_resync_to_restart; // default - src->pub.term_source = torch_jpeg_term_source; - // fill the buffers - src->data = (const JOCTET*)data; - src->len = len; - src->pub.bytes_in_buffer = len; - src->pub.next_input_byte = src->data; - - jpeg_save_markers(cinfo, APP1, 0xffff); -} - -inline unsigned char clamped_cmyk_rgb_convert( - unsigned char k, - unsigned char cmy) { - // Inspired from Pillow: - // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569 - int v = k * cmy + 128; - v = ((v >> 8) + v) >> 8; - return std::clamp(k - v, 0, 255); -} - -void convert_line_cmyk_to_rgb( - j_decompress_ptr cinfo, - const unsigned char* cmyk_line, - unsigned char* rgb_line) { - int width = cinfo->output_width; - for (int i = 0; i < width; ++i) { - int c = cmyk_line[i * 4 + 0]; - int m = cmyk_line[i * 4 + 1]; - int y = cmyk_line[i * 4 + 2]; - int k = cmyk_line[i * 4 + 3]; - - rgb_line[i * 3 + 0] = clamped_cmyk_rgb_convert(k, 255 - c); - rgb_line[i * 3 + 1] = clamped_cmyk_rgb_convert(k, 255 - m); - rgb_line[i * 3 + 2] = clamped_cmyk_rgb_convert(k, 255 - y); - } -} - -inline unsigned char rgb_to_gray(int r, int g, int b) { - // Inspired from Pillow: - // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226 - return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16; -} - -void convert_line_cmyk_to_gray( - j_decompress_ptr cinfo, - const unsigned char* cmyk_line, - unsigned char* gray_line) { - int width = cinfo->output_width; - for (int i = 0; i < width; ++i) { - int c = cmyk_line[i * 4 + 0]; - int m = cmyk_line[i * 4 + 1]; - int y = cmyk_line[i * 4 + 2]; - int k = cmyk_line[i * 4 + 3]; - - int r = clamped_cmyk_rgb_convert(k, 255 - c); - int g = clamped_cmyk_rgb_convert(k, 255 - m); - int b = clamped_cmyk_rgb_convert(k, 255 - y); - - gray_line[i] = rgb_to_gray(r, g, b); - } -} - -} // namespace - -torch::Tensor decode_jpeg( - const torch::Tensor& data, - ImageReadMode mode, - bool apply_exif_orientation) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg"); - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - // Check that the input tensor is 1-dimensional - TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); - - struct jpeg_decompress_struct cinfo; - struct torch_jpeg_error_mgr jerr; - - auto datap = data.data_ptr(); - // Setup decompression structure - cinfo.err = jpeg_std_error(&jerr.pub); - jerr.pub.error_exit = torch_jpeg_error_exit; - /* Establish the setjmp return context for my_error_exit to use. */ - if (setjmp(jerr.setjmp_buffer)) { - /* If we get here, the JPEG code has signaled an error. - * We need to clean up the JPEG object. - */ - jpeg_destroy_decompress(&cinfo); - TORCH_CHECK(false, jerr.jpegLastErrorMsg); - } - - jpeg_create_decompress(&cinfo); - torch_jpeg_set_source_mgr(&cinfo, datap, data.numel()); - - // read info from header. - jpeg_read_header(&cinfo, TRUE); - - int channels = cinfo.num_components; - bool cmyk_to_rgb_or_gray = false; - - if (mode != IMAGE_READ_MODE_UNCHANGED) { - switch (mode) { - case IMAGE_READ_MODE_GRAY: - if (cinfo.jpeg_color_space == JCS_CMYK || - cinfo.jpeg_color_space == JCS_YCCK) { - cinfo.out_color_space = JCS_CMYK; - cmyk_to_rgb_or_gray = true; - } else { - cinfo.out_color_space = JCS_GRAYSCALE; - } - channels = 1; - break; - case IMAGE_READ_MODE_RGB: - if (cinfo.jpeg_color_space == JCS_CMYK || - cinfo.jpeg_color_space == JCS_YCCK) { - cinfo.out_color_space = JCS_CMYK; - cmyk_to_rgb_or_gray = true; - } else { - cinfo.out_color_space = JCS_RGB; - } - channels = 3; - break; - /* - * Libjpeg does not support converting from CMYK to grayscale etc. There - * is a way to do this but it involves converting it manually to RGB: - * https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313 - */ - default: - jpeg_destroy_decompress(&cinfo); - TORCH_CHECK(false, "The provided mode is not supported for JPEG files"); - } - - jpeg_calc_output_dimensions(&cinfo); - } - - int exif_orientation = -1; - if (apply_exif_orientation) { - exif_orientation = fetch_jpeg_exif_orientation(&cinfo); - } - - jpeg_start_decompress(&cinfo); - - int height = cinfo.output_height; - int width = cinfo.output_width; - - int stride = width * channels; - auto tensor = - torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); - auto ptr = tensor.data_ptr(); - torch::Tensor cmyk_line_tensor; - if (cmyk_to_rgb_or_gray) { - cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8); - } - - while (cinfo.output_scanline < cinfo.output_height) { - /* jpeg_read_scanlines expects an array of pointers to scanlines. - * Here the array is only one element long, but you could ask for - * more than one scanline at a time if that's more convenient. - */ - if (cmyk_to_rgb_or_gray) { - auto cmyk_line_ptr = cmyk_line_tensor.data_ptr(); - jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1); - - if (channels == 3) { - convert_line_cmyk_to_rgb(&cinfo, cmyk_line_ptr, ptr); - } else if (channels == 1) { - convert_line_cmyk_to_gray(&cinfo, cmyk_line_ptr, ptr); - } - } else { - jpeg_read_scanlines(&cinfo, &ptr, 1); - } - ptr += stride; - } - - jpeg_finish_decompress(&cinfo); - jpeg_destroy_decompress(&cinfo); - auto output = tensor.permute({2, 0, 1}); - - if (apply_exif_orientation) { - return exif_orientation_transform(output, exif_orientation); - } - return output; -} -#endif // #if !JPEG_FOUND - -int64_t _jpeg_version() { -#if JPEG_FOUND - return JPEG_LIB_VERSION; -#else - return -1; -#endif -} - -bool _is_compiled_against_turbo() { -#ifdef LIBJPEG_TURBO_VERSION - return true; -#else - return false; -#endif -} - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/decode_jpeg.h b/framework/include/torchvision/io/image/cpu/decode_jpeg.h deleted file mode 100644 index e0c9a24c846..00000000000 --- a/framework/include/torchvision/io/image/cpu/decode_jpeg.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include -#include "../image_read_mode.h" - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor decode_jpeg( - const torch::Tensor& data, - ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, - bool apply_exif_orientation = false); - -C10_EXPORT int64_t _jpeg_version(); -C10_EXPORT bool _is_compiled_against_turbo(); - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/decode_png.cpp b/framework/include/torchvision/io/image/cpu/decode_png.cpp deleted file mode 100644 index ab4087fdfe2..00000000000 --- a/framework/include/torchvision/io/image/cpu/decode_png.cpp +++ /dev/null @@ -1,259 +0,0 @@ -#include "decode_png.h" -#include "common_png.h" -#include "exif.h" - -namespace vision { -namespace image { - -using namespace exif_private; - -#if !PNG_FOUND -torch::Tensor decode_png( - const torch::Tensor& data, - ImageReadMode mode, - bool allow_16_bits, - bool apply_exif_orientation) { - TORCH_CHECK( - false, "decode_png: torchvision not compiled with libPNG support"); -} -#else - -bool is_little_endian() { - uint32_t x = 1; - return *(uint8_t*)&x; -} - -torch::Tensor decode_png( - const torch::Tensor& data, - ImageReadMode mode, - bool allow_16_bits, - bool apply_exif_orientation) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png"); - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - // Check that the input tensor is 1-dimensional - TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); - - auto png_ptr = - png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); - TORCH_CHECK(png_ptr, "libpng read structure allocation failed!") - auto info_ptr = png_create_info_struct(png_ptr); - if (!info_ptr) { - png_destroy_read_struct(&png_ptr, nullptr, nullptr); - // Seems redundant with the if statement. done here to avoid leaking memory. - TORCH_CHECK(info_ptr, "libpng info structure allocation failed!") - } - - auto accessor = data.accessor(); - auto datap = accessor.data(); - auto datap_len = accessor.size(0); - - if (setjmp(png_jmpbuf(png_ptr)) != 0) { - png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, "Internal error."); - } - TORCH_CHECK(datap_len >= 8, "Content is too small for png!") - auto is_png = !png_sig_cmp(datap, 0, 8); - TORCH_CHECK(is_png, "Content is not png!") - - struct Reader { - png_const_bytep ptr; - png_size_t count; - } reader; - reader.ptr = png_const_bytep(datap) + 8; - reader.count = datap_len - 8; - - auto read_callback = [](png_structp png_ptr, - png_bytep output, - png_size_t bytes) { - auto reader = static_cast(png_get_io_ptr(png_ptr)); - TORCH_CHECK( - reader->count >= bytes, - "Out of bound read in decode_png. Probably, the input image is corrupted"); - std::copy(reader->ptr, reader->ptr + bytes, output); - reader->ptr += bytes; - reader->count -= bytes; - }; - png_set_sig_bytes(png_ptr, 8); - png_set_read_fn(png_ptr, &reader, read_callback); - png_read_info(png_ptr, info_ptr); - - png_uint_32 width, height; - int bit_depth, color_type; - int interlace_type; - auto retval = png_get_IHDR( - png_ptr, - info_ptr, - &width, - &height, - &bit_depth, - &color_type, - &interlace_type, - nullptr, - nullptr); - - if (retval != 1) { - png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(retval == 1, "Could read image metadata from content.") - } - - auto max_bit_depth = allow_16_bits ? 16 : 8; - auto err_msg = "At most " + std::to_string(max_bit_depth) + - "-bit PNG images are supported currently."; - if (bit_depth > max_bit_depth) { - png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, err_msg) - } - - int channels = png_get_channels(png_ptr, info_ptr); - - if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8) - png_set_expand_gray_1_2_4_to_8(png_ptr); - - int number_of_passes; - if (interlace_type == PNG_INTERLACE_ADAM7) { - number_of_passes = png_set_interlace_handling(png_ptr); - } else { - number_of_passes = 1; - } - - if (mode != IMAGE_READ_MODE_UNCHANGED) { - // TODO: consider supporting PNG_INFO_tRNS - bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0; - bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0; - bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0; - - switch (mode) { - case IMAGE_READ_MODE_GRAY: - if (color_type != PNG_COLOR_TYPE_GRAY) { - if (is_palette) { - png_set_palette_to_rgb(png_ptr); - has_alpha = true; - } - - if (has_alpha) { - png_set_strip_alpha(png_ptr); - } - - if (has_color) { - png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); - } - channels = 1; - } - break; - case IMAGE_READ_MODE_GRAY_ALPHA: - if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) { - if (is_palette) { - png_set_palette_to_rgb(png_ptr); - has_alpha = true; - } - - if (!has_alpha) { - png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); - } - - if (has_color) { - png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); - } - channels = 2; - } - break; - case IMAGE_READ_MODE_RGB: - if (color_type != PNG_COLOR_TYPE_RGB) { - if (is_palette) { - png_set_palette_to_rgb(png_ptr); - has_alpha = true; - } else if (!has_color) { - png_set_gray_to_rgb(png_ptr); - } - - if (has_alpha) { - png_set_strip_alpha(png_ptr); - } - channels = 3; - } - break; - case IMAGE_READ_MODE_RGB_ALPHA: - if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) { - if (is_palette) { - png_set_palette_to_rgb(png_ptr); - has_alpha = true; - } else if (!has_color) { - png_set_gray_to_rgb(png_ptr); - } - - if (!has_alpha) { - png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); - } - channels = 4; - } - break; - default: - png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, "The provided mode is not supported for PNG files"); - } - - png_read_update_info(png_ptr, info_ptr); - } - - auto num_pixels_per_row = width * channels; - auto tensor = torch::empty( - {int64_t(height), int64_t(width), channels}, - bit_depth <= 8 ? torch::kU8 : torch::kI32); - - if (bit_depth <= 8) { - auto t_ptr = tensor.accessor().data(); - for (int pass = 0; pass < number_of_passes; pass++) { - for (png_uint_32 i = 0; i < height; ++i) { - png_read_row(png_ptr, t_ptr, nullptr); - t_ptr += num_pixels_per_row; - } - t_ptr = tensor.accessor().data(); - } - } else { - // We're reading a 16bits png, but pytorch doesn't support uint16. - // So we read each row in a 16bits tmp_buffer which we then cast into - // a int32 tensor instead. - if (is_little_endian()) { - png_set_swap(png_ptr); - } - int32_t* t_ptr = tensor.accessor().data(); - - // We create a tensor instead of malloc-ing for automatic memory management - auto tmp_buffer_tensor = torch::empty( - {int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8); - uint16_t* tmp_buffer = - (uint16_t*)tmp_buffer_tensor.accessor().data(); - - for (int pass = 0; pass < number_of_passes; pass++) { - for (png_uint_32 i = 0; i < height; ++i) { - png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr); - // Now we copy the uint16 values into the int32 tensor. - for (size_t j = 0; j < num_pixels_per_row; ++j) { - t_ptr[j] = (int32_t)tmp_buffer[j]; - } - t_ptr += num_pixels_per_row; - } - t_ptr = tensor.accessor().data(); - } - } - - int exif_orientation = -1; - if (apply_exif_orientation) { - exif_orientation = fetch_png_exif_orientation(png_ptr, info_ptr); - } - - png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - - auto output = tensor.permute({2, 0, 1}); - if (apply_exif_orientation) { - return exif_orientation_transform(output, exif_orientation); - } - return output; -} -#endif - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/decode_png.h b/framework/include/torchvision/io/image/cpu/decode_png.h deleted file mode 100644 index b091f15e35f..00000000000 --- a/framework/include/torchvision/io/image/cpu/decode_png.h +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include -#include "../image_read_mode.h" - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor decode_png( - const torch::Tensor& data, - ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, - bool allow_16_bits = false, - bool apply_exif_orientation = false); - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/encode_jpeg.cpp b/framework/include/torchvision/io/image/cpu/encode_jpeg.cpp deleted file mode 100644 index d2ed73071a2..00000000000 --- a/framework/include/torchvision/io/image/cpu/encode_jpeg.cpp +++ /dev/null @@ -1,113 +0,0 @@ -#include "encode_jpeg.h" - -#include "common_jpeg.h" - -namespace vision { -namespace image { - -#if !JPEG_FOUND - -torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { - TORCH_CHECK( - false, "encode_jpeg: torchvision not compiled with libjpeg support"); -} - -#else -// For libjpeg version <= 9b, the out_size parameter in jpeg_mem_dest() is -// defined as unsigned long, whereas in later version, it is defined as size_t. -#if !defined(JPEG_LIB_VERSION_MAJOR) || JPEG_LIB_VERSION_MAJOR < 9 || \ - (JPEG_LIB_VERSION_MAJOR == 9 && JPEG_LIB_VERSION_MINOR <= 2) -using JpegSizeType = unsigned long; -#else -using JpegSizeType = size_t; -#endif - -using namespace detail; - -torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.encode_jpeg.encode_jpeg"); - // Define compression structures and error handling - struct jpeg_compress_struct cinfo {}; - struct torch_jpeg_error_mgr jerr {}; - - // Define buffer to write JPEG information to and its size - JpegSizeType jpegSize = 0; - uint8_t* jpegBuf = nullptr; - - cinfo.err = jpeg_std_error(&jerr.pub); - jerr.pub.error_exit = torch_jpeg_error_exit; - - /* Establish the setjmp return context for my_error_exit to use. */ - if (setjmp(jerr.setjmp_buffer)) { - /* If we get here, the JPEG code has signaled an error. - * We need to clean up the JPEG object and the buffer. - */ - jpeg_destroy_compress(&cinfo); - if (jpegBuf != nullptr) { - free(jpegBuf); - } - - TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg); - } - - // Check that the input tensor is on CPU - TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); - - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); - - // Check that the input tensor is 3-dimensional - TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); - - // Get image info - int channels = data.size(0); - int height = data.size(1); - int width = data.size(2); - auto input = data.permute({1, 2, 0}).contiguous(); - - TORCH_CHECK( - channels == 1 || channels == 3, - "The number of channels should be 1 or 3, got: ", - channels); - - // Initialize JPEG structure - jpeg_create_compress(&cinfo); - - // Set output image information - cinfo.image_width = width; - cinfo.image_height = height; - cinfo.input_components = channels; - cinfo.in_color_space = channels == 1 ? JCS_GRAYSCALE : JCS_RGB; - - jpeg_set_defaults(&cinfo); - jpeg_set_quality(&cinfo, quality, TRUE); - - // Save JPEG output to a buffer - jpeg_mem_dest(&cinfo, &jpegBuf, &jpegSize); - - // Start JPEG compression - jpeg_start_compress(&cinfo, TRUE); - - auto stride = width * channels; - auto ptr = input.data_ptr(); - - // Encode JPEG file - while (cinfo.next_scanline < cinfo.image_height) { - jpeg_write_scanlines(&cinfo, &ptr, 1); - ptr += stride; - } - - jpeg_finish_compress(&cinfo); - jpeg_destroy_compress(&cinfo); - - torch::TensorOptions options = torch::TensorOptions{torch::kU8}; - auto out_tensor = - torch::from_blob(jpegBuf, {(long)jpegSize}, ::free, options); - jpegBuf = nullptr; - return out_tensor; -} -#endif - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/encode_jpeg.h b/framework/include/torchvision/io/image/cpu/encode_jpeg.h deleted file mode 100644 index 25084e154d6..00000000000 --- a/framework/include/torchvision/io/image/cpu/encode_jpeg.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor encode_jpeg( - const torch::Tensor& data, - int64_t quality); - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/encode_png.cpp b/framework/include/torchvision/io/image/cpu/encode_png.cpp deleted file mode 100644 index a9b7d76ff61..00000000000 --- a/framework/include/torchvision/io/image/cpu/encode_png.cpp +++ /dev/null @@ -1,180 +0,0 @@ -#include "encode_jpeg.h" - -#include "common_png.h" - -namespace vision { -namespace image { - -#if !PNG_FOUND - -torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { - TORCH_CHECK( - false, "encode_png: torchvision not compiled with libpng support"); -} - -#else - -namespace { - -struct torch_mem_encode { - char* buffer; - size_t size; -}; - -struct torch_png_error_mgr { - const char* pngLastErrorMsg; /* error messages */ - jmp_buf setjmp_buffer; /* for return to caller */ -}; - -using torch_png_error_mgr_ptr = torch_png_error_mgr*; - -void torch_png_error(png_structp png_ptr, png_const_charp error_msg) { - /* png_ptr->err really points to a torch_png_error_mgr struct, so coerce - * pointer */ - auto error_ptr = (torch_png_error_mgr_ptr)png_get_error_ptr(png_ptr); - /* Replace the error message on the error structure */ - error_ptr->pngLastErrorMsg = error_msg; - /* Return control to the setjmp point */ - longjmp(error_ptr->setjmp_buffer, 1); -} - -void torch_png_write_data( - png_structp png_ptr, - png_bytep data, - png_size_t length) { - struct torch_mem_encode* p = - (struct torch_mem_encode*)png_get_io_ptr(png_ptr); - size_t nsize = p->size + length; - - /* allocate or grow buffer */ - if (p->buffer) - p->buffer = (char*)realloc(p->buffer, nsize); - else - p->buffer = (char*)malloc(nsize); - - if (!p->buffer) - png_error(png_ptr, "Write Error"); - - /* copy new bytes to end of buffer */ - memcpy(p->buffer + p->size, data, length); - p->size += length; -} - -} // namespace - -torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.encode_png.encode_png"); - // Define compression structures and error handling - png_structp png_write; - png_infop info_ptr; - struct torch_png_error_mgr err_ptr; - - // Define output buffer - struct torch_mem_encode buf_info; - buf_info.buffer = NULL; - buf_info.size = 0; - - /* Establish the setjmp return context for my_error_exit to use. */ - if (setjmp(err_ptr.setjmp_buffer)) { - /* If we get here, the PNG code has signaled an error. - * We need to clean up the PNG object and the buffer. - */ - if (info_ptr != NULL) { - png_destroy_info_struct(png_write, &info_ptr); - } - - if (png_write != NULL) { - png_destroy_write_struct(&png_write, NULL); - } - - if (buf_info.buffer != NULL) { - free(buf_info.buffer); - } - - TORCH_CHECK(false, err_ptr.pngLastErrorMsg); - } - - // Check that the compression level is between 0 and 9 - TORCH_CHECK( - compression_level >= 0 && compression_level <= 9, - "Compression level should be between 0 and 9"); - - // Check that the input tensor is on CPU - TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); - - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); - - // Check that the input tensor is 3-dimensional - TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor"); - - // Get image info - int channels = data.size(0); - int height = data.size(1); - int width = data.size(2); - auto input = data.permute({1, 2, 0}).contiguous(); - - TORCH_CHECK( - channels == 1 || channels == 3, - "The number of channels should be 1 or 3, got: ", - channels); - - // Initialize PNG structures - png_write = png_create_write_struct( - PNG_LIBPNG_VER_STRING, &err_ptr, torch_png_error, NULL); - - info_ptr = png_create_info_struct(png_write); - - // Define custom buffer output - png_set_write_fn(png_write, &buf_info, torch_png_write_data, NULL); - - // Set output image information - auto color_type = channels == 1 ? PNG_COLOR_TYPE_GRAY : PNG_COLOR_TYPE_RGB; - png_set_IHDR( - png_write, - info_ptr, - width, - height, - 8, - color_type, - PNG_INTERLACE_NONE, - PNG_COMPRESSION_TYPE_DEFAULT, - PNG_FILTER_TYPE_DEFAULT); - - // Set image compression level - png_set_compression_level(png_write, compression_level); - - // Write file header - png_write_info(png_write, info_ptr); - - auto stride = width * channels; - auto ptr = input.data_ptr(); - - // Encode PNG file - for (int y = 0; y < height; ++y) { - png_write_row(png_write, ptr); - ptr += stride; - } - - // Write EOF - png_write_end(png_write, info_ptr); - - // Destroy structures - png_destroy_write_struct(&png_write, &info_ptr); - - torch::TensorOptions options = torch::TensorOptions{torch::kU8}; - auto outTensor = torch::empty({(long)buf_info.size}, options); - - // Copy memory from png buffer, since torch cannot get ownership of it via - // `from_blob` - auto outPtr = outTensor.data_ptr(); - std::memcpy(outPtr, buf_info.buffer, sizeof(uint8_t) * outTensor.numel()); - free(buf_info.buffer); - - return outTensor; -} - -#endif - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/encode_png.h b/framework/include/torchvision/io/image/cpu/encode_png.h deleted file mode 100644 index 86a67c8706e..00000000000 --- a/framework/include/torchvision/io/image/cpu/encode_png.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor encode_png( - const torch::Tensor& data, - int64_t compression_level); - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/exif.h b/framework/include/torchvision/io/image/cpu/exif.h deleted file mode 100644 index 0f9a59417db..00000000000 --- a/framework/include/torchvision/io/image/cpu/exif.h +++ /dev/null @@ -1,264 +0,0 @@ -/*M/////////////////////////////////////////////////////////////////////////////////////// -// -// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. -// -// By downloading, copying, installing or using the software you agree to this -license. -// If you do not agree to this license, do not download, install, -// copy or use the software. -// -// -// License Agreement -// For Open Source Computer Vision Library -// -// Copyright (C) 2000-2008, Intel Corporation, all rights reserved. -// Copyright (C) 2009, Willow Garage Inc., all rights reserved. -// Third party copyrights are property of their respective owners. -// -// Redistribution and use in source and binary forms, with or without -modification, -// are permitted provided that the following conditions are met: -// -// * Redistribution's of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// -// * Redistribution's in binary form must reproduce the above copyright -notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// * The name of the copyright holders may not be used to endorse or promote -products -// derived from this software without specific prior written permission. -// -// This software is provided by the copyright holders and contributors "as is" -and -// any express or implied warranties, including, but not limited to, the implied -// warranties of merchantability and fitness for a particular purpose are -disclaimed. -// In no event shall the Intel Corporation or contributors be liable for any -direct, -// indirect, incidental, special, exemplary, or consequential damages -// (including, but not limited to, procurement of substitute goods or services; -// loss of use, data, or profits; or business interruption) however caused -// and on any theory of liability, whether in contract, strict liability, -// or tort (including negligence or otherwise) arising in any way out of -// the use of this software, even if advised of the possibility of such damage. -// -//M*/ -#pragma once -// Functions in this module are taken from OpenCV -// https://github.com/opencv/opencv/blob/097891e311fae1d8354eb092a0fd0171e630d78c/modules/imgcodecs/src/exif.cpp - -#if JPEG_FOUND -#include -#endif -#if PNG_FOUND -#include -#endif - -#include - -namespace vision { -namespace image { -namespace exif_private { - -constexpr uint16_t APP1 = 0xe1; -constexpr uint16_t ENDIANNESS_INTEL = 0x49; -constexpr uint16_t ENDIANNESS_MOTO = 0x4d; -constexpr uint16_t REQ_EXIF_TAG_MARK = 0x2a; -constexpr uint16_t ORIENTATION_EXIF_TAG = 0x0112; -constexpr uint16_t INCORRECT_TAG = -1; - -class ExifDataReader { - public: - ExifDataReader(unsigned char* p, size_t s) : _ptr(p), _size(s) {} - size_t size() const { - return _size; - } - const unsigned char& operator[](size_t index) const { - TORCH_CHECK(index >= 0 && index < _size); - return _ptr[index]; - } - - protected: - unsigned char* _ptr; - size_t _size; -}; - -inline uint16_t get_endianness(const ExifDataReader& exif_data) { - if ((exif_data.size() < 1) || - (exif_data.size() > 1 && exif_data[0] != exif_data[1])) { - return 0; - } - if (exif_data[0] == 'I') { - return ENDIANNESS_INTEL; - } - if (exif_data[0] == 'M') { - return ENDIANNESS_MOTO; - } - return 0; -} - -inline uint16_t get_uint16( - const ExifDataReader& exif_data, - uint16_t endianness, - const size_t offset) { - if (offset + 1 >= exif_data.size()) { - return INCORRECT_TAG; - } - - if (endianness == ENDIANNESS_INTEL) { - return exif_data[offset] + (exif_data[offset + 1] << 8); - } - return (exif_data[offset] << 8) + exif_data[offset + 1]; -} - -inline uint32_t get_uint32( - const ExifDataReader& exif_data, - uint16_t endianness, - const size_t offset) { - if (offset + 3 >= exif_data.size()) { - return INCORRECT_TAG; - } - - if (endianness == ENDIANNESS_INTEL) { - return exif_data[offset] + (exif_data[offset + 1] << 8) + - (exif_data[offset + 2] << 16) + (exif_data[offset + 3] << 24); - } - return (exif_data[offset] << 24) + (exif_data[offset + 1] << 16) + - (exif_data[offset + 2] << 8) + exif_data[offset + 3]; -} - -inline int fetch_exif_orientation(unsigned char* exif_data_ptr, size_t size) { - int exif_orientation = -1; - - // Exif binary structure looks like this - // First 6 bytes: [E, x, i, f, 0, 0] - // Endianness, 2 bytes : [M, M] or [I, I] - // Tag mark, 2 bytes: [0, 0x2a] - // Offset, 4 bytes - // Num entries, 2 bytes - // Tag entries and data, tag has 2 bytes and its data has 10 bytes - // For more details: - // http://www.media.mit.edu/pia/Research/deepview/exif.html - - ExifDataReader exif_data(exif_data_ptr, size); - auto endianness = get_endianness(exif_data); - - // Checking whether Tag Mark (0x002A) correspond to one contained in the - // Jpeg file - uint16_t tag_mark = get_uint16(exif_data, endianness, 2); - if (tag_mark == REQ_EXIF_TAG_MARK) { - auto offset = get_uint32(exif_data, endianness, 4); - size_t num_entry = get_uint16(exif_data, endianness, offset); - offset += 2; // go to start of tag fields - constexpr size_t tiff_field_size = 12; - for (size_t entry = 0; entry < num_entry; entry++) { - // Here we just search for orientation tag and parse it - auto tag_num = get_uint16(exif_data, endianness, offset); - if (tag_num == INCORRECT_TAG) { - break; - } - if (tag_num == ORIENTATION_EXIF_TAG) { - exif_orientation = get_uint16(exif_data, endianness, offset + 8); - break; - } - offset += tiff_field_size; - } - } - return exif_orientation; -} - -#if JPEG_FOUND -inline int fetch_jpeg_exif_orientation(j_decompress_ptr cinfo) { - // Check for Exif marker APP1 - jpeg_saved_marker_ptr exif_marker = 0; - jpeg_saved_marker_ptr cmarker = cinfo->marker_list; - while (cmarker && exif_marker == 0) { - if (cmarker->marker == APP1) { - exif_marker = cmarker; - } - cmarker = cmarker->next; - } - - if (!exif_marker) { - return -1; - } - - constexpr size_t start_offset = 6; - if (exif_marker->data_length <= start_offset) { - return -1; - } - - auto* exif_data_ptr = exif_marker->data + start_offset; - auto size = exif_marker->data_length - start_offset; - - return fetch_exif_orientation(exif_data_ptr, size); -} -#else // #if JPEG_FOUND -inline int fetch_jpeg_exif_orientation(j_decompress_ptr cinfo) { - return -1; -} -#endif // #if JPEG_FOUND - -#if PNG_FOUND && defined(PNG_eXIf_SUPPORTED) -inline int fetch_png_exif_orientation(png_structp png_ptr, png_infop info_ptr) { - png_uint_32 num_exif = 0; - png_bytep exif = 0; - - // Exif info could be in info_ptr - if (png_get_valid(png_ptr, info_ptr, PNG_INFO_eXIf)) { - png_get_eXIf_1(png_ptr, info_ptr, &num_exif, &exif); - } - - if (exif && num_exif > 0) { - return fetch_exif_orientation(exif, num_exif); - } - return -1; -} -#else // #if PNG_FOUND && defined(PNG_eXIf_SUPPORTED) -inline int fetch_png_exif_orientation(png_structp png_ptr, png_infop info_ptr) { - return -1; -} -#endif // #if PNG_FOUND && defined(PNG_eXIf_SUPPORTED) - -constexpr uint16_t IMAGE_ORIENTATION_TL = 1; // normal orientation -constexpr uint16_t IMAGE_ORIENTATION_TR = 2; // needs horizontal flip -constexpr uint16_t IMAGE_ORIENTATION_BR = 3; // needs 180 rotation -constexpr uint16_t IMAGE_ORIENTATION_BL = 4; // needs vertical flip -constexpr uint16_t IMAGE_ORIENTATION_LT = - 5; // mirrored horizontal & rotate 270 CW -constexpr uint16_t IMAGE_ORIENTATION_RT = 6; // rotate 90 CW -constexpr uint16_t IMAGE_ORIENTATION_RB = - 7; // mirrored horizontal & rotate 90 CW -constexpr uint16_t IMAGE_ORIENTATION_LB = 8; // needs 270 CW rotation - -inline torch::Tensor exif_orientation_transform( - const torch::Tensor& image, - int orientation) { - if (orientation == IMAGE_ORIENTATION_TL) { - return image; - } else if (orientation == IMAGE_ORIENTATION_TR) { - return image.flip(-1); - } else if (orientation == IMAGE_ORIENTATION_BR) { - // needs 180 rotation equivalent to - // flip both horizontally and vertically - return image.flip({-2, -1}); - } else if (orientation == IMAGE_ORIENTATION_BL) { - return image.flip(-2); - } else if (orientation == IMAGE_ORIENTATION_LT) { - return image.transpose(-1, -2); - } else if (orientation == IMAGE_ORIENTATION_RT) { - return image.transpose(-1, -2).flip(-1); - } else if (orientation == IMAGE_ORIENTATION_RB) { - return image.transpose(-1, -2).flip({-2, -1}); - } else if (orientation == IMAGE_ORIENTATION_LB) { - return image.transpose(-1, -2).flip(-2); - } - return image; -} - -} // namespace exif_private -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/read_write_file.cpp b/framework/include/torchvision/io/image/cpu/read_write_file.cpp deleted file mode 100644 index def74c6721a..00000000000 --- a/framework/include/torchvision/io/image/cpu/read_write_file.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include "read_write_file.h" - -#include - -#ifdef _WIN32 -#define WIN32_LEAN_AND_MEAN -#include -#endif - -namespace vision { -namespace image { - -#ifdef _WIN32 -namespace { -std::wstring utf8_decode(const std::string& str) { - if (str.empty()) { - return std::wstring(); - } - int size_needed = MultiByteToWideChar( - CP_UTF8, 0, str.c_str(), static_cast(str.size()), NULL, 0); - TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode"); - std::wstring wstrTo(size_needed, 0); - MultiByteToWideChar( - CP_UTF8, - 0, - str.c_str(), - static_cast(str.size()), - &wstrTo[0], - size_needed); - return wstrTo; -} -} // namespace -#endif - -torch::Tensor read_file(const std::string& filename) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.read_write_file.read_file"); -#ifdef _WIN32 - // According to - // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019, - // we should use struct __stat64 and _wstat64 for 64-bit file size on Windows. - struct __stat64 stat_buf; - auto fileW = utf8_decode(filename); - int rc = _wstat64(fileW.c_str(), &stat_buf); -#else - struct stat stat_buf; - int rc = stat(filename.c_str(), &stat_buf); -#endif - // errno is a variable defined in errno.h - TORCH_CHECK( - rc == 0, "[Errno ", errno, "] ", strerror(errno), ": '", filename, "'"); - - int64_t size = stat_buf.st_size; - - TORCH_CHECK(size > 0, "Expected a non empty file"); - -#ifdef _WIN32 - // TODO: Once torch::from_file handles UTF-8 paths correctly, we should move - // back to use the following implementation since it uses file mapping. - // auto data = - // torch::from_file(filename, /*shared=*/false, /*size=*/size, - // torch::kU8).clone() - FILE* infile = _wfopen(fileW.c_str(), L"rb"); - - TORCH_CHECK(infile != nullptr, "Error opening input file"); - - auto data = torch::empty({size}, torch::kU8); - auto dataBytes = data.data_ptr(); - - fread(dataBytes, sizeof(uint8_t), size, infile); - fclose(infile); -#else - auto data = - torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8); -#endif - - return data; -} - -void write_file(const std::string& filename, torch::Tensor& data) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cpu.read_write_file.write_file"); - // Check that the input tensor is on CPU - TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); - - // Check that the input tensor dtype is uint8 - TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8"); - - // Check that the input tensor is 3-dimensional - TORCH_CHECK(data.dim() == 1, "Input data should be a 1-dimensional tensor"); - - auto fileBytes = data.data_ptr(); - auto fileCStr = filename.c_str(); -#ifdef _WIN32 - auto fileW = utf8_decode(filename); - FILE* outfile = _wfopen(fileW.c_str(), L"wb"); -#else - FILE* outfile = fopen(fileCStr, "wb"); -#endif - - TORCH_CHECK(outfile != nullptr, "Error opening output file"); - - fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile); - fclose(outfile); -} - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cpu/read_write_file.h b/framework/include/torchvision/io/image/cpu/read_write_file.h deleted file mode 100644 index a5a712dd8e2..00000000000 --- a/framework/include/torchvision/io/image/cpu/read_write_file.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor read_file(const std::string& filename); - -C10_EXPORT void write_file(const std::string& filename, torch::Tensor& data); - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.cpp b/framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.cpp deleted file mode 100644 index ee7d432f30d..00000000000 --- a/framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.cpp +++ /dev/null @@ -1,208 +0,0 @@ -#include "decode_jpeg_cuda.h" - -#include - -#if NVJPEG_FOUND -#include -#include -#include -#endif - -#include - -namespace vision { -namespace image { - -#if !NVJPEG_FOUND - -torch::Tensor decode_jpeg_cuda( - const torch::Tensor& data, - ImageReadMode mode, - torch::Device device) { - TORCH_CHECK( - false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support"); -} - -#else - -namespace { -static nvjpegHandle_t nvjpeg_handle = nullptr; -} - -torch::Tensor decode_jpeg_cuda( - const torch::Tensor& data, - ImageReadMode mode, - torch::Device device) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cuda.decode_jpeg_cuda.decode_jpeg_cuda"); - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - - TORCH_CHECK( - !data.is_cuda(), - "The input tensor must be on CPU when decoding with nvjpeg") - - TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); - - TORCH_CHECK(device.is_cuda(), "Expected a cuda device") - - int major_version; - int minor_version; - nvjpegStatus_t get_major_property_status = - nvjpegGetProperty(MAJOR_VERSION, &major_version); - nvjpegStatus_t get_minor_property_status = - nvjpegGetProperty(MINOR_VERSION, &minor_version); - - TORCH_CHECK( - get_major_property_status == NVJPEG_STATUS_SUCCESS, - "nvjpegGetProperty failed: ", - get_major_property_status); - TORCH_CHECK( - get_minor_property_status == NVJPEG_STATUS_SUCCESS, - "nvjpegGetProperty failed: ", - get_minor_property_status); - if ((major_version < 11) || ((major_version == 11) && (minor_version < 6))) { - TORCH_WARN_ONCE( - "There is a memory leak issue in the nvjpeg library for CUDA versions < 11.6. " - "Make sure to rely on CUDA 11.6 or above before using decode_jpeg(..., device='cuda')."); - } - - at::cuda::CUDAGuard device_guard(device); - - // Create global nvJPEG handle - static std::once_flag nvjpeg_handle_creation_flag; - std::call_once(nvjpeg_handle_creation_flag, []() { - if (nvjpeg_handle == nullptr) { - nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); - - if (create_status != NVJPEG_STATUS_SUCCESS) { - // Reset handle so that one can still call the function again in the - // same process if there was a failure - free(nvjpeg_handle); - nvjpeg_handle = nullptr; - } - TORCH_CHECK( - create_status == NVJPEG_STATUS_SUCCESS, - "nvjpegCreateSimple failed: ", - create_status); - } - }); - - // Create the jpeg state - nvjpegJpegState_t jpeg_state; - nvjpegStatus_t state_status = - nvjpegJpegStateCreate(nvjpeg_handle, &jpeg_state); - - TORCH_CHECK( - state_status == NVJPEG_STATUS_SUCCESS, - "nvjpegJpegStateCreate failed: ", - state_status); - - auto datap = data.data_ptr(); - - // Get the image information - int num_channels; - nvjpegChromaSubsampling_t subsampling; - int widths[NVJPEG_MAX_COMPONENT]; - int heights[NVJPEG_MAX_COMPONENT]; - nvjpegStatus_t info_status = nvjpegGetImageInfo( - nvjpeg_handle, - datap, - data.numel(), - &num_channels, - &subsampling, - widths, - heights); - - if (info_status != NVJPEG_STATUS_SUCCESS) { - nvjpegJpegStateDestroy(jpeg_state); - TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status); - } - - if (subsampling == NVJPEG_CSS_UNKNOWN) { - nvjpegJpegStateDestroy(jpeg_state); - TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling"); - } - - int width = widths[0]; - int height = heights[0]; - - nvjpegOutputFormat_t ouput_format; - int num_channels_output; - - switch (mode) { - case IMAGE_READ_MODE_UNCHANGED: - num_channels_output = num_channels; - // For some reason, setting output_format to NVJPEG_OUTPUT_UNCHANGED will - // not properly decode RGB images (it's fine for grayscale), so we set - // output_format manually here - if (num_channels == 1) { - ouput_format = NVJPEG_OUTPUT_Y; - } else if (num_channels == 3) { - ouput_format = NVJPEG_OUTPUT_RGB; - } else { - nvjpegJpegStateDestroy(jpeg_state); - TORCH_CHECK( - false, - "When mode is UNCHANGED, only 1 or 3 input channels are allowed."); - } - break; - case IMAGE_READ_MODE_GRAY: - ouput_format = NVJPEG_OUTPUT_Y; - num_channels_output = 1; - break; - case IMAGE_READ_MODE_RGB: - ouput_format = NVJPEG_OUTPUT_RGB; - num_channels_output = 3; - break; - default: - nvjpegJpegStateDestroy(jpeg_state); - TORCH_CHECK( - false, "The provided mode is not supported for JPEG decoding on GPU"); - } - - auto out_tensor = torch::empty( - {int64_t(num_channels_output), int64_t(height), int64_t(width)}, - torch::dtype(torch::kU8).device(device)); - - // nvjpegImage_t is a struct with - // - an array of pointers to each channel - // - the pitch for each channel - // which must be filled in manually - nvjpegImage_t out_image; - - for (int c = 0; c < num_channels_output; c++) { - out_image.channel[c] = out_tensor[c].data_ptr(); - out_image.pitch[c] = width; - } - for (int c = num_channels_output; c < NVJPEG_MAX_COMPONENT; c++) { - out_image.channel[c] = nullptr; - out_image.pitch[c] = 0; - } - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()); - - nvjpegStatus_t decode_status = nvjpegDecode( - nvjpeg_handle, - jpeg_state, - datap, - data.numel(), - ouput_format, - &out_image, - stream); - - nvjpegJpegStateDestroy(jpeg_state); - - TORCH_CHECK( - decode_status == NVJPEG_STATUS_SUCCESS, - "nvjpegDecode failed: ", - decode_status); - - return out_tensor; -} - -#endif // NVJPEG_FOUND - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.h b/framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.h deleted file mode 100644 index 496b355e9b7..00000000000 --- a/framework/include/torchvision/io/image/cuda/decode_jpeg_cuda.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include -#include "../image_read_mode.h" - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor decode_jpeg_cuda( - const torch::Tensor& data, - ImageReadMode mode, - torch::Device device); - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/image.cpp b/framework/include/torchvision/io/image/image.cpp deleted file mode 100644 index 53d588e4746..00000000000 --- a/framework/include/torchvision/io/image/image.cpp +++ /dev/null @@ -1,39 +0,0 @@ -#include "image.h" - -#include -#ifdef USE_PYTHON -#include -#endif - -// If we are in a Windows environment, we need to define -// initialization functions for the _custom_ops extension -#ifdef USE_PYTHON -#ifdef _WIN32 -PyMODINIT_FUNC PyInit_image(void) { - // No need to do anything. - return NULL; -} -#endif -#endif // USE_PYTHON - -namespace vision { -namespace image { - -static auto registry = - torch::RegisterOperators() - .op("image::decode_png(Tensor data, int mode, bool allow_16_bits = False, bool apply_exif_orientation=False) -> Tensor", - &decode_png) - .op("image::encode_png", &encode_png) - .op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", - &decode_jpeg) - .op("image::encode_jpeg", &encode_jpeg) - .op("image::read_file", &read_file) - .op("image::write_file", &write_file) - .op("image::decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", - &decode_image) - .op("image::decode_jpeg_cuda", &decode_jpeg_cuda) - .op("image::_jpeg_version", &_jpeg_version) - .op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo); - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/io/image/image.h b/framework/include/torchvision/io/image/image.h deleted file mode 100644 index 05bac44c77d..00000000000 --- a/framework/include/torchvision/io/image/image.h +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once - -#include "cpu/decode_image.h" -#include "cpu/decode_jpeg.h" -#include "cpu/decode_png.h" -#include "cpu/encode_jpeg.h" -#include "cpu/encode_png.h" -#include "cpu/read_write_file.h" -#include "cuda/decode_jpeg_cuda.h" diff --git a/framework/include/torchvision/io/image/image_read_mode.h b/framework/include/torchvision/io/image/image_read_mode.h deleted file mode 100644 index 84425265c34..00000000000 --- a/framework/include/torchvision/io/image/image_read_mode.h +++ /dev/null @@ -1,17 +0,0 @@ -#pragma once - -#include - -namespace vision { -namespace image { - -/* Should be kept in-sync with Python ImageReadMode enum */ -using ImageReadMode = int64_t; -const ImageReadMode IMAGE_READ_MODE_UNCHANGED = 0; -const ImageReadMode IMAGE_READ_MODE_GRAY = 1; -const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2; -const ImageReadMode IMAGE_READ_MODE_RGB = 3; -const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4; - -} // namespace image -} // namespace vision diff --git a/framework/include/torchvision/macros.h b/framework/include/torchvision/macros.h deleted file mode 100644 index 64ca89429a9..00000000000 --- a/framework/include/torchvision/macros.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#if defined(_WIN32) && !defined(TORCHVISION_BUILD_STATIC_LIBS) -#if defined(torchvision_EXPORTS) -#define VISION_API __declspec(dllexport) -#else -#define VISION_API __declspec(dllimport) -#endif -#else -#define VISION_API -#endif - -#if (defined __cpp_inline_variables) || __cplusplus >= 201703L -#define VISION_INLINE_VARIABLE inline -#else -#ifdef _MSC_VER -#define VISION_INLINE_VARIABLE __declspec(selectany) -#define HINT_MSVC_LINKER_INCLUDE_SYMBOL -#else -#define VISION_INLINE_VARIABLE __attribute__((weak)) -#endif -#endif diff --git a/framework/include/torchvision/ops/.DS_Store b/framework/include/torchvision/ops/.DS_Store deleted file mode 100644 index ba279f808b2887b459c9e79f767d8973c56e4b0c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKy-veG47S@2MKE+=U?Z=Pxl5?hiLomKw2{h?C>78>FM`xZVBr-Qcmy7R=iu|% z%8wExCa91t`M%59KIgux;+lwfvCm?n84*pPf};}*10r_Oo{Vhe49H=REj`d3W%)K< z^}K!I7#Wbeo8r#)^h71K?r&2TJAL8J^N>$brs+B_*JAs;UcSH1->!!JvLE^F?&DRH z^MI{rOdHzKUVcrrg_C(>X}`mHUcZCI6=jpTZ|&scaMUps$oJUIck>bQ1xq{M<)M|& ztWCbYwVUrn&gbF`I0MeWUon81Eiye-^wAk`2AqMK0r@^eP{G8oQVgFC450-8POuyV zbLk}{CNfM6D@9l!tf4>+Wot25!?7NjUt(A(YB;eqA8eV~Ius7sv40fYi4#R1odIW{ z%fMJ4XHx%{-}nFBAb)ZOoPmGE08i3Yy1 -#include - -namespace vision { -namespace ops { - -namespace { - -class DeformConv2dFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& weight, - const torch::autograd::Variable& offset, - const torch::autograd::Variable& mask, - const torch::autograd::Variable& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask) { - at::AutoDispatchBelowADInplaceOrView g; - auto output = deform_conv2d_symint( - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); - - ctx->save_for_backward({input, weight, offset, mask, bias}); - ctx->saved_data["stride_h"] = stride_h; - ctx->saved_data["stride_w"] = stride_w; - ctx->saved_data["pad_h"] = pad_h; - ctx->saved_data["pad_w"] = pad_w; - ctx->saved_data["dilation_h"] = dilation_h; - ctx->saved_data["dilation_w"] = dilation_w; - ctx->saved_data["groups"] = groups; - ctx->saved_data["offset_groups"] = offset_groups; - ctx->saved_data["use_mask"] = use_mask; - - return { - output, - }; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - auto saved = ctx->get_saved_variables(); - auto input = saved[0]; - auto weight = saved[1]; - auto offset = saved[2]; - auto mask = saved[3]; - auto bias = saved[4]; - - auto stride_h = ctx->saved_data["stride_h"].toSymInt(); - auto stride_w = ctx->saved_data["stride_w"].toSymInt(); - auto pad_h = ctx->saved_data["pad_h"].toSymInt(); - auto pad_w = ctx->saved_data["pad_w"].toSymInt(); - auto dilation_h = ctx->saved_data["dilation_h"].toSymInt(); - auto dilation_w = ctx->saved_data["dilation_w"].toSymInt(); - auto groups = ctx->saved_data["groups"].toSymInt(); - auto offset_groups = ctx->saved_data["offset_groups"].toSymInt(); - auto use_mask = ctx->saved_data["use_mask"].toBool(); - - auto grads = detail::_deform_conv2d_backward_symint( - grad_output[0], - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); - auto grad_input = std::get<0>(grads); - auto grad_weight = std::get<1>(grads); - auto grad_offset = std::get<2>(grads); - auto grad_mask = std::get<3>(grads); - auto grad_bias = std::get<4>(grads); - - return { - grad_input, - grad_weight, - grad_offset, - grad_mask, - grad_bias, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - }; - } -}; - -// TODO: There should be an easier way to do this -class DeformConv2dBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& input, - const torch::autograd::Variable& weight, - const torch::autograd::Variable& offset, - const torch::autograd::Variable& mask, - const torch::autograd::Variable& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask) { - at::AutoDispatchBelowADInplaceOrView g; - auto result = detail::_deform_conv2d_backward_symint( - grad, - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); - - auto grad_input = std::get<0>(result); - auto grad_weight = std::get<1>(result); - auto grad_offset = std::get<2>(result); - auto grad_mask = std::get<3>(result); - auto grad_bias = std::get<4>(result); - - return { - grad_input, - grad_weight, - grad_offset, - grad_mask, - grad_bias, - }; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on deform_conv2d not supported"); - } -}; - -at::Tensor deform_conv2d_autograd( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask) { - return DeformConv2dFunction::apply( - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask)[0]; -} - -std::tuple -deform_conv2d_backward_autograd( - const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask) { - auto result = DeformConv2dBackwardFunction::apply( - grad, - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); - - return std::make_tuple(result[0], result[1], result[2], result[3], result[4]); -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), - TORCH_FN(deform_conv2d_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), - TORCH_FN(deform_conv2d_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp b/framework/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp deleted file mode 100644 index 7205e9b15db..00000000000 --- a/framework/include/torchvision/ops/autograd/ps_roi_align_kernel.cpp +++ /dev/null @@ -1,167 +0,0 @@ -#include "../ps_roi_align.h" - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -class PSROIAlignFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio) { - ctx->saved_data["spatial_scale"] = spatial_scale; - ctx->saved_data["pooled_height"] = pooled_height; - ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["sampling_ratio"] = sampling_ratio; - ctx->saved_data["input_shape"] = input.sym_sizes(); - at::AutoDispatchBelowADInplaceOrView g; - auto result = ps_roi_align_symint( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio); - - auto output = std::get<0>(result); - auto channel_mapping = std::get<1>(result); - ctx->save_for_backward({rois, channel_mapping}); - ctx->mark_non_differentiable({channel_mapping}); - - return {output, channel_mapping}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - // Use data saved in forward - auto saved = ctx->get_saved_variables(); - auto rois = saved[0]; - auto channel_mapping = saved[1]; - auto input_shape = ctx->saved_data["input_shape"].toList(); - auto grad_in = detail::_ps_roi_align_backward_symint( - grad_output[0], - rois, - channel_mapping, - ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toSymInt(), - ctx->saved_data["pooled_width"].toSymInt(), - ctx->saved_data["sampling_ratio"].toInt(), - input_shape[0].get().toSymInt(), - input_shape[1].get().toSymInt(), - input_shape[2].get().toSymInt(), - input_shape[3].get().toSymInt()); - - return { - grad_in, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable()}; - } -}; - -// TODO: There should be an easier way to do this -class PSROIAlignBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& rois, - const torch::autograd::Variable& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - at::AutoDispatchBelowADInplaceOrView g; - auto grad_in = detail::_ps_roi_align_backward_symint( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - batch_size, - channels, - height, - width); - - return {grad_in}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on ps_roi_align not supported"); - } -}; - -std::tuple ps_roi_align_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio) { - auto result = PSROIAlignFunction::apply( - input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); - - return std::make_tuple(result[0], result[1]); -} - -at::Tensor ps_roi_align_backward_autograd( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - return PSROIAlignBackwardFunction::apply( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - batch_size, - channels, - height, - width)[0]; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), - TORCH_FN(ps_roi_align_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), - TORCH_FN(ps_roi_align_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp b/framework/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp deleted file mode 100644 index 39b83819f94..00000000000 --- a/framework/include/torchvision/ops/autograd/ps_roi_pool_kernel.cpp +++ /dev/null @@ -1,152 +0,0 @@ -#include "../ps_roi_pool.h" - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -class PSROIPoolFunction : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - ctx->saved_data["spatial_scale"] = spatial_scale; - ctx->saved_data["pooled_height"] = pooled_height; - ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["input_shape"] = input.sym_sizes(); - at::AutoDispatchBelowADInplaceOrView g; - auto result = ps_roi_pool_symint( - input, rois, spatial_scale, pooled_height, pooled_width); - - auto output = std::get<0>(result); - auto channel_mapping = std::get<1>(result); - ctx->save_for_backward({rois, channel_mapping}); - ctx->mark_non_differentiable({channel_mapping}); - - return {output, channel_mapping}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - // Use data saved in forward - auto saved = ctx->get_saved_variables(); - auto rois = saved[0]; - auto channel_mapping = saved[1]; - auto input_shape = ctx->saved_data["input_shape"].toList(); - auto grad_in = detail::_ps_roi_pool_backward_symint( - grad_output[0], - rois, - channel_mapping, - ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toSymInt(), - ctx->saved_data["pooled_width"].toSymInt(), - input_shape[0].get().toSymInt(), - input_shape[1].get().toSymInt(), - input_shape[2].get().toSymInt(), - input_shape[3].get().toSymInt()); - - return { - grad_in, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable()}; - } -}; - -// TODO: There should be an easier way to do this -class PSROIPoolBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& rois, - const torch::autograd::Variable& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - at::AutoDispatchBelowADInplaceOrView g; - auto grad_in = detail::_ps_roi_pool_backward_symint( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); - - return {grad_in}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on ps_roi_pool not supported"); - } -}; - -std::tuple ps_roi_pool_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - auto result = PSROIPoolFunction::apply( - input, rois, spatial_scale, pooled_height, pooled_width); - - return std::make_tuple(result[0], result[1]); -} - -at::Tensor ps_roi_pool_backward_autograd( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - return PSROIPoolBackwardFunction::apply( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width)[0]; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), - TORCH_FN(ps_roi_pool_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), - TORCH_FN(ps_roi_pool_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/autograd/roi_align_kernel.cpp b/framework/include/torchvision/ops/autograd/roi_align_kernel.cpp deleted file mode 100644 index 6d792fe09d9..00000000000 --- a/framework/include/torchvision/ops/autograd/roi_align_kernel.cpp +++ /dev/null @@ -1,167 +0,0 @@ -#include "../roi_align.h" - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -class ROIAlignFunction : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - bool aligned) { - ctx->saved_data["spatial_scale"] = spatial_scale; - ctx->saved_data["pooled_height"] = pooled_height; - ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["sampling_ratio"] = sampling_ratio; - ctx->saved_data["aligned"] = aligned; - ctx->saved_data["input_shape"] = input.sym_sizes(); - ctx->save_for_backward({rois}); - at::AutoDispatchBelowADInplaceOrView g; - auto result = roi_align_symint( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned); - return {result}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - // Use data saved in forward - auto saved = ctx->get_saved_variables(); - auto rois = saved[0]; - auto input_shape = ctx->saved_data["input_shape"].toList(); - auto grad_in = detail::_roi_align_backward_symint( - grad_output[0], - rois, - ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toSymInt(), - ctx->saved_data["pooled_width"].toSymInt(), - input_shape[0].get().toSymInt(), - input_shape[1].get().toSymInt(), - input_shape[2].get().toSymInt(), - input_shape[3].get().toSymInt(), - ctx->saved_data["sampling_ratio"].toInt(), - ctx->saved_data["aligned"].toBool()); - return { - grad_in, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable()}; - } -}; - -// TODO: There should be an easier way to do this -class ROIAlignBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned) { - at::AutoDispatchBelowADInplaceOrView g; - auto result = detail::_roi_align_backward_symint( - grad, - rois, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width, - sampling_ratio, - aligned); - return {result}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on roi_align not supported"); - } -}; - -at::Tensor roi_align_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - bool aligned) { - return ROIAlignFunction::apply( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned)[0]; -} - -at::Tensor roi_align_backward_autograd( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned) { - return ROIAlignBackwardFunction::apply( - grad, - rois, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width, - sampling_ratio, - aligned)[0]; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_align"), - TORCH_FN(roi_align_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), - TORCH_FN(roi_align_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/autograd/roi_pool_kernel.cpp b/framework/include/torchvision/ops/autograd/roi_pool_kernel.cpp deleted file mode 100644 index 508bafb2b1e..00000000000 --- a/framework/include/torchvision/ops/autograd/roi_pool_kernel.cpp +++ /dev/null @@ -1,152 +0,0 @@ -#include "../roi_pool.h" - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -class ROIPoolFunction : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - ctx->saved_data["spatial_scale"] = spatial_scale; - ctx->saved_data["pooled_height"] = pooled_height; - ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["input_shape"] = input.sym_sizes(); - at::AutoDispatchBelowADInplaceOrView g; - auto result = roi_pool_symint( - input, rois, spatial_scale, pooled_height, pooled_width); - - auto output = std::get<0>(result); - auto argmax = std::get<1>(result); - ctx->save_for_backward({rois, argmax}); - ctx->mark_non_differentiable({argmax}); - - return {output, argmax}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - // Use data saved in forward - auto saved = ctx->get_saved_variables(); - auto rois = saved[0]; - auto argmax = saved[1]; - auto input_shape = ctx->saved_data["input_shape"].toList(); - auto grad_in = detail::_roi_pool_backward_symint( - grad_output[0], - rois, - argmax, - ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toSymInt(), - ctx->saved_data["pooled_width"].toSymInt(), - input_shape[0].get().toSymInt(), - input_shape[1].get().toSymInt(), - input_shape[2].get().toSymInt(), - input_shape[3].get().toSymInt()); - - return { - grad_in, - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable(), - torch::autograd::Variable()}; - } -}; - -// TODO: There should be an easier way to do this -class ROIPoolBackwardFunction - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& grad, - const torch::autograd::Variable& rois, - const torch::autograd::Variable& argmax, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - at::AutoDispatchBelowADInplaceOrView g; - auto grad_in = detail::_roi_pool_backward_symint( - grad, - rois, - argmax, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); - - return {grad_in}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::variable_list& grad_output) { - TORCH_CHECK(0, "double backwards on roi_pool not supported"); - } -}; - -std::tuple roi_pool_autograd( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - auto result = ROIPoolFunction::apply( - input, rois, spatial_scale, pooled_height, pooled_width); - - return std::make_tuple(result[0], result[1]); -} - -at::Tensor roi_pool_backward_autograd( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - return ROIPoolBackwardFunction::apply( - grad, - rois, - argmax, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width)[0]; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_pool"), - TORCH_FN(roi_pool_autograd)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), - TORCH_FN(roi_pool_backward_autograd)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp b/framework/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp deleted file mode 100644 index c5e59077aa6..00000000000 --- a/framework/include/torchvision/ops/cpu/deform_conv2d_kernel.cpp +++ /dev/null @@ -1,1172 +0,0 @@ -/*! - ******************* BEGIN Caffe Copyright Notice and Disclaimer - ***************** - * - * COPYRIGHT - * - * All contributions by the University of California: - * Copyright (c) 2014-2017 The Regents of the University of California (Regents) - * All rights reserved. - * - * All other contributions: - * Copyright (c) 2014-2017, the respective contributors - * All rights reserved. - * - * Caffe uses a shared copyright model: each contributor holds copyright over - * their contributions to Caffe. The project versioning records all such - * contribution and copyright details. If a contributor wants to further mark - * their specific copyright on a particular contribution, they should indicate - * their copyright solely in the commit message of the change when it is - * committed. - * - * LICENSE - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE - *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * CONTRIBUTION AGREEMENT - * - * By contributing to the BVLC/caffe repository through pull-request, comment, - * or otherwise, the contributor releases their content to the - * license and copyright terms herein. - * - ***************** END Caffe Copyright Notice and Disclaimer - ********************* - * - * Copyright (c) 2018 Microsoft - * Licensed under The MIT License [see LICENSE for details] - * \file modulated_deformable_im2col.cuh - * \brief Function definitions of converting an image to - * column matrix based on kernel, padding, dilation, and offset. - * These functions are mainly used in deformable convolution operators. - * \ref: https://arxiv.org/abs/1703.06211 - * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng - */ - -// modified from -// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu - -// modified from -// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -const int kMaxParallelImgs = 32; - -template -scalar_t bilinear_interpolate( - const scalar_t* in, - int height, - int width, - scalar_t h, - scalar_t w) { - if (h <= -1 || height <= h || w <= -1 || width <= w) { - return 0; - } - - int h_low = floor(h); - int w_low = floor(w); - int h_high = h_low + 1; - int w_high = w_low + 1; - - scalar_t lh = h - h_low; - scalar_t lw = w - w_low; - scalar_t hh = 1 - lh, hw = 1 - lw; - - scalar_t v1 = 0; - if (h_low >= 0 && w_low >= 0) - v1 = in[h_low * width + w_low]; - scalar_t v2 = 0; - if (h_low >= 0 && w_high <= width - 1) - v2 = in[h_low * width + w_high]; - scalar_t v3 = 0; - if (h_high <= height - 1 && w_low >= 0) - v3 = in[h_high * width + w_low]; - scalar_t v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - v4 = in[h_high * width + w_high]; - - scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - - scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; -} - -template -void deformable_im2col_kernel( - int n, - const scalar_t* input, - const scalar_t* offset, - const scalar_t* mask, - int height, - int width, - int weight_h, - int weight_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int batch_sz, - int n_in_channels, - int n_offset_grps, - int out_h, - int out_w, - bool use_mask, - scalar_t* columns) { - for (int index = 0; index != n; ++index) { - const int out_x = index % out_w; - const int out_y = (index / out_w) % out_h; - const int out_b = (index / (out_w * out_h)) % batch_sz; - const int in_c = index / (out_w * out_h * batch_sz); - const int out_c = in_c * weight_h * weight_w; - - int c_per_offset_grp = n_in_channels / n_offset_grps; - const int grp_idx = in_c / c_per_offset_grp; - - auto columns_ptr = columns + - (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + - out_y * out_w + out_x); - - auto input_ptr = input + - (out_b * (n_in_channels * height * width) + in_c * (height * width)); - - auto offset_ptr = offset + - (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * - out_w; - - auto mask_ptr = mask; - if (use_mask) { - mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * - out_h * out_w; - } - - for (int i = 0; i < weight_h; ++i) { - for (int j = 0; j < weight_w; ++j) { - const int mask_idx = i * weight_w + j; - const int offset_idx = 2 * mask_idx; - - scalar_t mask_value = 1; - if (use_mask) { - mask_value = - mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; - } - - const scalar_t offset_h = - offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; - const scalar_t offset_w = offset_ptr - [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; - const scalar_t y = - (out_y * stride_h - pad_h) + i * dilation_h + offset_h; - const scalar_t x = - (out_x * stride_w - pad_w) + j * dilation_w + offset_w; - *columns_ptr = - mask_value * bilinear_interpolate(input_ptr, height, width, y, x); - columns_ptr += batch_sz * out_h * out_w; - } - } - } -} - -void deformable_im2col( - const at::Tensor& input, - const at::Tensor& data_offset, - const at::Tensor& data_mask, - int n_in_channels, - int height, - int width, - int weight_h, - int weight_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int out_h, - int out_w, - int parallel_imgs, - int deformable_group, - bool use_mask, - at::Tensor data_col) { - int num_kernels = n_in_channels * out_h * out_w * parallel_imgs; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "deformable_im2col", ([&] { - deformable_im2col_kernel( - num_kernels, - input.data_ptr(), - data_offset.data_ptr(), - data_mask.data_ptr(), - height, - width, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - parallel_imgs, - n_in_channels, - deformable_group, - out_h, - out_w, - use_mask, - data_col.data_ptr()); - })); -} - -int get_greatest_divisor_below_bound(int n, int bound) { - for (int k = bound; k > 1; --k) { - if (n % k == 0) { - return k; - } - } - return 1; -} - -template -void deformable_col2im_kernel( - int n, - const scalar_t* col, - const scalar_t* offset, - const scalar_t* mask, - int channels, - int height, - int width, - int kernel_h, - int kernel_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int batch_sz, - int n_offset_grps, - int out_h, - int out_w, - bool use_mask, - scalar_t* grad_im) { - for (int index = 0; index != n; ++index) { - const int out_x = index % out_w; - const int out_y = (index / out_w) % out_h; - const int b = (index / (out_w * out_h)) % batch_sz; - const int j = (index / (out_w * out_h * batch_sz)) % kernel_w; - const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; - const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); - - int c_per_offset_grp = channels / n_offset_grps; - const int offset_grp = c / c_per_offset_grp; - - auto offset_ptr = offset + - (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * - out_w; - - auto mask_ptr = mask; - if (use_mask) { - mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * - out_h * out_w; - } - - const int mask_idx = i * kernel_w + j; - const int offset_idx = 2 * mask_idx; - - const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; - const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; - - const scalar_t offset_h = offset_ptr[offset_h_ptr]; - const scalar_t offset_w = offset_ptr[offset_w_ptr]; - - scalar_t mask_value = 1; - if (use_mask) { - mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; - } - - const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; - const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; - - for (int dy = -1; dy <= 1; dy++) { - for (int dx = -1; dx <= 1; dx++) { - int yp = int(y) + dy; - int xp = int(x) + dx; - if (0 <= yp && yp < height && 0 <= xp && xp < width && - std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { - int grad_pos = ((b * channels + c) * height + yp) * width + xp; - scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); - grad_im[grad_pos] += mask_value * weight * col[index]; - } - } - } - } -} - -void compute_grad_input( - const at::Tensor& columns, - const at::Tensor& offset, - const at::Tensor& mask, - int channels, - int height, - int width, - int weight_h, - int weight_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int parallel_imgs, - int n_offset_grps, - bool use_mask, - at::Tensor grad_im) { - int out_h = - (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - int out_w = - (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int num_kernels = - channels * weight_h * weight_w * out_h * out_w * parallel_imgs; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - columns.scalar_type(), "compute_grad_input", ([&] { - deformable_col2im_kernel( - num_kernels, - columns.data_ptr(), - offset.data_ptr(), - mask.data_ptr(), - channels, - height, - width, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - parallel_imgs, - n_offset_grps, - out_h, - out_w, - use_mask, - grad_im.data_ptr()); - })); -} - -template -scalar_t get_coordinate_weight( - const scalar_t* im_data, - int height, - int width, - scalar_t y, - scalar_t x, - bool is_y_direction) { - int y_l = floor(y); - int x_l = floor(x); - int y_h = y_l + 1; - int x_h = x_l + 1; - - bool valid_y_l = 0 <= y_l && y_l < height; - bool valid_y_h = 0 <= y_h && y_h < height; - bool valid_x_l = 0 <= x_l && x_l < width; - bool valid_x_h = 0 <= x_h && x_h < width; - - scalar_t zero = 0; - scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero; - scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero; - scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero; - scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero; - - if (is_y_direction) { - scalar_t dx = x - x_l; - return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx); - } else { - scalar_t dy = y - y_l; - return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx); - } -} - -template -void deformable_col2im_coord_kernel( - int n, - const scalar_t* col, - const scalar_t* im, - const scalar_t* offset, - const scalar_t* mask, - int channels, - int height, - int width, - int weight_h, - int weight_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int batch_sz, - int offset_channels, - int n_offset_grps, - int out_h, - int out_w, - bool use_mask, - scalar_t* grad_offset, - scalar_t* grad_mask) { - for (int index = 0; index != n; ++index) { - scalar_t grad_offset_val = 0; - scalar_t grad_mask_val = 0; - - int w = index % out_w; - int h = (index / out_w) % out_h; - int w_w = (index / (out_w * out_h * 2)) % weight_w; - int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; - int c = (index / (out_w * out_h)) % offset_channels; - int b = index / (out_w * out_h * offset_channels); - - const int offset_grp = c / (2 * weight_h * weight_w); - const int col_step = weight_h * weight_w; - - int c_per_offset_grp = channels / n_offset_grps; - - auto col_ptr = col + - offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * - out_h; - auto im_ptr = im + - (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; - auto offset_ptr = offset + - (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h * - out_w; - - auto mask_ptr = mask; - if (use_mask) { - mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * - out_h * out_w; - } - - const int offset_c = c - offset_grp * 2 * weight_h * weight_w; - const bool is_y_direction = offset_c % 2 == 0; - - const int c_bound = c_per_offset_grp * weight_h * weight_w; - for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { - const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w; - - int out_x = col_pos % out_w; - int out_y = (col_pos / out_w) % out_h; - int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; - int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; - - const int mask_idx = i * weight_w + j; - - const int offset_h_idx = - (((2 * mask_idx) * out_h + out_y) * out_w + out_x); - const int offset_w_idx = - (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); - const scalar_t offset_h = offset_ptr[offset_h_idx]; - const scalar_t offset_w = offset_ptr[offset_w_idx]; - - scalar_t mask_value = 1; - if (use_mask) { - mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; - } - - scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; - scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; - - const scalar_t weight = - get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); - grad_offset_val += mask_value * weight * col_ptr[col_pos]; - - if (use_mask && is_y_direction) { - grad_mask_val += col_ptr[col_pos] * - bilinear_interpolate(im_ptr, height, width, y, x); - } - - im_ptr += height * width; - } - - grad_offset[index] = grad_offset_val; - - if (use_mask && is_y_direction) { - const int idx = - ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + - w_w) * - out_h + - h) * - out_w + - w; - grad_mask[idx] = grad_mask_val; - } - } -} - -void compute_grad_offset_and_mask( - const at::Tensor& columns, - const at::Tensor& input, - const at::Tensor& offset, - const at::Tensor& mask, - int channels, - int height, - int width, - int weight_h, - int weight_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int parallel_imgs, - int n_offset_grps, - bool use_mask, - at::Tensor grad_offset, - at::Tensor grad_mask) { - int out_h = - (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - int out_w = - (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int num_kernels = - out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - columns.scalar_type(), "compute_grad_offset_and_mask", ([&] { - deformable_col2im_coord_kernel( - num_kernels, - columns.data_ptr(), - input.data_ptr(), - offset.data_ptr(), - mask.data_ptr(), - channels, - height, - width, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - parallel_imgs, - 2 * weight_h * weight_w * n_offset_grps, - n_offset_grps, - out_h, - out_w, - use_mask, - grad_offset.data_ptr(), - grad_mask.data_ptr()); - })); -} - -std::tuple backward_gradient_inputs( - at::Tensor input, - at::Tensor weight, - at::Tensor offset, - at::Tensor mask, - at::Tensor grad_out, - int stride_h, - int stride_w, - int pad_h, - int pad_w, - int dilation_h, - int dilation_w, - int n_weight_grps, - int n_offset_grps, - int n_parallel_imgs, - bool use_mask) { - int batch_sz = input.size(0); - int n_in_channels = input.size(1); - int in_h = input.size(2); - int in_w = input.size(3); - - n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); - - long n_out_channels = weight.size(0); - int weight_h = weight.size(2); - int weight_w = weight.size(3); - - long out_h = - (in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - long out_w = - (in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - - auto grad_input = at::zeros_like(input); - auto grad_offset = at::zeros_like(offset); - auto grad_mask = at::zeros_like(mask); - - if (batch_sz == 0) { - return std::make_tuple(grad_input, grad_offset, grad_mask); - } - - auto columns = at::empty( - {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, - input.options()); - - // Separate into blocks - grad_input = grad_input.reshape( - {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - input = input.reshape( - {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - - grad_offset = grad_offset.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - offset = offset.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - - if (use_mask) { - grad_mask = grad_mask.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - mask = mask.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - } - - grad_out = grad_out - .reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_weight_grps, - n_out_channels / n_weight_grps, - out_h, - out_w}) - .permute({0, 2, 3, 1, 4, 5}); - - weight = weight.reshape( - {n_weight_grps, - weight.size(0) / n_weight_grps, - weight.size(1), - weight.size(2), - weight.size(3)}); - - columns = columns.view( - {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); - - for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { - columns.zero_(); - // Separate into weight groups - for (int g = 0; g < n_weight_grps; g++) { - columns[g] = columns[g].addmm_( - weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); - } - - compute_grad_offset_and_mask( - columns, - input[elt], - offset[elt], - mask[elt], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - grad_offset[elt], - grad_mask[elt]); - - compute_grad_input( - columns, - offset[elt], - mask[elt], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - grad_input[elt]); - } - - grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); - grad_offset = grad_offset.view( - {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - - if (use_mask) { - grad_mask = grad_mask.view( - {batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w}); - } - - return std::make_tuple(grad_input, grad_offset, grad_mask); -} - -at::Tensor backward_gradient_parameters( - at::Tensor input, - const at::Tensor& weight, - at::Tensor offset, - at::Tensor mask, - const at::Tensor& grad_out, - int stride_h, - int stride_w, - int pad_h, - int pad_w, - int dilation_h, - int dilation_w, - int n_weight_grps, - int n_offset_grps, - int n_parallel_imgs, - bool use_mask) { - int batch_sz = input.size(0); - int n_in_channels = input.size(1); - int in_h = input.size(2); - int in_w = input.size(3); - - n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); - - long n_out_channels = weight.size(0); - int weight_h = weight.size(2); - int weight_w = weight.size(3); - - long out_h = grad_out.size(2); - long out_w = grad_out.size(3); - - auto grad_weight = at::zeros_like(weight); - if (batch_sz == 0) { - return grad_weight; - } - - at::Tensor grad_out_buf = grad_out - .reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_weight_grps, - n_out_channels / n_weight_grps, - out_h, - out_w}) - .permute({0, 2, 3, 1, 4, 5}) - .contiguous(); - - input = input.reshape( - {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - - offset = offset.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - - if (use_mask) { - mask = mask.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - } - - grad_weight = grad_weight.view( - {n_weight_grps, - grad_weight.size(0) / n_weight_grps, - grad_weight.size(1), - grad_weight.size(2), - grad_weight.size(3)}); - - auto columns = at::empty( - {n_weight_grps, - n_in_channels * weight_w * weight_h / n_weight_grps, - n_parallel_imgs * out_h * out_w}, - input.options()); - - for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { - deformable_im2col( - input[elt], - offset[elt], - mask[elt], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - out_h, - out_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - columns); - - for (int g = 0; g < n_weight_grps; g++) { - grad_weight[g] = - grad_weight[g] - .flatten(1) - .addmm_( - grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) - .view_as(grad_weight[g]); - } - } - - grad_weight = grad_weight.view( - {grad_weight.size(0) * grad_weight.size(1), - grad_weight.size(2), - grad_weight.size(3), - grad_weight.size(4)}); - return grad_weight; -} - -at::Tensor deform_conv2d_forward_kernel( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - bool use_mask) { - at::Tensor input_c = input.contiguous(); - at::Tensor offset_c = offset.contiguous(); - at::Tensor weight_c = weight.contiguous(); - at::Tensor mask_c = mask.contiguous(); - at::Tensor bias_c = bias.contiguous(); - - TORCH_CHECK(input_c.ndimension() == 4); - TORCH_CHECK(offset_c.ndimension() == 4); - TORCH_CHECK(!use_mask || mask_c.ndimension() == 4); - TORCH_CHECK(weight_c.ndimension() == 4); - TORCH_CHECK(input_c.device().is_cpu(), "input must be a CPU tensor"); - - int batch_sz = input_c.size(0); - int n_in_channels = input_c.size(1); - int in_h = input_c.size(2); - int in_w = input_c.size(3); - - int n_parallel_imgs = - get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); - - // Unpack shapes and args - int out_channels = weight_c.size(0); - int weight_h = weight_c.size(2); - int weight_w = weight_c.size(3); - - int ker_h = dilation_h * (weight_h - 1) + 1; - int ker_w = dilation_w * (weight_w - 1) + 1; - int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; - int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; - - TORCH_CHECK( - weight_h > 0 && weight_w > 0, - "weight_h: ", - weight_h, - " weight_w: ", - weight_w); - TORCH_CHECK( - stride_h > 0 && stride_w > 0, - "stride_h: ", - stride_h, - " stride_w: ", - stride_w); - TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); - TORCH_CHECK( - dilation_h > 0 && dilation_w > 0, - "dilation_h: ", - dilation_h, - " dilation_w: ", - dilation_w); - - TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1)); - TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0); - TORCH_CHECK( - (offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w), - "offset.shape[1] is not valid: got: ", - offset_c.size(1), - " expected: ", - n_offset_grps * 2 * weight_h * weight_w); - TORCH_CHECK( - (!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w), - "mask.shape[1] is not valid: got: ", - mask_c.size(1), - " expected: ", - n_offset_grps * weight_h * weight_w); - TORCH_CHECK(input_c.size(1) % n_offset_grps == 0); - - TORCH_CHECK( - (offset_c.size(0) == input_c.size(0)), "invalid batch size of offset"); - TORCH_CHECK( - (offset_c.size(2) == out_h && offset_c.size(3) == out_w), - "offset output dims: (", - offset_c.size(2), - ", ", - offset_c.size(3), - ") - ", - "computed output dims: (", - out_h, - ", ", - out_w, - ")"); - TORCH_CHECK( - (mask_c.size(0) == input_c.size(0)), "invalid batch size of mask"); - TORCH_CHECK( - (!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)), - "mask output dims: (", - mask_c.size(2), - ", ", - mask_c.size(3), - ") - ", - "computed output dims: (", - out_h, - ", ", - out_w, - ")"); - TORCH_CHECK( - out_h > 0 && out_w > 0, - "Calculated output size too small - out_h: ", - out_h, - " out_w: ", - out_w); - - auto out = - at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options()); - if (batch_sz == 0) { - return out; - } - - // Separate batches into blocks - out = out.view( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - out_channels, - out_h, - out_w}); - input_c = input_c.view( - {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - - offset_c = offset_c.view( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - - if (use_mask) { - mask_c = mask_c.view( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - } - - at::Tensor out_buf = at::zeros( - {batch_sz / n_parallel_imgs, - out_channels, - n_parallel_imgs * out_h, - out_w}, - out.options()); - - // Separate channels into convolution groups - out_buf = out_buf.view( - {out_buf.size(0), - n_weight_grps, - out_buf.size(1) / n_weight_grps, - out_buf.size(2), - out_buf.size(3)}); - weight_c = weight_c.view( - {n_weight_grps, - weight_c.size(0) / n_weight_grps, - weight_c.size(1), - weight_c.size(2), - weight_c.size(3)}); - - // Sample points and perform convolution - auto columns = at::zeros( - {n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, - input_c.options()); - for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { - deformable_im2col( - input_c[b], - offset_c[b], - mask_c[b], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - out_h, - out_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - columns); - - columns = columns.view( - {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); - for (int g = 0; g < n_weight_grps; g++) { - out_buf[b][g] = out_buf[b][g] - .flatten(1) - .addmm_(weight_c[g].flatten(1), columns[g]) - .view_as(out_buf[b][g]); - } - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - } - - out_buf = out_buf.view( - {batch_sz / n_parallel_imgs, - out_channels, - n_parallel_imgs, - out_h, - out_w}); - out_buf.transpose_(1, 2); - out.copy_(out_buf); - out = out.view({batch_sz, out_channels, out_h, out_w}); - - return out + bias_c.view({1, out_channels, 1, 1}); -} - -std::tuple -deform_conv2d_backward_kernel( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - bool use_mask) { - at::Tensor grad_out_c = grad_out.contiguous(); - at::Tensor input_c = input.contiguous(); - at::Tensor weight_c = weight.contiguous(); - at::Tensor offset_c = offset.contiguous(); - at::Tensor mask_c = mask.contiguous(); - at::Tensor bias_c = bias.contiguous(); - - const int batch_sz = input_c.size(0); - const int n_parallel_imgs = - get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); - - auto grad_input_and_offset_and_mask = backward_gradient_inputs( - input_c, - weight_c, - offset_c, - mask_c, - grad_out_c, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - n_weight_grps, - n_offset_grps, - n_parallel_imgs, - use_mask); - - auto grad_input = std::get<0>(grad_input_and_offset_and_mask); - auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); - auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); - - auto grad_weight = backward_gradient_parameters( - input_c, - weight_c, - offset_c, - mask_c, - grad_out_c, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - n_weight_grps, - n_offset_grps, - n_parallel_imgs, - use_mask); - - auto grad_bias = at::ones_like(bias_c) * grad_out_c.sum({0, 2, 3}); - - return std::make_tuple( - grad_input, grad_weight, grad_offset, grad_mask, grad_bias); -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), - TORCH_FN(deform_conv2d_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), - TORCH_FN(deform_conv2d_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/nms_kernel.cpp b/framework/include/torchvision/ops/cpu/nms_kernel.cpp deleted file mode 100644 index 50479066cbd..00000000000 --- a/framework/include/torchvision/ops/cpu/nms_kernel.cpp +++ /dev/null @@ -1,117 +0,0 @@ -#include -#include - -namespace vision { -namespace ops { - -namespace { - -template -at::Tensor nms_kernel_impl( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold) { - TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); - TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); - TORCH_CHECK( - dets.scalar_type() == scores.scalar_type(), - "dets should have the same type as scores"); - - if (dets.numel() == 0) - return at::empty({0}, dets.options().dtype(at::kLong)); - - auto x1_t = dets.select(1, 0).contiguous(); - auto y1_t = dets.select(1, 1).contiguous(); - auto x2_t = dets.select(1, 2).contiguous(); - auto y2_t = dets.select(1, 3).contiguous(); - - at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); - - auto order_t = std::get<1>( - scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); - - auto ndets = dets.size(0); - at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); - at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); - - auto suppressed = suppressed_t.data_ptr(); - auto keep = keep_t.data_ptr(); - auto order = order_t.data_ptr(); - auto x1 = x1_t.data_ptr(); - auto y1 = y1_t.data_ptr(); - auto x2 = x2_t.data_ptr(); - auto y2 = y2_t.data_ptr(); - auto areas = areas_t.data_ptr(); - - int64_t num_to_keep = 0; - - for (int64_t _i = 0; _i < ndets; _i++) { - auto i = order[_i]; - if (suppressed[i] == 1) - continue; - keep[num_to_keep++] = i; - auto ix1 = x1[i]; - auto iy1 = y1[i]; - auto ix2 = x2[i]; - auto iy2 = y2[i]; - auto iarea = areas[i]; - - for (int64_t _j = _i + 1; _j < ndets; _j++) { - auto j = order[_j]; - if (suppressed[j] == 1) - continue; - auto xx1 = std::max(ix1, x1[j]); - auto yy1 = std::max(iy1, y1[j]); - auto xx2 = std::min(ix2, x2[j]); - auto yy2 = std::min(iy2, y2[j]); - - auto w = std::max(static_cast(0), xx2 - xx1); - auto h = std::max(static_cast(0), yy2 - yy1); - auto inter = w * h; - auto ovr = inter / (iarea + areas[j] - inter); - if (ovr > iou_threshold) - suppressed[j] = 1; - } - } - return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); -} - -at::Tensor nms_kernel( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold) { - TORCH_CHECK( - dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); - TORCH_CHECK( - dets.size(1) == 4, - "boxes should have 4 elements in dimension 1, got ", - dets.size(1)); - TORCH_CHECK( - scores.dim() == 1, - "scores should be a 1d tensor, got ", - scores.dim(), - "D"); - TORCH_CHECK( - dets.size(0) == scores.size(0), - "boxes and scores should have same number of elements in ", - "dimension 0, got ", - dets.size(0), - " and ", - scores.size(0)); - - auto result = at::empty({0}, dets.options()); - - AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] { - result = nms_kernel_impl(dets, scores, iou_threshold); - }); - return result; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp b/framework/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp deleted file mode 100644 index 1c272427d3f..00000000000 --- a/framework/include/torchvision/ops/cpu/ps_roi_align_kernel.cpp +++ /dev/null @@ -1,429 +0,0 @@ -#include -#include - -namespace vision { -namespace ops { - -namespace { - -template -T bilinear_interpolate( - const T* input, - int height, - int width, - T y, - T x, - int index /* index for debug only*/) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - return 0; - } - - if (y <= 0) - y = 0; - if (x <= 0) - x = 0; - - int y_low = (int)y; - int x_low = (int)x; - int y_high; - int x_high; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T)x_low; - } else { - x_high = x_low + 1; - } - - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - - // do bilinear interpolation - T v1 = input[y_low * width + x_low]; - T v2 = input[y_low * width + x_high]; - T v3 = input[y_high * width + x_low]; - T v4 = input[y_high * width + x_high]; - T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - - T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - - return val; -} - -template -void ps_roi_align_forward_kernel_impl( - int num_rois, - const T* input, - const T spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - int sampling_ratio, - const T* rois, - int channels_out, - T* output, - int* channel_mapping) { - for (int n = 0; n < num_rois; n++) { - // [start, end) interval for spatial sampling - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); - T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); - T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); - T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - int c_in = 0; - for (int c_out = 0; c_out < channels_out; ++c_out) { - for (int ph = 0; ph < pooled_height; ++ph) { - for (int pw = 0; pw < pooled_width; ++pw) { - int index = - ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + - pw; - - // Do not using floor/ceil; this implementation detail is critical - T hstart = static_cast(ph) * bin_size_h + roi_start_h; - T wstart = static_cast(pw) * bin_size_w + roi_start_w; - - // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); - int roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_width / pooled_width); - const T count = roi_bin_grid_h * roi_bin_grid_w; - - const T* offset_input = - input + (roi_batch_ind * channels + c_in) * height * width; - - T out_sum = 0; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = hstart + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = wstart + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - T val = bilinear_interpolate( - offset_input, height, width, y, x, index); - out_sum += val; - } - } - - out_sum /= count; - output[index] = out_sum; - channel_mapping[index] = c_in; - c_in++; - } - } - } - } -} - -template -void bilinear_interpolate_gradient( - int height, - int width, - T y, - T x, - T& w1, - T& w2, - T& w3, - T& w4, - int& x_low, - int& x_high, - int& y_low, - int& y_high, - int index /* index for debug only*/) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - w1 = w2 = w3 = w4 = 0.; - x_low = x_high = y_low = y_high = -1; - return; - } - - if (y <= 0) - y = 0; - if (x <= 0) - x = 0; - - y_low = (int)y; - x_low = (int)x; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T)x_low; - } else { - x_high = x_low + 1; - } - - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - - // reference in forward - // T v1 = input[y_low * width + x_low]; - // T v2 = input[y_low * width + x_high]; - // T v3 = input[y_high * width + x_low]; - // T v4 = input[y_high * width + x_high]; - // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - - w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; -} - -template -inline void add(T* address, const T& val) { - *address += val; -} - -template -void ps_roi_align_backward_kernel_impl( - int nthreads, - const T* grad_output, - const int* channel_mapping, - const T spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - int sampling_ratio, - int channels_out, - T* grad_input, - const T* rois) { - for (int index = 0; index < nthreads; index++) { - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int n = index / pooled_width / pooled_height / channels_out; - - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); - T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); - T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); - T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); - - // Force too small ROIs to be 1x1 - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - T bin_size_h = roi_height / static_cast(pooled_height); - T bin_size_w = roi_width / static_cast(pooled_width); - - int c_in = channel_mapping[index]; - T* grad_input_offset = - grad_input + (roi_batch_ind * channels + c_in) * height * width; - - // Do not using floor/ceil; this implementation detail is critical - T hstart = static_cast(ph) * bin_size_h + roi_start_h; - T wstart = static_cast(pw) * bin_size_w + roi_start_w; - - const T grad_output_this_bin = grad_output[index]; - - // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 - int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - const T count = roi_bin_grid_h * roi_bin_grid_w; - - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = hstart + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = wstart + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - - T w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - - bilinear_interpolate_gradient( - height, - width, - y, - x, - w1, - w2, - w3, - w4, - x_low, - x_high, - y_low, - y_high, - index); - - T g1 = grad_output_this_bin * w1 / count; - T g2 = grad_output_this_bin * w2 / count; - T g3 = grad_output_this_bin * w3 / count; - T g4 = grad_output_this_bin * w4 / count; - - if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - add(grad_input_offset + y_low * width + x_low, g1); - add(grad_input_offset + y_low * width + x_high, g2); - add(grad_input_offset + y_high * width + x_low, g3); - add(grad_input_offset + y_high * width + x_high, g4); - } // if - } // ix - } // iy - } -} - -std::tuple ps_roi_align_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio) { - // Check if input tensors are CPU tensors - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK( - rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "ps_roi_align_forward_kernel"; - at::checkAllSameType(c, {input_t, rois_t}); - - int num_rois = rois.size(0); - int channels = input.size(1); - int height = input.size(2); - int width = input.size(3); - - TORCH_CHECK( - channels % (pooled_height * pooled_width) == 0, - "input channels must be a multiple of pooling height * pooling width"); - int channels_out = channels / (pooled_height * pooled_width); - - auto output = at::zeros( - {num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = - at::zeros(output.sizes(), input.options().dtype(at::kInt)); - - if (output.numel() == 0) { - return std::make_tuple(output, channel_mapping); - } - - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ps_roi_align_forward_kernel", [&] { - ps_roi_align_forward_kernel_impl( - num_rois, - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - rois_.data_ptr(), - channels_out, - output.data_ptr(), - channel_mapping.data_ptr()); - }); - return std::make_tuple(output, channel_mapping); -} - -at::Tensor ps_roi_align_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - // Check if input tensors are CPU tensors - TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK( - channel_mapping.device().is_cpu(), - "channel_mapping must be a CPU tensor"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, - channel_mapping_t{channel_mapping, "channel_mapping", 3}; - - at::CheckedFrom c = "ps_roi_align_backward_kernel"; - at::checkAllSameType(c, {grad_t, rois_t}); - - auto grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); - - // handle possibly empty gradients - if (grad.numel() == 0) { - return grad_input; - } - - int channels_out = channels / (pooled_height * pooled_width); - - auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ps_roi_align_backward_kernel", [&] { - ps_roi_align_backward_kernel_impl( - grad.numel(), - grad_.data_ptr(), - channel_mapping.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - channels_out, - grad_input.data_ptr(), - rois_.data_ptr()); - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), - TORCH_FN(ps_roi_align_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), - TORCH_FN(ps_roi_align_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp b/framework/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp deleted file mode 100644 index 607cbe4bab6..00000000000 --- a/framework/include/torchvision/ops/cpu/ps_roi_pool_kernel.cpp +++ /dev/null @@ -1,273 +0,0 @@ -#include -#include - -namespace vision { -namespace ops { - -namespace { - -template -inline void add(T* address, const T& val) { - *address += val; -} - -template -void ps_roi_pool_forward_kernel_impl( - const T* input, - const T spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - const T* rois, - int channels_out, - int num_rois, - T* output, - int* channel_mapping) { - for (int n = 0; n < num_rois; ++n) { - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - int roi_start_w = round(offset_rois[1] * spatial_scale); - int roi_start_h = round(offset_rois[2] * spatial_scale); - int roi_end_w = round(offset_rois[3] * spatial_scale); - int roi_end_h = round(offset_rois[4] * spatial_scale); - - // Force too small ROIs to be 1x1 - int roi_width = std::max(roi_end_w - roi_start_w, 1); - int roi_height = std::max(roi_end_h - roi_start_h, 1); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - int c_in = 0; - for (int c_out = 0; c_out < channels_out; ++c_out) { - for (int ph = 0; ph < pooled_height; ++ph) { - for (int pw = 0; pw < pooled_width; ++pw) { - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = - static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = - static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - // Add roi offsets and clip to input boundaries - hstart = std::min(std::max(hstart + roi_start_h, 0), height - 1); - hend = std::min(std::max(hend + roi_start_h, 0), height - 1); - wstart = std::min(std::max(wstart + roi_start_w, 0), width - 1); - wend = std::min(std::max(wend + roi_start_w, 0), width - 1); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - const T* offset_input = - input + (roi_batch_ind * channels + c_in) * height * width; - - T out_sum = 0; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - int input_index = h * width + w; - out_sum += offset_input[input_index]; - } - } - - int index = - ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + - pw; - T bin_area = (hend - hstart) * (wend - wstart); - output[index] = is_empty ? static_cast(0) : out_sum / bin_area; - channel_mapping[index] = c_in; - c_in++; - } - } - } - } -} - -template -void ps_roi_pool_backward_kernel_impl( - const T* grad_output, - const int* channel_mapping, - int num_rois, - const T spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - int channels_out, - T* grad_input, - const T* rois) { - for (int n = 0; n < num_rois; ++n) { - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - int roi_start_w = roundf(offset_rois[1] * spatial_scale); - int roi_start_h = roundf(offset_rois[2] * spatial_scale); - int roi_end_w = roundf(offset_rois[3] * spatial_scale); - int roi_end_h = roundf(offset_rois[4] * spatial_scale); - - // Force too small ROIs to be 1x1 - int roi_width = std::max(roi_end_w - roi_start_w, 1); - int roi_height = std::max(roi_end_h - roi_start_h, 1); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - for (int ph = 0; ph < pooled_height; ++ph) { - for (int pw = 0; pw < pooled_width; ++pw) { - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - // Add roi offsets and clip to input boundaries - hstart = std::min(std::max(hstart + roi_start_h, 0), height); - hend = std::min(std::max(hend + roi_start_h, 0), height); - wstart = std::min(std::max(wstart + roi_start_w, 0), width); - wend = std::min(std::max(wend + roi_start_w, 0), width); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - for (int c_out = 0; c_out < channels_out; ++c_out) { - int index = - ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + - pw; - int c_in = channel_mapping[index]; - - T* grad_input_offset = - grad_input + (roi_batch_ind * channels + c_in) * height * width; - T bin_area = (hend - hstart) * (wend - wstart); - T diff_val = - is_empty ? static_cast(0) : grad_output[index] / bin_area; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - int grad_input_index = h * width + w; - add(grad_input_offset + grad_input_index, diff_val); - } - } - } - } - } - } -} - -std::tuple ps_roi_pool_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - // Check if input tensors are CPU tensors - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK( - rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "ps_roi_pool_forward_kernel"; - at::checkAllSameType(c, {input_t, rois_t}); - - int num_rois = rois.size(0); - int channels = input.size(1); - int height = input.size(2); - int width = input.size(3); - - TORCH_CHECK( - channels % (pooled_height * pooled_width) == 0, - "input channels must be a multiple of pooling height * pooling width"); - int channels_out = channels / (pooled_height * pooled_width); - - auto output = at::zeros( - {num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = - at::zeros(output.sizes(), input.options().dtype(at::kInt)); - - auto output_size = output.numel(); - if (output_size == 0) { - return std::make_tuple(output, channel_mapping); - } - - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ps_roi_pool_forward_kernel", [&] { - ps_roi_pool_forward_kernel_impl( - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - rois_.data_ptr(), - channels_out, - num_rois, - output.data_ptr(), - channel_mapping.data_ptr()); - }); - return std::make_tuple(output, channel_mapping); -} - -at::Tensor ps_roi_pool_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - // Check if input tensors are CPU tensors - TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK( - channel_mapping.device().is_cpu(), - "channel_mapping must be a CPU tensor"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, - channel_mapping_t{channel_mapping, "channel_mapping", 3}; - - at::CheckedFrom c = "ps_roi_pool_backward_kernel"; - at::checkAllSameType(c, {grad_t, rois_t}); - - auto num_rois = rois.size(0); - auto grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); - - // handle possibly empty gradients - if (grad.numel() == 0) { - return grad_input; - } - - int channels_out = channels / (pooled_height * pooled_width); - - auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] { - ps_roi_pool_backward_kernel_impl( - grad_.data_ptr(), - channel_mapping.data_ptr(), - num_rois, - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - channels_out, - grad_input.data_ptr(), - rois_.data_ptr()); - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), - TORCH_FN(ps_roi_pool_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), - TORCH_FN(ps_roi_pool_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/roi_align_common.h b/framework/include/torchvision/ops/cpu/roi_align_common.h deleted file mode 100644 index e10c67b5b79..00000000000 --- a/framework/include/torchvision/ops/cpu/roi_align_common.h +++ /dev/null @@ -1,128 +0,0 @@ -#pragma once - -#include - -namespace vision { -namespace ops { -namespace detail { - -template -struct PreCalc { - int pos1; - int pos2; - int pos3; - int pos4; - T w1; - T w2; - T w3; - T w4; -}; - -// This helper computes the interpolation weights (w1, w2...) for every sampling -// point of a given box. There are pool_height * pool_width * roi_bin_grid_h * -// roi_bin_grid_w such sampling points. -// -// The weights (w1, w2...) are computed as the areas in this figure: -// https://en.wikipedia.org/wiki/Bilinear_interpolation#/media/File:Bilinear_interpolation_visualisation.svg -// and pos1, pos2 etc correspond to the indices of their respective pixels. -// -// Note: the weights and indices are shared across all channels, which is why -// they are pre-calculated prior to the main loop in the RoIAlign kernel. -// implementation taken from Caffe2 -template -void pre_calc_for_bilinear_interpolate( - int height, - int width, - int pooled_height, - int pooled_width, - T roi_start_h, - T roi_start_w, - T bin_size_h, - T bin_size_w, - int roi_bin_grid_h, - int roi_bin_grid_w, - std::vector>& pre_calc) { - int pre_calc_index = 0; - for (int ph = 0; ph < pooled_height; ph++) { - for (int pw = 0; pw < pooled_width; pw++) { - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T yy = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T xx = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - - T x = xx; - T y = yy; - // deal with: inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - PreCalc pc; - pc.pos1 = 0; - pc.pos2 = 0; - pc.pos3 = 0; - pc.pos4 = 0; - pc.w1 = 0; - pc.w2 = 0; - pc.w3 = 0; - pc.w4 = 0; - pre_calc[pre_calc_index] = pc; - pre_calc_index += 1; - continue; - } - - if (y <= 0) { - y = 0; - } - if (x <= 0) { - x = 0; - } - - int y_low = (int)y; - int x_low = (int)x; - int y_high; - int x_high; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T)x_low; - } else { - x_high = x_low + 1; - } - - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - - // save weights and indices - PreCalc pc; - pc.pos1 = y_low * width + x_low; - pc.pos2 = y_low * width + x_high; - pc.pos3 = y_high * width + x_low; - pc.pos4 = y_high * width + x_high; - pc.w1 = w1; - pc.w2 = w2; - pc.w3 = w3; - pc.w4 = w4; - pre_calc[pre_calc_index] = pc; - - pre_calc_index += 1; - } - } - } - } -} - -} // namespace detail -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/roi_align_kernel.cpp b/framework/include/torchvision/ops/cpu/roi_align_kernel.cpp deleted file mode 100644 index b787de6f6bb..00000000000 --- a/framework/include/torchvision/ops/cpu/roi_align_kernel.cpp +++ /dev/null @@ -1,400 +0,0 @@ -#include -#include - -#include "./roi_align_common.h" - -namespace vision { -namespace ops { - -namespace { - -template -void roi_align_forward_kernel_impl( - int n_rois, - const T* input, - const T& spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - int sampling_ratio, - bool aligned, - const T* rois, - T* output) { - // (n, c, ph, pw) is an element in the pooled output - // can be parallelized using omp - // #pragma omp parallel for num_threads(32) - for (int n = 0; n < n_rois; n++) { - int index_n = n * channels * pooled_width * pooled_height; - - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T offset = aligned ? (T)0.5 : (T)0.0; - T roi_start_w = offset_rois[1] * spatial_scale - offset; - T roi_start_h = offset_rois[2] * spatial_scale - offset; - T roi_end_w = offset_rois[3] * spatial_scale - offset; - T roi_end_h = offset_rois[4] * spatial_scale - offset; - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = std::max(roi_width, (T)1.); - roi_height = std::max(roi_height, (T)1.); - } - - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 - int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - - // We do average (integral) pooling inside a bin - // When the grid is empty, output zeros. - const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 - - // we want to precalculate indices and weights shared by all channels, - // this is the key point of optimization - std::vector> pre_calc( - roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); - detail::pre_calc_for_bilinear_interpolate( - height, - width, - pooled_height, - pooled_width, - roi_start_h, - roi_start_w, - bin_size_h, - bin_size_w, - roi_bin_grid_h, - roi_bin_grid_w, - pre_calc); - - for (int c = 0; c < channels; c++) { - int index_n_c = index_n + c * pooled_width * pooled_height; - const T* offset_input = - input + (roi_batch_ind * channels + c) * height * width; - int pre_calc_index = 0; - - for (int ph = 0; ph < pooled_height; ph++) { - for (int pw = 0; pw < pooled_width; pw++) { - int index = index_n_c + ph * pooled_width + pw; - - T output_val = 0.; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - detail::PreCalc pc = pre_calc[pre_calc_index]; - output_val += pc.w1 * offset_input[pc.pos1] + - pc.w2 * offset_input[pc.pos2] + - pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4]; - - pre_calc_index += 1; - } - } - output_val /= count; // Average pooling - - output[index] = output_val; - } // for pw - } // for ph - } // for c - } // for n -} - -template -void bilinear_interpolate_gradient( - int height, - int width, - T y, - T x, - T& w1, - T& w2, - T& w3, - T& w4, - int& x_low, - int& x_high, - int& y_low, - int& y_high, - int index /* index for debug only*/) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - w1 = w2 = w3 = w4 = 0.; - x_low = x_high = y_low = y_high = -1; - return; - } - - if (y <= 0) - y = 0; - if (x <= 0) - x = 0; - - y_low = (int)y; - x_low = (int)x; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T)x_low; - } else { - x_high = x_low + 1; - } - - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - - // reference in forward - // T v1 = input[y_low * width + x_low]; - // T v2 = input[y_low * width + x_high]; - // T v3 = input[y_high * width + x_low]; - // T v4 = input[y_high * width + x_high]; - // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - - w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; -} - -template -inline void add(T* address, const T& val) { - *address += val; -} - -template -void roi_align_backward_kernel_impl( - int nthreads, - const T* grad_output, - const T& spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - int sampling_ratio, - bool aligned, - T* grad_input, - const T* rois, - int n_stride, - int c_stride, - int h_stride, - int w_stride) { - for (int index = 0; index < nthreads; index++) { - // (n, c, ph, pw) is an element in the pooled output - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; - - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T offset = aligned ? (T)0.5 : (T)0.0; - T roi_start_w = offset_rois[1] * spatial_scale - offset; - T roi_start_h = offset_rois[2] * spatial_scale - offset; - T roi_end_w = offset_rois[3] * spatial_scale - offset; - T roi_end_h = offset_rois[4] * spatial_scale - offset; - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = std::max(roi_width, (T)1.); - roi_height = std::max(roi_height, (T)1.); - } - - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - T* offset_grad_input = - grad_input + ((roi_batch_ind * channels + c) * height * width); - - int output_offset = n * n_stride + c * c_stride; - const T* offset_grad_output = grad_output + output_offset; - const T grad_output_this_bin = - offset_grad_output[ph * h_stride + pw * w_stride]; - - // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 - int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - - // We do average (integral) pooling inside a bin - const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 - - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - - T w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - - bilinear_interpolate_gradient( - height, - width, - y, - x, - w1, - w2, - w3, - w4, - x_low, - x_high, - y_low, - y_high, - index); - - T g1 = grad_output_this_bin * w1 / count; - T g2 = grad_output_this_bin * w2 / count; - T g3 = grad_output_this_bin * w3 / count; - T g4 = grad_output_this_bin * w4 / count; - - if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - // atomic add is not needed for now since it is single threaded - add(offset_grad_input + y_low * width + x_low, static_cast(g1)); - add(offset_grad_input + y_low * width + x_high, static_cast(g2)); - add(offset_grad_input + y_high * width + x_low, static_cast(g3)); - add(offset_grad_input + y_high * width + x_high, static_cast(g4)); - } // if - } // ix - } // iy - } // for -} - -at::Tensor roi_align_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned) { - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_align_forward_kernel"; - at::checkAllSameType(c, {input_t, rois_t}); - - auto num_rois = rois.size(0); - auto channels = input.size(1); - auto height = input.size(2); - auto width = input.size(3); - - at::Tensor output = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, input.options()); - - if (output.numel() == 0) - return output; - - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "roi_align_forward_kernel", [&] { - roi_align_forward_kernel_impl( - num_rois, - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - aligned, - rois_.data_ptr(), - output.data_ptr()); - }); - return output; -} - -at::Tensor roi_align_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned) { - TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_align_backward_kernel"; - at::checkAllSameType(c, {grad_t, rois_t}); - - at::Tensor grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); - - // handle possibly empty gradients - if (grad.numel() == 0) { - return grad_input; - } - - // get stride values to ensure indexing into gradients is correct. - int n_stride = grad.stride(0); - int c_stride = grad.stride(1); - int h_stride = grad.stride(2); - int w_stride = grad.stride(3); - - auto rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "roi_align_backward_kernel", [&] { - roi_align_backward_kernel_impl( - grad.numel(), - grad.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - sampling_ratio, - aligned, - grad_input.data_ptr(), - rois_.data_ptr(), - n_stride, - c_stride, - h_stride, - w_stride); - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_align"), - TORCH_FN(roi_align_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), - TORCH_FN(roi_align_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/cpu/roi_pool_kernel.cpp b/framework/include/torchvision/ops/cpu/roi_pool_kernel.cpp deleted file mode 100644 index b099523896a..00000000000 --- a/framework/include/torchvision/ops/cpu/roi_pool_kernel.cpp +++ /dev/null @@ -1,249 +0,0 @@ -#include - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -template -inline void add(T* address, const T& val) { - *address += val; -} - -template -void roi_pool_forward_kernel_impl( - const T* input, - const T spatial_scale, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - const T* rois, - int num_rois, - T* output, - int* argmax_data) { - for (int n = 0; n < num_rois; ++n) { - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - int roi_start_w = round(offset_rois[1] * spatial_scale); - int roi_start_h = round(offset_rois[2] * spatial_scale); - int roi_end_w = round(offset_rois[3] * spatial_scale); - int roi_end_h = round(offset_rois[4] * spatial_scale); - - // Force malformed ROIs to be 1x1 - int roi_width = std::max(roi_end_w - roi_start_w + 1, 1); - int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - for (int ph = 0; ph < pooled_height; ++ph) { - for (int pw = 0; pw < pooled_width; ++pw) { - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - // Add roi offsets and clip to input boundaries - hstart = std::min(std::max(hstart + roi_start_h, 0), height); - hend = std::min(std::max(hend + roi_start_h, 0), height); - wstart = std::min(std::max(wstart + roi_start_w, 0), width); - wend = std::min(std::max(wend + roi_start_w, 0), width); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - for (int c = 0; c < channels; ++c) { - // Define an empty pooling region to be zero - T maxval = is_empty ? 0 : -FLT_MAX; - // If nothing is pooled, argmax = -1 causes nothing to be backprop'd - int maxidx = -1; - - const T* input_offset = - input + (roi_batch_ind * channels + c) * height * width; - - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - int input_index = h * width + w; - if (input_offset[input_index] > maxval) { - maxval = input_offset[input_index]; - maxidx = input_index; - } - } - } - int index = - ((n * channels + c) * pooled_height + ph) * pooled_width + pw; - output[index] = maxval; - argmax_data[index] = maxidx; - } // channels - } // pooled_width - } // pooled_height - } // num_rois -} - -template -void roi_pool_backward_kernel_impl( - const T* grad_output, - const int* argmax_data, - int num_rois, - int channels, - int height, - int width, - int pooled_height, - int pooled_width, - T* grad_input, - const T* rois, - int n_stride, - int c_stride, - int h_stride, - int w_stride) { - for (int n = 0; n < num_rois; ++n) { - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - - for (int c = 0; c < channels; ++c) { - T* grad_input_offset = - grad_input + ((roi_batch_ind * channels + c) * height * width); - const int* argmax_data_offset = - argmax_data + (n * channels + c) * pooled_height * pooled_width; - - for (int ph = 0; ph < pooled_height; ++ph) { - for (int pw = 0; pw < pooled_width; ++pw) { - int output_offset = n * n_stride + c * c_stride; - int argmax = argmax_data_offset[ph * pooled_width + pw]; - - if (argmax != -1) { - add(grad_input_offset + argmax, - static_cast( - grad_output - [output_offset + ph * h_stride + pw * w_stride])); - } - } // pooled_width - } // pooled_height - } // channels - } // num_rois -} - -std::tuple roi_pool_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_pool_forward_kernel"; - at::checkAllSameType(c, {input_t, rois_t}); - - int num_rois = rois.size(0); - int channels = input.size(1); - int height = input.size(2); - int width = input.size(3); - - at::Tensor output = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, input.options()); - at::Tensor argmax = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, - input.options().dtype(at::kInt)); - - if (output.numel() == 0) { - return std::make_tuple(output, argmax); - } - - auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "roi_pool_forward_kernel", [&] { - roi_pool_forward_kernel_impl( - input_.data_ptr(), - spatial_scale, - channels, - height, - width, - pooled_height, - pooled_width, - rois_.data_ptr(), - num_rois, - output.data_ptr(), - argmax.data_ptr()); - }); - return std::make_tuple(output, argmax); -} - -at::Tensor roi_pool_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - // Check if input tensors are CPU tensors - TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); - TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); - TORCH_CHECK(argmax.device().is_cpu(), "argmax must be a CPU tensor"); - TORCH_CHECK( - rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_pool_backward_kernel"; - at::checkAllSameType(c, {grad_t, rois_t}); - - auto num_rois = rois.size(0); - - at::Tensor grad_input = - at::zeros({batch_size, channels, height, width}, grad.options()); - - // handle possibly empty gradients - if (grad.numel() == 0) { - return grad_input; - } - - // get stride values to ensure indexing into gradients is correct. - int n_stride = grad.stride(0); - int c_stride = grad.stride(1); - int h_stride = grad.stride(2); - int w_stride = grad.stride(3); - - auto rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "roi_pool_backward_kernel", [&] { - roi_pool_backward_kernel_impl( - grad.data_ptr(), - argmax.data_ptr(), - num_rois, - channels, - height, - width, - pooled_height, - pooled_width, - grad_input.data_ptr(), - rois_.data_ptr(), - n_stride, - c_stride, - h_stride, - w_stride); - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_pool"), - TORCH_FN(roi_pool_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), - TORCH_FN(roi_pool_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/deform_conv2d.cpp b/framework/include/torchvision/ops/deform_conv2d.cpp deleted file mode 100644 index 3cda60fe0bc..00000000000 --- a/framework/include/torchvision/ops/deform_conv2d.cpp +++ /dev/null @@ -1,172 +0,0 @@ -#include "deform_conv2d.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -at::Tensor deform_conv2d( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, - bool use_mask) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::deform_conv2d", "") - .typed(); - return op.call( - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); -} - -at::Tensor deform_conv2d_symint( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::deform_conv2d", "") - .typed(); - return op.call( - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); -} - -namespace detail { - -std::tuple -_deform_conv2d_backward( - const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, - bool use_mask) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") - .typed(); - return op.call( - grad, - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); -} - -std::tuple -_deform_conv2d_backward_symint( - const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") - .typed(); - return op.call( - grad, - input, - weight, - offset, - mask, - bias, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - groups, - offset_groups, - use_mask); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)")); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/deform_conv2d.h b/framework/include/torchvision/ops/deform_conv2d.h deleted file mode 100644 index cf1f142e648..00000000000 --- a/framework/include/torchvision/ops/deform_conv2d.h +++ /dev/null @@ -1,82 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API at::Tensor deform_conv2d( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, - bool use_mask); - -VISION_API at::Tensor deform_conv2d_symint( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask); - -namespace detail { - -std::tuple -_deform_conv2d_backward( - const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, - bool use_mask); - -std::tuple -_deform_conv2d_backward_symint( - const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - c10::SymInt stride_h, - c10::SymInt stride_w, - c10::SymInt pad_h, - c10::SymInt pad_w, - c10::SymInt dilation_h, - c10::SymInt dilation_w, - c10::SymInt groups, - c10::SymInt offset_groups, - bool use_mask); - -} // namespace detail - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/mps/mps_helpers.h b/framework/include/torchvision/ops/mps/mps_helpers.h deleted file mode 100644 index d3c0e8d94b7..00000000000 --- a/framework/include/torchvision/ops/mps/mps_helpers.h +++ /dev/null @@ -1,6 +0,0 @@ -constexpr int threadsPerBlock = 512; - -template -constexpr inline T ceil_div(T n, T m) { - return (n + m - 1) / m; -} diff --git a/framework/include/torchvision/ops/mps/mps_kernels.h b/framework/include/torchvision/ops/mps/mps_kernels.h deleted file mode 100644 index e720a1608f1..00000000000 --- a/framework/include/torchvision/ops/mps/mps_kernels.h +++ /dev/null @@ -1,1102 +0,0 @@ -#include - -namespace vision { -namespace ops { - -namespace mps { - -static const char* METAL_VISION = R"VISION_METAL( - -#include -#include -using namespace metal; - -/*----------Macros----------*/ - -#define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \ - for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \ - i += (tptg.x * n_tgs)) - -#define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, uint) - -/*----------Helpers--------*/ - -template -inline T ceil_div(T n, T m) { - return (n + m - 1) / m; -} - -template -inline void atomic_add_float( device T* data_ptr, const T val) -{ -#if __METAL_VERSION__ >= 300 - // atomic_float is supported in Metal 3 (macOS Ventura) onward. - device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); -#else - // Custom atomic addition implementation - // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 - // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639 - // https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide) - - // Create an atomic uint pointer for atomic transaction. - device atomic_uint* atom_var = (device atomic_uint*)data_ptr; - // Create necessary storage. - uint fetched_uint, assigning_uint; - T fetched_float, assigning_float; - - // Replace the value in atom_var with 0 and return the previous value in atom_var. - fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed); - // Read out the previous value as float. - fetched_float = *( (thread T*) &fetched_uint ); - - // Do addition and represent the addition result in uint for atomic transaction. - assigning_float = fetched_float + val; - assigning_uint = *((thread uint*) &assigning_float); - - // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr). - while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) { - // If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads. - // Try to assign 0 and get the previously assigned addition result. - uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed); - T fetched_float_again = *( (thread T*) &fetched_uint_again ); - // Re-add again - fetched_float = *((thread T*) &(fetched_uint)); - // Previously assigned addition result + addition result from other threads. - assigning_float = fetched_float_again + fetched_float; - assigning_uint = *( (thread uint*) &assigning_float); - } -#endif -} - -template -inline T bilinear_interpolate( - constant T* input, - integer_t height, - integer_t width, - T y, - T x, - uint index /* index for debug only*/) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - return 0; - } - - if (y <= 0) - y = 0; - if (x <= 0) - x = 0; - - integer_t y_low = (integer_t)y; - integer_t x_low = (integer_t)x; - integer_t y_high; - integer_t x_high; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T)x_low; - } else { - x_high = x_low + 1; - } - - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - - // do bilinear interpolation - T v1 = input[y_low * width + x_low]; - T v2 = input[y_low * width + x_high]; - T v3 = input[y_high * width + x_low]; - T v4 = input[y_high * width + x_high]; - T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - - T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - - return val; -} - -template -inline void bilinear_interpolate_gradient( - integer_t height, - integer_t width, - T y, - T x, - thread T& w1, - thread T& w2, - thread T& w3, - thread T& w4, - thread integer_t& x_low, - thread integer_t& x_high, - thread integer_t& y_low, - thread integer_t& y_high, - uint index /* index for debug only*/) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - w1 = w2 = w3 = w4 = 0.; - x_low = x_high = y_low = y_high = -1; - return; - } - - if (y <= 0) - y = 0; - if (x <= 0) - x = 0; - - y_low = (integer_t)y; - x_low = (integer_t)x; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T)x_low; - } else { - x_high = x_low + 1; - } - - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - - // reference in forward - // T v1 = input[y_low * width + x_low]; - // T v2 = input[y_low * width + x_high]; - // T v3 = input[y_high * width + x_low]; - // T v4 = input[y_high * width + x_high]; - // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - - w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; -} - -template -inline bool IoU( - constant T & a, - threadgroup T & b, - const float threshold) { - auto xx1 = max(a.x, b.x); - auto yy1 = max(a.y, b.y); - auto xx2 = min(a.z, b.z); - auto yy2 = min(a.w, b.w); - auto w = max(static_cast(0), xx2 - xx1); - auto h = max(static_cast(0), yy2 - yy1); - // Upcast to float before multiplications to circumvent precision issues in half. - auto inter = static_cast(w) * static_cast(h); - auto area_b = static_cast(b.z - b.x) * static_cast(b.w - b.y); - auto area_a = static_cast(a.z - a.x) * static_cast(a.w - a.y); - return (inter / (area_a + area_b - inter)) > threshold; -} - -/*----------Kernels----------*/ - -// This should be in sync with the one in nms_kernel.mm. -// Since metal does not support dynamic array, -// we need to make it static instead of deriving it from [[threads_per_threadgroup]]. -constant int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8; - -template -kernel void nms(constant T * dev_boxes [[buffer(0)]], - device uint64_t * mask [[buffer(1)]], - constant int64_t & n_boxes [[buffer(2)]], - constant float & iou_threshold [[buffer(3)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tid2 [[thread_position_in_threadgroup]]) { - - const uint row_start = tgid.y; - const uint col_start = tgid.x; - const uint tid = tid2.x; - const uint row_size = - min(n_boxes - row_start * nmsThreadsPerBlock, nmsThreadsPerBlock); - const uint col_size = - min(n_boxes - col_start * nmsThreadsPerBlock, nmsThreadsPerBlock); - - threadgroup T block_boxes[nmsThreadsPerBlock]; - block_boxes[tid] = dev_boxes[nmsThreadsPerBlock * col_start + tid]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tid < row_size) { - const uint cur_box_idx = nmsThreadsPerBlock * row_start + tid; - uint64_t t = 0; - uint start = 0; - - if (row_start == col_start) { - start = tid + 1; - } - - for (uint i = start; i < col_size; i++){ - if (IoU(dev_boxes[cur_box_idx], block_boxes[i], iou_threshold)){ - t |= static_cast(1) << i; // discard 1 keep 0 - } - } - const uint col_blocks = ceil_div(n_boxes, nmsThreadsPerBlock); - mask[cur_box_idx * col_blocks + col_start] = t; - } -} - -#define REGISTER_NMS_OP(DTYPE) \ -template \ -[[host_name("nms_" #DTYPE)]] \ -kernel void nms( \ - constant DTYPE ## 4 * dev_boxes [[buffer(0)]], \ - device uint64_t * mask [[buffer(1)]], \ - constant int64_t & n_boxes [[buffer(2)]], \ - constant float & iou_threshold [[buffer(3)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void roi_align( - constant T * input [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * output [[buffer(2)]], - constant int64_t & output_size [[buffer(3)]], - constant int64_t & channels [[buffer(4)]], - constant int64_t & height [[buffer(5)]], - constant int64_t & width [[buffer(6)]], - constant int64_t & pooled_height [[buffer(7)]], - constant int64_t & pooled_width [[buffer(8)]], - constant int64_t & sampling_ratio [[buffer(9)]], - constant bool & aligned [[buffer(10)]], - constant float & spatial_scale [[buffer(11)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, c, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t c = (index / pooled_width / pooled_height) % channels; - integer_t n = index / pooled_width / pooled_height / channels; - - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T offset = aligned ? (T)0.5 : (T)0.0; - T roi_start_w = offset_rois[1] * spatial_scale - offset; - T roi_start_h = offset_rois[2] * spatial_scale - offset; - T roi_end_w = offset_rois[3] * spatial_scale - offset; - T roi_end_h = offset_rois[4] * spatial_scale - offset; - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = max(roi_width, (T)1.); - roi_height = max(roi_height, (T)1.); - } - - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - constant T* offset_input = - input + (roi_batch_ind * channels + c) * height * width; - - // We use roi_bin_grid to sample the grid and mimic integral - integer_t roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 - integer_t roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - - // We do average (integral) pooling inside a bin - // When the grid is empty, output zeros. - const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast(1)); // e.g. = 4 - - T output_val = 0.; - for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 - { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - - T val = bilinear_interpolate(offset_input, height, width, y, x, index); - output_val += val; - } - } - output_val /= count; - - output[index] = output_val; - } -} - -#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("roi_align_" #DTYPE)]] \ -kernel void roi_align( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - constant int64_t & output_size [[buffer(3)]], \ - constant int64_t & channels [[buffer(4)]], \ - constant int64_t & height [[buffer(5)]], \ - constant int64_t & width [[buffer(6)]], \ - constant int64_t & pooled_height [[buffer(7)]], \ - constant int64_t & pooled_width [[buffer(8)]], \ - constant int64_t & sampling_ratio [[buffer(9)]], \ - constant bool & aligned [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void roi_align_backward( - constant T * grad_output [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * grad_input [[buffer(2)]], - constant int64_t & output_size [[buffer(3)]], - constant int64_t & channels [[buffer(4)]], - constant int64_t & height [[buffer(5)]], - constant int64_t & width [[buffer(6)]], - constant int64_t & pooled_height [[buffer(7)]], - constant int64_t & pooled_width [[buffer(8)]], - constant int64_t & sampling_ratio [[buffer(9)]], - constant bool & aligned [[buffer(10)]], - constant float & spatial_scale [[buffer(11)]], - constant int64_t & n_stride [[buffer(12)]], - constant int64_t & c_stride [[buffer(13)]], - constant int64_t & h_stride [[buffer(14)]], - constant int64_t & w_stride [[buffer(15)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, c, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t c = (index / pooled_width / pooled_height) % channels; - integer_t n = index / pooled_width / pooled_height / channels; - - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T offset = aligned ? (T)0.5 : (T)0.0; - T roi_start_w = offset_rois[1] * spatial_scale - offset; - T roi_start_h = offset_rois[2] * spatial_scale - offset; - T roi_end_w = offset_rois[3] * spatial_scale - offset; - T roi_end_h = offset_rois[4] * spatial_scale - offset; - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = max(roi_width, (T)1.); - roi_height = max(roi_height, (T)1.); - } - - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - // We need to index the gradient using the tensor strides to access the - // correct values. - const integer_t output_offset = n * n_stride + c * c_stride; - constant T* offset_grad_output = grad_output + output_offset; - const T grad_output_this_bin = - offset_grad_output[ph * h_stride + pw * w_stride]; - - // We use roi_bin_grid to sample the grid and mimic integral - integer_t roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 - integer_t roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - - // We do average (integral) pooling inside a bin - const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 - - const integer_t input_offset = (roi_batch_ind * channels + c) * height * width; - - for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 - { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - - T w1, w2, w3, w4; - integer_t x_low, x_high, y_low, y_high; - - bilinear_interpolate_gradient( - height, - width, - y, - x, - w1, - w2, - w3, - w4, - x_low, - x_high, - y_low, - y_high, - index); - - T g1 = grad_output_this_bin * w1 / count; - T g2 = grad_output_this_bin * w2 / count; - T g3 = grad_output_this_bin * w3 / count; - T g4 = grad_output_this_bin * w4 / count; - - if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - atomic_add_float(grad_input + input_offset + y_low * width + x_low, static_cast(g1)); - atomic_add_float(grad_input + input_offset + y_low * width + x_high, static_cast(g2)); - atomic_add_float(grad_input + input_offset + y_high * width + x_low, static_cast(g3)); - atomic_add_float(grad_input + input_offset + y_high * width + x_high, static_cast(g4)); - - } // if - } // ix - } // iy - } // MPS_1D_KERNEL_LOOP -} - -#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("roi_align_backward_" #DTYPE)]] \ -kernel void roi_align_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * grad_input [[buffer(2)]], \ - constant int64_t & output_size [[buffer(3)]], \ - constant int64_t & channels [[buffer(4)]], \ - constant int64_t & height [[buffer(5)]], \ - constant int64_t & width [[buffer(6)]], \ - constant int64_t & pooled_height [[buffer(7)]], \ - constant int64_t & pooled_width [[buffer(8)]], \ - constant int64_t & sampling_ratio [[buffer(9)]], \ - constant bool & aligned [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - constant int64_t & n_stride [[buffer(12)]], \ - constant int64_t & c_stride [[buffer(13)]], \ - constant int64_t & h_stride [[buffer(14)]], \ - constant int64_t & w_stride [[buffer(15)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void roi_pool( - constant T * input [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * output [[buffer(2)]], - device int64_t * argmax [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant float & spatial_scale [[buffer(10)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, c, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t c = (index / pooled_width / pooled_height) % channels; - integer_t n = index / pooled_width / pooled_height / channels; - - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - integer_t roi_start_w = round(offset_rois[1] * spatial_scale); - integer_t roi_start_h = round(offset_rois[2] * spatial_scale); - integer_t roi_end_w = round(offset_rois[3] * spatial_scale); - integer_t roi_end_h = round(offset_rois[4] * spatial_scale); - - // Force malformed ROIs to be 1x1 - integer_t roi_width = max(roi_end_w - roi_start_w + 1, static_cast(1)); - integer_t roi_height = max(roi_end_h - roi_start_h + 1, static_cast(1)); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - // Add roi offsets and clip to input boundaries - hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height)); - hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height)); - wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width)); - wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width)); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - // Define an empty pooling region to be zero - T maxval = is_empty ? 0 : -FLT_MAX; - // If nothing is pooled, argmax = -1 causes nothing to be backprop'd - integer_t maxidx = -1; - constant T* offset_input = - input + (roi_batch_ind * channels + c) * height * width; - for (integer_t h = hstart; h < hend; ++h) { - for (integer_t w = wstart; w < wend; ++w) { - integer_t input_index = h * width + w; - if (offset_input[input_index] > maxval) { - maxval = offset_input[input_index]; - maxidx = input_index; - } - } - } - output[index] = maxval; - argmax[index] = maxidx; - } -} - -#define REGISTER_ROI_POOL_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("roi_pool_" #DTYPE)]] \ -kernel void roi_pool( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - device int64_t * argmax_data [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant float & spatial_scale [[buffer(10)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void roi_pool_backward( - constant T * grad_output [[buffer(0)]], - constant T * rois [[buffer(1)]], - constant int64_t * argmax_data [[buffer(2)]], - device T * grad_input [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant float & spatial_scale [[buffer(10)]], - constant int64_t & n_stride [[buffer(11)]], - constant int64_t & c_stride [[buffer(12)]], - constant int64_t & h_stride [[buffer(13)]], - constant int64_t & w_stride [[buffer(14)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, c, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t c = (index / pooled_width / pooled_height) % channels; - integer_t n = index / pooled_width / pooled_height / channels; - - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - - const integer_t output_offset = n * n_stride + c * c_stride; - constant integer_t * argmax_data_offset = - argmax_data + (n * channels + c) * pooled_height * pooled_width; - const integer_t argmax = argmax_data_offset[ph * pooled_width + pw]; - const integer_t offset = (roi_batch_ind * channels + c) * height * width; - - if (argmax != -1) { - atomic_add_float(grad_input + offset + argmax, static_cast(grad_output[output_offset + ph * h_stride + pw * w_stride])); - } - - } // MPS_1D_KERNEL_LOOP -} - -#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("roi_pool_backward_" #DTYPE)]] \ -kernel void roi_pool_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - constant int64_t * argmax_data [[buffer(2)]], \ - device DTYPE * grad_input [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant float & spatial_scale [[buffer(10)]], \ - constant int64_t & n_stride [[buffer(11)]], \ - constant int64_t & c_stride [[buffer(12)]], \ - constant int64_t & h_stride [[buffer(13)]], \ - constant int64_t & w_stride [[buffer(14)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void ps_roi_align( - constant T * input [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * output [[buffer(2)]], - device int64_t * channel_mapping [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant int64_t & sampling_ratio [[buffer(10)]], - constant int64_t & channels_out [[buffer(11)]], - constant float & spatial_scale [[buffer(12)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, c_out, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t c_out = (index / pooled_width / pooled_height) % channels_out; - integer_t n = index / pooled_width / pooled_height / channels_out; - - // (n, c_in, ph, pw) is the associated element in the input - integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw; - - // [start, end) interval for spatial sampling - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); - T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); - T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); - T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - // Do not using floor/ceil; this implementation detail is critical - T hstart = static_cast(ph) * bin_size_h + roi_start_h; - T wstart = static_cast(pw) * bin_size_w + roi_start_w; - - // We use roi_bin_grid to sample the grid and mimic integral - integer_t roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); - integer_t roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - const T count = roi_bin_grid_h * roi_bin_grid_w; - - constant T* offset_input = - input + (roi_batch_ind * channels + c_in) * height * width; - T out_sum = 0; - for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = hstart + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); - for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = wstart + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - T val = bilinear_interpolate(offset_input, height, width, y, x, index); - out_sum += val; - } - } - - out_sum /= count; - output[index] = out_sum; - channel_mapping[index] = c_in; - } -} - -#define REGISTER_PS_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("ps_roi_align_" #DTYPE)]] \ -kernel void ps_roi_align( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - device int64_t * channel_mapping [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & sampling_ratio [[buffer(10)]], \ - constant int64_t & channels_out [[buffer(11)]], \ - constant float & spatial_scale [[buffer(12)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void ps_roi_align_backward( - constant T * grad_output [[buffer(0)]], - constant T * rois [[buffer(1)]], - constant int64_t * channel_mapping [[buffer(2)]], - device T * grad_input [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant int64_t & sampling_ratio [[buffer(10)]], - constant int64_t & channels_out [[buffer(11)]], - constant float & spatial_scale [[buffer(12)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, *, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t n = index / pooled_width / pooled_height / channels_out; - - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); - T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); - T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); - T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); - - // Force too small ROIs to be 1x1 - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - T bin_size_h = roi_height / static_cast(pooled_height); - T bin_size_w = roi_width / static_cast(pooled_width); - - integer_t c_in = channel_mapping[index]; - - // Do not using floor/ceil; this implementation detail is critical - T hstart = static_cast(ph) * bin_size_h + roi_start_h; - T wstart = static_cast(pw) * bin_size_w + roi_start_w; - - const T grad_output_this_bin = grad_output[index]; - - // We use roi_bin_grid to sample the grid and mimic integral - integer_t roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 - integer_t roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - const T count = roi_bin_grid_h * roi_bin_grid_w; - - const integer_t offset = (roi_batch_ind * channels + c_in) * height * width; - - for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = hstart + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); - for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = wstart + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); - - T w1, w2, w3, w4; - integer_t x_low, x_high, y_low, y_high; - - bilinear_interpolate_gradient( - height, - width, - y, - x, - w1, - w2, - w3, - w4, - x_low, - x_high, - y_low, - y_high, - index); - - T g1 = grad_output_this_bin * w1 / count; - T g2 = grad_output_this_bin * w2 / count; - T g3 = grad_output_this_bin * w3 / count; - T g4 = grad_output_this_bin * w4 / count; - - if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - atomic_add_float(grad_input + offset + y_low * width + x_low, static_cast(g1)); - atomic_add_float(grad_input + offset + y_low * width + x_high, static_cast(g2)); - atomic_add_float(grad_input + offset + y_high * width + x_low, static_cast(g3)); - atomic_add_float(grad_input + offset + y_high * width + x_high, static_cast(g4)); - } // if - } // ix - } // iy - } -} - -#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("ps_roi_align_backward_" #DTYPE)]] \ -kernel void ps_roi_align_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - constant int64_t * channel_mapping [[buffer(2)]], \ - device DTYPE * grad_input [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & sampling_ratio [[buffer(10)]], \ - constant int64_t & channels_out [[buffer(11)]], \ - constant float & spatial_scale [[buffer(12)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void ps_roi_pool( - constant T * input [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * output [[buffer(2)]], - device int64_t * channel_mapping [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant int64_t & channels_out [[buffer(10)]], - constant float & spatial_scale [[buffer(11)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, c_out, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t c_out = (index / (pooled_width * pooled_height)) % channels_out; - integer_t n = index / pooled_width / pooled_height / channels_out; - - // (n, c_in, ph, pw) is the associated element in the input - integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw; - - // [start, end) interval for spatial sampling - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - integer_t roi_start_w = round(offset_rois[1] * spatial_scale); - integer_t roi_start_h = round(offset_rois[2] * spatial_scale); - integer_t roi_end_w = round(offset_rois[3] * spatial_scale); - integer_t roi_end_h = round(offset_rois[4] * spatial_scale); - - // Force too small ROIs to be 1x1 - integer_t roi_width = max(roi_end_w - roi_start_w, static_cast(1)); - integer_t roi_height = max(roi_end_h - roi_start_h, static_cast(1)); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - // Add roi offsets and clip to input boundaries - hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height - 1)); - hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height - 1)); - wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width - 1)); - wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width - 1)); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - constant T* offset_input = - input + (roi_batch_ind * channels + c_in) * height * width; - T out_sum = 0; - for (integer_t h = hstart; h < hend; ++h) { - for (integer_t w = wstart; w < wend; ++w) { - integer_t input_index = h * width + w; - out_sum += offset_input[input_index]; - } - } - - T bin_area = (hend - hstart) * (wend - wstart); - output[index] = is_empty ? static_cast(0) : out_sum / bin_area; - channel_mapping[index] = c_in; - } -} - -#define REGISTER_PS_ROI_POOL_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("ps_roi_pool_" #DTYPE)]] \ -kernel void ps_roi_pool( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - device int64_t * channel_mapping [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & channels_out [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -template -kernel void ps_roi_pool_backward( - constant T * grad_output [[buffer(0)]], - constant T * rois [[buffer(1)]], - constant int64_t * channel_mapping [[buffer(2)]], - device T * grad_input [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant int64_t & channels_out [[buffer(10)]], - constant float & spatial_scale [[buffer(11)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, *, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t n = index / pooled_width / pooled_height / channels_out; - - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - integer_t roi_start_w = round(offset_rois[1] * spatial_scale); - integer_t roi_start_h = round(offset_rois[2] * spatial_scale); - integer_t roi_end_w = round(offset_rois[3] * spatial_scale); - integer_t roi_end_h = round(offset_rois[4] * spatial_scale); - - // Force too small ROIs to be 1x1 - integer_t roi_width = max(roi_end_w - roi_start_w, static_cast(1)); - integer_t roi_height = max(roi_end_h - roi_start_h, static_cast(1)); - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - // Add roi offsets and clip to input boundaries - hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height)); - hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height)); - wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width)); - wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width)); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - integer_t c_in = channel_mapping[index]; - T bin_area = (hend - hstart) * (wend - wstart); - T diff_val = is_empty ? static_cast(0) : grad_output[index] / bin_area; - - const integer_t offset = (roi_batch_ind * channels + c_in) * height * width; - - for (integer_t h = hstart; h < hend; ++h) { - for (integer_t w = wstart; w < wend; ++w) { - integer_t grad_input_index = h * width + w; - atomic_add_float(grad_input + offset + grad_input_index, diff_val); - } - } - - } // MPS_1D_KERNEL_LOOP -} - -#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("ps_roi_pool_backward_" #DTYPE)]] \ -kernel void ps_roi_pool_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - constant int64_t * channel_mapping [[buffer(2)]], \ - device DTYPE * grad_input [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & channels_out [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - -REGISTER_NMS_OP(float); -REGISTER_NMS_OP(half); -REGISTER_ROI_ALIGN_OP(float, int64_t); -REGISTER_ROI_ALIGN_OP(half, int64_t); -REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t); -REGISTER_ROI_ALIGN_BACKWARD_OP(half, int64_t); -REGISTER_ROI_POOL_OP(float, int64_t); -REGISTER_ROI_POOL_OP(half, int64_t); -REGISTER_ROI_POOL_BACKWARD_OP(float, int64_t); -REGISTER_ROI_POOL_BACKWARD_OP(half, int64_t); -REGISTER_PS_ROI_ALIGN_OP(float, int64_t); -REGISTER_PS_ROI_ALIGN_OP(half, int64_t); -REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float, int64_t); -REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half, int64_t); -REGISTER_PS_ROI_POOL_OP(float, int64_t); -REGISTER_PS_ROI_POOL_OP(half, int64_t); -REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t); -REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t); - -)VISION_METAL"; - -static id compileVisionOpsLibrary(id device) { - static id visionLibrary = nil; - if (visionLibrary) { - return visionLibrary; - } - - NSError* error = nil; - MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; - [options setLanguageVersion:MTLLanguageVersion2_3]; - visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding] - options:options - error:&error]; - TORCH_CHECK(visionLibrary, "Failed to create metal vision library, error: ", [[error description] UTF8String]); - return visionLibrary; -} - -static id visionPipelineState(id device, const std::string& kernel) { - static std::unordered_map> psoCache; - id pso = psoCache[kernel]; - if (pso) { - return pso; - } - - NSError* error = nil; - id visionLib = compileVisionOpsLibrary(device); - id visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]]; - TORCH_CHECK(visionFunc, "Failed to create function state object for: ", kernel); - pso = [device newComputePipelineStateWithFunction:visionFunc error:&error]; - TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); - - psoCache[kernel] = pso; - return pso; -} - -} // namespace mps -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/mps/nms_kernel.mm b/framework/include/torchvision/ops/mps/nms_kernel.mm deleted file mode 100644 index 5ee9b5cbeae..00000000000 --- a/framework/include/torchvision/ops/mps/nms_kernel.mm +++ /dev/null @@ -1,109 +0,0 @@ -#include -#include -#include "mps_kernels.h" - -namespace vision { -namespace ops { - -namespace { - -// This should be in sync with `nmsThreadsPerBlock` in the metal kernel. -constexpr int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8; - -at::Tensor nms_kernel(const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { - using namespace at::native::mps; - TORCH_CHECK(dets.is_mps(), "dets must be a MPS tensor"); - TORCH_CHECK(scores.is_mps(), "scores must be a MPS tensor"); - - TORCH_CHECK(dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); - TORCH_CHECK(dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1)); - TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D"); - TORCH_CHECK(dets.size(0) == scores.size(0), - "boxes and scores should have same number of elements in ", - "dimension 0, got ", - dets.size(0), - " and ", - scores.size(0)) - - if (dets.numel() == 0) { - return at::empty({0}, dets.options().dtype(at::kLong)); - } - - auto order_t = std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); - auto dets_sorted = dets.index_select(0, order_t).contiguous(); - int64_t dets_num = dets.size(0); - float iou_threshold_f = static_cast(iou_threshold); - - const int col_blocks = (dets_num + nmsThreadsPerBlock - 1) / nmsThreadsPerBlock; - at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); - - id inputBuffer = getMTLBufferStorage(dets_sorted); - id outputBuffer = getMTLBufferStorage(mask); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake(col_blocks, col_blocks, 1); - - const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {dets, scores}); - - [computeEncoder setComputePipelineState:visionPSO]; - [computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0]; - [computeEncoder setBuffer:outputBuffer offset:mask.storage_offset() * mask.element_size() atIndex:1]; - [computeEncoder setBytes:&dets_num length:sizeof(int64_t) atIndex:2]; - [computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > nmsThreadsPerBlock) { - tgSize = nmsThreadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - - int64_t num_to_keep = 0; - - at::Tensor mask_cpu = mask.to(at::kCPU); - unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr(); - - std::vector remv(col_blocks); - memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); - - at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); - int64_t* keep_out = keep.data_ptr(); - - for (int64_t i = 0; i < dets_num; i++) { - int64_t nblock = i / nmsThreadsPerBlock; - int64_t inblock = i % nmsThreadsPerBlock; - - if (!(remv[nblock] & (1ULL << inblock))) { - keep_out[num_to_keep++] = i; - unsigned long long* p = mask_host + i * col_blocks; - for (int64_t j = nblock; j < col_blocks; j++) { - remv[j] |= p[j]; - } - } - } - - return order_t.index( - {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(order_t.device(), keep.scalar_type())}); -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/mps/ps_roi_align_kernel.mm b/framework/include/torchvision/ops/mps/ps_roi_align_kernel.mm deleted file mode 100644 index 16b711ad5ef..00000000000 --- a/framework/include/torchvision/ops/mps/ps_roi_align_kernel.mm +++ /dev/null @@ -1,205 +0,0 @@ -#include -#include -#include "mps_helpers.h" -#include "mps_kernels.h" - -namespace vision { -namespace ops { - -namespace { - -std::tuple ps_roi_align_forward_kernel(const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio) { - using namespace at::native::mps; - TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "ps_roi_align_forward_kernel"; - at::checkAllSameGPU(c, {input_t, rois_t}); - at::checkAllSameType(c, {input_t, rois_t}); - - int64_t num_rois = rois.size(0); - int64_t channels = input.size(1); - int64_t height = input.size(2); - int64_t width = input.size(3); - float spatial_scale_f = static_cast(spatial_scale); - - TORCH_CHECK(channels % (pooled_height * pooled_width) == 0, - "input channels must be a multiple of pooling height * pooling width"); - - int64_t channels_out = channels / (pooled_height * pooled_width); - - auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong)); - - int64_t output_size = output.numel(); - - if (output_size == 0) { - return std::make_tuple(output, channel_mapping); - } - - auto input_ = input.contiguous(); - auto rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(input_); - id roisBuffer = getMTLBufferStorage(rois_); - id outputBuffer = getMTLBufferStorage(output); - id channelMappingBuffer = getMTLBufferStorage(channel_mapping); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; - [computeEncoder setBuffer:channelMappingBuffer - offset:channel_mapping.storage_offset() * channel_mapping.element_size() - atIndex:3]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; - [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return std::make_tuple(output, channel_mapping); -} - -at::Tensor ps_roi_align_backward_kernel(const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - using namespace at::native::mps; - TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_align backward with float16 inputs."); - TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); - - at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, - channel_mapping_t{channel_mapping, "channel_mapping", 3}; - - at::CheckedFrom c = "ps_roi_align_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); - at::checkAllSameType(c, {grad_t, rois_t}); - - float spatial_scale_f = static_cast(spatial_scale); - - auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); - - if (grad.numel() == 0) { - return grad_input; - } - - int64_t output_size = grad.numel(); - int64_t channels_out = channels / (pooled_height * pooled_width); - - at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel"); - auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(grad_); - id roisBuffer = getMTLBufferStorage(rois_); - id channelMappingBuffer = getMTLBufferStorage(channel_mapping); - id outputBuffer = getMTLBufferStorage(grad_input); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:channelMappingBuffer - offset:channel_mapping.storage_offset() * channel_mapping.element_size() - atIndex:2]; - [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; - [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), TORCH_FN(ps_roi_align_forward_kernel)); - m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), TORCH_FN(ps_roi_align_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/mps/ps_roi_pool_kernel.mm b/framework/include/torchvision/ops/mps/ps_roi_pool_kernel.mm deleted file mode 100644 index fc24f6990fa..00000000000 --- a/framework/include/torchvision/ops/mps/ps_roi_pool_kernel.mm +++ /dev/null @@ -1,200 +0,0 @@ -#include -#include -#include "mps_helpers.h" -#include "mps_kernels.h" - -namespace vision { -namespace ops { - -namespace { - -std::tuple ps_roi_pool_forward_kernel(const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - using namespace at::native::mps; - TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "ps_roi_pool_forward_kernel"; - at::checkAllSameGPU(c, {input_t, rois_t}); - at::checkAllSameType(c, {input_t, rois_t}); - - int64_t num_rois = rois.size(0); - int64_t channels = input.size(1); - int64_t height = input.size(2); - int64_t width = input.size(3); - float spatial_scale_f = static_cast(spatial_scale); - - TORCH_CHECK(channels % (pooled_height * pooled_width) == 0, - "input channels must be a multiple of pooling height * pooling width"); - int64_t channels_out = channels / (pooled_height * pooled_width); - - auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong)); - auto output_size = output.numel(); - - if (output_size == 0) { - return std::make_tuple(output, channel_mapping); - } - - auto input_ = input.contiguous(); - auto rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(input_); - id roisBuffer = getMTLBufferStorage(rois_); - id outputBuffer = getMTLBufferStorage(output); - id channelMappingBuffer = getMTLBufferStorage(channel_mapping); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "ps_roi_pool_" + scalarToMetalTypeString(input.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; - [computeEncoder setBuffer:channelMappingBuffer - offset:channel_mapping.storage_offset() * channel_mapping.element_size() - atIndex:3]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return std::make_tuple(output, channel_mapping); -} - -at::Tensor ps_roi_pool_backward_kernel(const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - using namespace at::native::mps; - TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_pool backward with float16 inputs."); - TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); - - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, - channel_mapping_t{channel_mapping, "channel_mapping", 3}; - - at::CheckedFrom c = "ps_roi_pool_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); - at::checkAllSameType(c, {grad_t, rois_t}); - - float spatial_scale_f = static_cast(spatial_scale); - - auto num_rois = rois.size(0); - auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); - - if (grad.numel() == 0) { - return grad_input; - } - - int64_t channels_out = channels / (pooled_height * pooled_width); - int64_t output_size = grad.numel(); - - at::globalContext().alertNotDeterministic("ps_roi_pool_backward_kernel"); - auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(grad_); - id roisBuffer = getMTLBufferStorage(rois_); - id channelMappingBuffer = getMTLBufferStorage(channel_mapping); - id outputBuffer = getMTLBufferStorage(grad_input); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "ps_roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad_, rois_, channel_mapping}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:channelMappingBuffer - offset:channel_mapping.storage_offset() * channel_mapping.element_size() - atIndex:2]; - [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), TORCH_FN(ps_roi_pool_forward_kernel)); - m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), TORCH_FN(ps_roi_pool_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/mps/roi_align_kernel.mm b/framework/include/torchvision/ops/mps/roi_align_kernel.mm deleted file mode 100644 index d4ed8b43fd2..00000000000 --- a/framework/include/torchvision/ops/mps/roi_align_kernel.mm +++ /dev/null @@ -1,197 +0,0 @@ -#include -#include -#include "mps_helpers.h" -#include "mps_kernels.h" - -namespace vision { -namespace ops { - -namespace { - -at::Tensor roi_align_forward_kernel(const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned) { - using namespace at::native::mps; - TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_align_forward_kernel"; - at::checkAllSameGPU(c, {input_t, rois_t}); - at::checkAllSameType(c, {input_t, rois_t}); - - int64_t num_rois = rois.size(0); - int64_t channels = input.size(1); - int64_t height = input.size(2); - int64_t width = input.size(3); - float spatial_scale_f = static_cast(spatial_scale); - - at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); - - int64_t output_size = num_rois * pooled_height * pooled_width * channels; - - if (output.numel() == 0) { - return output; - } - - auto input_ = input.contiguous(); - auto rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(input_); - id roisBuffer = getMTLBufferStorage(rois_); - id outputBuffer = getMTLBufferStorage(output); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return output; -} - -at::Tensor roi_align_backward_kernel(const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned) { - using namespace at::native::mps; - TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_align backward with float16 inputs."); - - at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_align_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t}); - at::checkAllSameType(c, {grad_t, rois_t}); - - float spatial_scale_f = static_cast(spatial_scale); - - at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); - - if (grad.numel() == 0) { - return grad_input; - } - - int64_t n_stride = grad.stride(0); - int64_t c_stride = grad.stride(1); - int64_t h_stride = grad.stride(2); - int64_t w_stride = grad.stride(3); - int64_t output_size = grad.numel(); - - at::globalContext().alertNotDeterministic("roi_align_backward_kernel"); - auto rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(grad); - id roisBuffer = getMTLBufferStorage(rois_); - id outputBuffer = getMTLBufferStorage(grad_input); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; - [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12]; - [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13]; - [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14]; - [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_align"), TORCH_FN(roi_align_forward_kernel)); - m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), TORCH_FN(roi_align_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/mps/roi_pool_kernel.mm b/framework/include/torchvision/ops/mps/roi_pool_kernel.mm deleted file mode 100644 index 816d8d70863..00000000000 --- a/framework/include/torchvision/ops/mps/roi_pool_kernel.mm +++ /dev/null @@ -1,196 +0,0 @@ -#include -#include -#include "mps_helpers.h" -#include "mps_kernels.h" - -namespace vision { -namespace ops { - -namespace { - -std::tuple roi_pool_forward_kernel(const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - using namespace at::native::mps; - TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); - - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - - at::CheckedFrom c = "roi_pool_forward_kernel"; - at::checkAllSameGPU(c, {input_t, rois_t}); - at::checkAllSameType(c, {input_t, rois_t}); - - int64_t num_rois = rois.size(0); - int64_t channels = input.size(1); - int64_t height = input.size(2); - int64_t width = input.size(3); - float spatial_scale_f = static_cast(spatial_scale); - - at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); - at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kLong)); - - int64_t output_size = num_rois * pooled_height * pooled_width * channels; - - if (output.numel() == 0) { - return std::make_tuple(output, argmax); - } - - auto input_ = input.contiguous(); - auto rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(input_); - id roisBuffer = getMTLBufferStorage(rois_); - id outputBuffer = getMTLBufferStorage(output); - id argmaxBuffer = getMTLBufferStorage(argmax); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "roi_pool_" + scalarToMetalTypeString(input.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; - [computeEncoder setBuffer:argmaxBuffer offset:argmax.storage_offset() * argmax.element_size() atIndex:3]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return std::make_tuple(output, argmax); -} - -at::Tensor roi_pool_backward_kernel(const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - using namespace at::native::mps; - TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); - TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_pool backward with float16 inputs."); - TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor"); - - at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3}; - - at::CheckedFrom c = "roi_pool_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t}); - at::checkAllSameType(c, {grad_t, rois_t}); - - float spatial_scale_f = static_cast(spatial_scale); - - at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); - - if (grad.numel() == 0) { - return grad_input; - } - - int64_t n_stride = grad.stride(0); - int64_t c_stride = grad.stride(1); - int64_t h_stride = grad.stride(2); - int64_t w_stride = grad.stride(3); - int64_t output_size = grad.numel(); - - at::globalContext().alertNotDeterministic("roi_pool_backward_kernel"); - auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous(); - - id inputBuffer = getMTLBufferStorage(grad); - id roisBuffer = getMTLBufferStorage(rois_); - id argmaxBuffer = getMTLBufferStorage(argmax_); - id outputBuffer = getMTLBufferStorage(grad_input); - id device = MPSDevice::getInstance()->device(); - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), - 1, - 1); - - const std::string kernel = "roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_, argmax_}); - - [computeEncoder setComputePipelineState:visionPSO]; - // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; - [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2]; - [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; - - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; - [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:11]; - [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:12]; - [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:13]; - [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - return grad_input; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_pool"), TORCH_FN(roi_pool_forward_kernel)); - m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), TORCH_FN(roi_pool_backward_kernel)); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/nms.cpp b/framework/include/torchvision/ops/nms.cpp deleted file mode 100644 index 07a934bce5a..00000000000 --- a/framework/include/torchvision/ops/nms.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include "nms.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -at::Tensor nms( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.nms.nms"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::nms", "") - .typed(); - return op.call(dets, scores, iou_threshold); -} - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/nms.h b/framework/include/torchvision/ops/nms.h deleted file mode 100644 index 8c75a242bff..00000000000 --- a/framework/include/torchvision/ops/nms.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API at::Tensor nms( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold); - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/ops.h b/framework/include/torchvision/ops/ops.h deleted file mode 100644 index 77995e44197..00000000000 --- a/framework/include/torchvision/ops/ops.h +++ /dev/null @@ -1,8 +0,0 @@ -#pragma once - -#include "deform_conv2d.h" -#include "nms.h" -#include "ps_roi_align.h" -#include "ps_roi_pool.h" -#include "roi_align.h" -#include "roi_pool.h" diff --git a/framework/include/torchvision/ops/ps_roi_align.cpp b/framework/include/torchvision/ops/ps_roi_align.cpp deleted file mode 100644 index de458c0d62d..00000000000 --- a/framework/include/torchvision/ops/ps_roi_align.cpp +++ /dev/null @@ -1,112 +0,0 @@ -#include "ps_roi_align.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -std::tuple ps_roi_align( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::ps_roi_align", "") - .typed(); - return op.call( - input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); -} - -std::tuple ps_roi_align_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::ps_roi_align", "") - .typed(); - return op.call( - input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); -} - -namespace detail { - -at::Tensor _ps_roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") - .typed(); - return op.call( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - batch_size, - channels, - height, - width); -} - -at::Tensor _ps_roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") - .typed(); - return op.call( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - batch_size, - channels, - height, - width); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio) -> (Tensor, Tensor)")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/ps_roi_align.h b/framework/include/torchvision/ops/ps_roi_align.h deleted file mode 100644 index 75650586bc6..00000000000 --- a/framework/include/torchvision/ops/ps_roi_align.h +++ /dev/null @@ -1,56 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API std::tuple ps_roi_align( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio); - -VISION_API std::tuple ps_roi_align_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio); - -namespace detail { - -at::Tensor _ps_roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -at::Tensor _ps_roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width); - -} // namespace detail - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/ps_roi_pool.cpp b/framework/include/torchvision/ops/ps_roi_pool.cpp deleted file mode 100644 index 92469d5e380..00000000000 --- a/framework/include/torchvision/ops/ps_roi_pool.cpp +++ /dev/null @@ -1,104 +0,0 @@ -#include "ps_roi_pool.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -std::tuple ps_roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::ps_roi_pool", "") - .typed(); - return op.call(input, rois, spatial_scale, pooled_height, pooled_width); -} - -std::tuple ps_roi_pool_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::ps_roi_pool", "") - .typed(); - return op.call(input, rois, spatial_scale, pooled_height, pooled_width); -} - -namespace detail { - -at::Tensor _ps_roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") - .typed(); - return op.call( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -} - -at::Tensor _ps_roi_pool_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") - .typed(); - return op.call( - grad, - rois, - channel_mapping, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/ps_roi_pool.h b/framework/include/torchvision/ops/ps_roi_pool.h deleted file mode 100644 index 4a3cc54e0e5..00000000000 --- a/framework/include/torchvision/ops/ps_roi_pool.h +++ /dev/null @@ -1,52 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API std::tuple ps_roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API std::tuple ps_roi_pool_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width); - -namespace detail { - -at::Tensor _ps_roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -at::Tensor _ps_roi_pool_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width); - -} // namespace detail - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/roi_align.cpp b/framework/include/torchvision/ops/roi_align.cpp deleted file mode 100644 index aa6dccb44f2..00000000000 --- a/framework/include/torchvision/ops/roi_align.cpp +++ /dev/null @@ -1,132 +0,0 @@ -#include "roi_align.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -at::Tensor roi_align( - const at::Tensor& input, // Input feature map. - const at::Tensor& rois, // List of ROIs to pool over. - double spatial_scale, // The scale of the image features. ROIs will be - // scaled to this. - int64_t pooled_height, // The height of the pooled feature map. - int64_t pooled_width, // The width of the pooled feature - int64_t sampling_ratio, // The number of points to sample in each bin - bool aligned) // The flag for pixel shift -// along each axis. -{ - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::roi_align", "") - .typed(); - return op.call( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned); -} - -at::Tensor roi_align_symint( - const at::Tensor& input, // Input feature map. - const at::Tensor& rois, // List of ROIs to pool over. - double spatial_scale, // The scale of the image features. ROIs will be - // scaled to this. - c10::SymInt pooled_height, // The height of the pooled feature map. - c10::SymInt pooled_width, // The width of the pooled feature - int64_t sampling_ratio, // The number of points to sample in each bin - bool aligned) // The flag for pixel shift -// along each axis. -{ - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::roi_align", "") - .typed(); - return op.call( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned); -} - -namespace detail { - -at::Tensor _roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_roi_align_backward", "") - .typed(); - return op.call( - grad, - rois, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width, - sampling_ratio, - aligned); -} - -at::Tensor _roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_roi_align_backward", "") - .typed(); - return op.call( - grad, - rois, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width, - sampling_ratio, - aligned); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width, int sampling_ratio, bool aligned) -> Tensor")); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/roi_align.h b/framework/include/torchvision/ops/roi_align.h deleted file mode 100644 index 072d6d4231c..00000000000 --- a/framework/include/torchvision/ops/roi_align.h +++ /dev/null @@ -1,58 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API at::Tensor roi_align( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned); - -VISION_API at::Tensor roi_align_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - bool aligned); - -namespace detail { - -at::Tensor _roi_align_backward( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned); - -at::Tensor _roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned); - -} // namespace detail - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/roi_pool.cpp b/framework/include/torchvision/ops/roi_pool.cpp deleted file mode 100644 index 20ca3ca91e7..00000000000 --- a/framework/include/torchvision/ops/roi_pool.cpp +++ /dev/null @@ -1,102 +0,0 @@ -#include "roi_pool.h" - -#include -#include -#include - -namespace vision { -namespace ops { - -std::tuple roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::roi_pool", "") - .typed(); - return op.call(input, rois, spatial_scale, pooled_height, pooled_width); -} - -std::tuple roi_pool_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width) { - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::roi_pool", "") - .typed(); - return op.call(input, rois, spatial_scale, pooled_height, pooled_width); -} - -namespace detail { - -at::Tensor _roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_roi_pool_backward", "") - .typed(); - return op.call( - grad, - rois, - argmax, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -} - -at::Tensor _roi_pool_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_roi_pool_backward", "") - .typed(); - return op.call( - grad, - rois, - argmax, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width); -} - -} // namespace detail - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)")); - m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); -} - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/ops/roi_pool.h b/framework/include/torchvision/ops/roi_pool.h deleted file mode 100644 index e2133240f4f..00000000000 --- a/framework/include/torchvision/ops/roi_pool.h +++ /dev/null @@ -1,52 +0,0 @@ -#pragma once - -#include -#include "../macros.h" - -namespace vision { -namespace ops { - -VISION_API std::tuple roi_pool( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API std::tuple roi_pool_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width); - -namespace detail { - -at::Tensor _roi_pool_backward( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - -at::Tensor _roi_pool_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width); - -} // namespace detail - -} // namespace ops -} // namespace vision diff --git a/framework/include/torchvision/vision.cpp b/framework/include/torchvision/vision.cpp deleted file mode 100644 index 161b8ecfa2f..00000000000 --- a/framework/include/torchvision/vision.cpp +++ /dev/null @@ -1,41 +0,0 @@ -#include "vision.h" - -#ifndef MOBILE -#ifdef USE_PYTHON -#include -#endif -#endif -#include - -#ifdef WITH_CUDA -#include -#endif -#ifdef WITH_HIP -#include -#endif - -// If we are in a Windows environment, we need to define -// initialization functions for the _custom_ops extension. -// For PyMODINIT_FUNC to work, we need to include Python.h -#if !defined(MOBILE) && defined(_WIN32) -#ifdef USE_PYTHON -PyMODINIT_FUNC PyInit__C(void) { - // No need to do anything. - return NULL; -} -#endif // USE_PYTHON -#endif // !defined(MOBILE) && defined(_WIN32) - -namespace vision { -int64_t cuda_version() { -#ifdef WITH_CUDA - return CUDA_VERSION; -#else - return -1; -#endif -} - -TORCH_LIBRARY_FRAGMENT(torchvision, m) { - m.def("_cuda_version", &cuda_version); -} -} // namespace vision diff --git a/framework/include/torchvision/vision.h b/framework/include/torchvision/vision.h deleted file mode 100644 index 22f8c6cdd38..00000000000 --- a/framework/include/torchvision/vision.h +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include -#include "macros.h" - -namespace vision { -VISION_API int64_t cuda_version(); - -namespace detail { -extern "C" VISION_INLINE_VARIABLE auto _register_ops = &cuda_version; -#ifdef HINT_MSVC_LINKER_INCLUDE_SYMBOL -#pragma comment(linker, "/include:_register_ops") -#endif - -} // namespace detail -} // namespace vision diff --git a/framework/share/cmake/TorchVision/TorchVisionConfig.cmake b/framework/share/cmake/TorchVision/TorchVisionConfig.cmake deleted file mode 100644 index f04d2919ebf..00000000000 --- a/framework/share/cmake/TorchVision/TorchVisionConfig.cmake +++ /dev/null @@ -1,82 +0,0 @@ -# TorchVisionConfig.cmake -# -------------------- -# -# Exported targets:: Vision -# - - -####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() ####### -####### Any changes to this file will be overwritten by the next CMake run #### -####### The input file was TorchVisionConfig.cmake.in ######## - -get_filename_component(PACKAGE_${CMAKE_FIND_PACKAGE_NAME}_COUNTER_1 "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) - -macro(set_and_check _var _file) - set(${_var} "${_file}") - if(NOT EXISTS "${_file}") - message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !") - endif() -endmacro() - -macro(check_required_components _NAME) - foreach(comp ${${_NAME}_FIND_COMPONENTS}) - if(NOT ${_NAME}_${comp}_FOUND) - if(${_NAME}_FIND_REQUIRED_${comp}) - set(${_NAME}_FOUND FALSE) - endif() - endif() - endforeach() -endmacro() - -#################################################################################### - -set(PN TorchVision) - -# location of include/torchvision -set(${PN}_INCLUDE_DIR "${PACKAGE_PREFIX_DIR}/include") - -set(${PN}_LIBRARY "") -set(${PN}_DEFINITIONS USING_${PN}) - -check_required_components(${PN}) - - -if(NOT (CMAKE_VERSION VERSION_LESS 3.0)) -#----------------------------------------------------------------------------- -# Don't include targets if this file is being picked up by another -# project which has already built this as a subproject -#----------------------------------------------------------------------------- -if(NOT TARGET ${PN}::${PN}) -include("${CMAKE_CURRENT_LIST_DIR}/${PN}Targets.cmake") - -target_include_directories(${PN}::${PN} INTERFACE "${${PN}_INCLUDE_DIR}") - -if(OFF) - target_compile_definitions(${PN}::${PN} INTERFACE WITH_CUDA) -endif() - -find_package(Torch REQUIRED) -target_link_libraries(${PN}::${PN} INTERFACE torch) - -if(ON) - find_package(PNG REQUIRED) - target_link_libraries(${PN}::${PN} INTERFACE ${PNG_LIBRARY}) - target_compile_definitions(${PN}::${PN} INTERFACE PNG_FOUND) -endif() - -if(ON) - find_package(JPEG REQUIRED) - target_link_libraries(${PN}::${PN} INTERFACE ${JPEG_LIBRARIES}) - target_compile_definitions(${PN}::${PN} INTERFACE JPEG_FOUND) -endif() - -if (OFF) - if(NOT TARGET Python3::Python) - find_package(Python3 COMPONENTS Development) - endif() - target_link_libraries(torch INTERFACE Python3::Python) - target_compile_definitions(${PN}::${PN} INTERFACE USE_PYTHON) -endif() - -endif() -endif() diff --git a/framework/share/cmake/TorchVision/TorchVisionConfigVersion.cmake b/framework/share/cmake/TorchVision/TorchVisionConfigVersion.cmake deleted file mode 100644 index cb344ba7a65..00000000000 --- a/framework/share/cmake/TorchVision/TorchVisionConfigVersion.cmake +++ /dev/null @@ -1,43 +0,0 @@ -# This is a basic version file for the Config-mode of find_package(). -# It is used by write_basic_package_version_file() as input file for configure_file() -# to create a version-file which can be installed along a config.cmake file. -# -# The created file sets PACKAGE_VERSION_EXACT if the current version string and -# the requested version string are exactly the same and it sets -# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version. -# The variable CVF_VERSION must be set before calling configure_file(). - -set(PACKAGE_VERSION "0.18.0a0") - -if (PACKAGE_FIND_VERSION_RANGE) - # Package version must be in the requested version range - if ((PACKAGE_FIND_VERSION_RANGE_MIN STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MIN) - OR ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_GREATER PACKAGE_FIND_VERSION_MAX) - OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND PACKAGE_VERSION VERSION_GREATER_EQUAL PACKAGE_FIND_VERSION_MAX))) - set(PACKAGE_VERSION_COMPATIBLE FALSE) - else() - set(PACKAGE_VERSION_COMPATIBLE TRUE) - endif() -else() - if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION) - set(PACKAGE_VERSION_COMPATIBLE FALSE) - else() - set(PACKAGE_VERSION_COMPATIBLE TRUE) - if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION) - set(PACKAGE_VERSION_EXACT TRUE) - endif() - endif() -endif() - - -# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it: -if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "") - return() -endif() - -# check that the installed version has the same 32/64bit-ness as the one which is currently searching: -if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8") - math(EXPR installedBits "8 * 8") - set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)") - set(PACKAGE_VERSION_UNSUITABLE TRUE) -endif() diff --git a/framework/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake b/framework/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake deleted file mode 100644 index 91aa482bb9c..00000000000 --- a/framework/share/cmake/TorchVision/TorchVisionTargets-noconfig.cmake +++ /dev/null @@ -1,20 +0,0 @@ -#---------------------------------------------------------------- -# Generated CMake target import file. -#---------------------------------------------------------------- - -# Commands may need to know the format version. -set(CMAKE_IMPORT_FILE_VERSION 1) - -# Import target "TorchVision::TorchVision" for configuration "" -set_property(TARGET TorchVision::TorchVision APPEND PROPERTY IMPORTED_CONFIGURATIONS NOCONFIG) -set_target_properties(TorchVision::TorchVision PROPERTIES - IMPORTED_LINK_DEPENDENT_LIBRARIES_NOCONFIG "torch" - IMPORTED_LOCATION_NOCONFIG "${_IMPORT_PREFIX}/lib/libtorchvision.dylib" - IMPORTED_SONAME_NOCONFIG "@rpath/libtorchvision.dylib" - ) - -list(APPEND _cmake_import_check_targets TorchVision::TorchVision ) -list(APPEND _cmake_import_check_files_for_TorchVision::TorchVision "${_IMPORT_PREFIX}/lib/libtorchvision.dylib" ) - -# Commands beyond this point should not need to know the version. -set(CMAKE_IMPORT_FILE_VERSION) diff --git a/framework/share/cmake/TorchVision/TorchVisionTargets.cmake b/framework/share/cmake/TorchVision/TorchVisionTargets.cmake deleted file mode 100644 index 1e07b7fc626..00000000000 --- a/framework/share/cmake/TorchVision/TorchVisionTargets.cmake +++ /dev/null @@ -1,102 +0,0 @@ -# Generated by CMake - -if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8) - message(FATAL_ERROR "CMake >= 2.8.0 required") -endif() -if(CMAKE_VERSION VERSION_LESS "2.8.3") - message(FATAL_ERROR "CMake >= 2.8.3 required") -endif() -cmake_policy(PUSH) -cmake_policy(VERSION 2.8.3...3.27) -#---------------------------------------------------------------- -# Generated CMake target import file. -#---------------------------------------------------------------- - -# Commands may need to know the format version. -set(CMAKE_IMPORT_FILE_VERSION 1) - -# Protect against multiple inclusion, which would fail when already imported targets are added once more. -set(_cmake_targets_defined "") -set(_cmake_targets_not_defined "") -set(_cmake_expected_targets "") -foreach(_cmake_expected_target IN ITEMS TorchVision::TorchVision) - list(APPEND _cmake_expected_targets "${_cmake_expected_target}") - if(TARGET "${_cmake_expected_target}") - list(APPEND _cmake_targets_defined "${_cmake_expected_target}") - else() - list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}") - endif() -endforeach() -unset(_cmake_expected_target) -if(_cmake_targets_defined STREQUAL _cmake_expected_targets) - unset(_cmake_targets_defined) - unset(_cmake_targets_not_defined) - unset(_cmake_expected_targets) - unset(CMAKE_IMPORT_FILE_VERSION) - cmake_policy(POP) - return() -endif() -if(NOT _cmake_targets_defined STREQUAL "") - string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}") - string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}") - message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n") -endif() -unset(_cmake_targets_defined) -unset(_cmake_targets_not_defined) -unset(_cmake_expected_targets) - - -# Compute the installation prefix relative to this file. -get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH) -get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) -get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) -get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) -if(_IMPORT_PREFIX STREQUAL "/") - set(_IMPORT_PREFIX "") -endif() - -# Create imported target TorchVision::TorchVision -add_library(TorchVision::TorchVision SHARED IMPORTED) - -# Load information for each installed configuration. -file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/TorchVisionTargets-*.cmake") -foreach(_cmake_config_file IN LISTS _cmake_config_files) - include("${_cmake_config_file}") -endforeach() -unset(_cmake_config_file) -unset(_cmake_config_files) - -# Cleanup temporary variables. -set(_IMPORT_PREFIX) - -# Loop over all imported files and verify that they actually exist -foreach(_cmake_target IN LISTS _cmake_import_check_targets) - if(CMAKE_VERSION VERSION_LESS "3.28" - OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target} - OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}") - foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}") - if(NOT EXISTS "${_cmake_file}") - message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file - \"${_cmake_file}\" -but this file does not exist. Possible reasons include: -* The file was deleted, renamed, or moved to another location. -* An install or uninstall procedure did not complete successfully. -* The installation package was faulty and contained - \"${CMAKE_CURRENT_LIST_FILE}\" -but not all the files it references. -") - endif() - endforeach() - endif() - unset(_cmake_file) - unset("_cmake_import_check_files_for_${_cmake_target}") -endforeach() -unset(_cmake_target) -unset(_cmake_import_check_targets) - -# This file does not depend on other imported targets which have -# been exported from the same project but in a separate export set. - -# Commands beyond this point should not need to know the version. -set(CMAKE_IMPORT_FILE_VERSION) -cmake_policy(POP) From 2895f4fc71eaf727df293f8be97bf1a942fcc7c7 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Sat, 16 Nov 2024 13:57:40 +0100 Subject: [PATCH 09/31] Removing build_xcode dir and included files. --- build_xcode/VisionTests/VisionTests.mm | 36 - .../csrc/ops/mps/deform_conv2d_kernal.mm | 934 ------------------ 2 files changed, 970 deletions(-) delete mode 100644 build_xcode/VisionTests/VisionTests.mm delete mode 100644 torchvision/csrc/ops/mps/deform_conv2d_kernal.mm diff --git a/build_xcode/VisionTests/VisionTests.mm b/build_xcode/VisionTests/VisionTests.mm deleted file mode 100644 index 62336a1b3d5..00000000000 --- a/build_xcode/VisionTests/VisionTests.mm +++ /dev/null @@ -1,36 +0,0 @@ -// -// VisionTests.m -// VisionTests -// -// Created by Thomas Martin on 12/10/2024. -// - -#import - -@interface VisionTests : XCTestCase - -@end - -@implementation VisionTests - -- (void)setUp { - // Put setup code here. This method is called before the invocation of each test method in the class. -} - -- (void)tearDown { - // Put teardown code here. This method is called after the invocation of each test method in the class. -} - -- (void)testExample { - // This is an example of a functional test case. - // Use XCTAssert and related functions to verify your tests produce the correct results. -} - -- (void)testPerformanceExample { - // This is an example of a performance test case. - [self measureBlock:^{ - // Put the code you want to measure the time of here. - }]; -} - -@end diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernal.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernal.mm deleted file mode 100644 index 2df529f25a2..00000000000 --- a/torchvision/csrc/ops/mps/deform_conv2d_kernal.mm +++ /dev/null @@ -1,934 +0,0 @@ -// vision::ops:: -// deform_conv2d_kernal.mm -// - -#include -#include -#include -#include -#include "mps_helpers.h" -#include "mps_kernels.h" - - -namespace vision { -namespace ops { - -namespace { - -const int64_t tkMaxParallelImgs = 32; - - -void deformable_im2col(const at::Tensor& input, - const at::Tensor& data_offset, - const at::Tensor& data_mask, - int64_t n_in_channels, - int64_t height, - int64_t width, - int64_t weight_h, - int64_t weight_w, - int64_t pad_h, - int64_t pad_w, - int64_t stride_h, - int64_t stride_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t out_h, - int64_t out_w, - int64_t parallel_imgs, - int64_t deformable_group, - bool use_mask, - at::Tensor data_col) { - using namespace at::native::mps; - - // Validate tensors as of type mps. - TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); - TORCH_CHECK(data_offset.is_mps(), "data_offset must be a MPS tensor"); - TORCH_CHECK(data_mask.is_mps(), "data_mask must be a MPS tensor"); - - at::TensorArg input_t{input, "input", 1}, - data_offset_t{data_offset, "data_offset", 2}, - data_mask_t{data_mask, "data_mask", 3}; - - at::CheckedFrom c = "deformable_im2col"; - at::checkAllSameGPU(c, {input_t, data_offset_t, data_mask_t}); - at::checkAllSameType(c, {input_t, data_offset_t, data_mask_t}); - - - const int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs; - - // These function parameters have all been made contiguous by the caller function deform_conv2d_forward_kernel - // Check if it is safe to skip the following: - auto input_c = input.contiguous(); - auto data_offset_c = data_offset.contiguous(); - auto data_mask_c = data_mask.contiguous(); - - // Get a raw pointer to the underlying data structure of the tensors and cast it as a pointer to an MTLBuffer. - id inputBuffer = getMTLBufferStorage(input_c); - id data_offsetBuffer = getMTLBufferStorage(data_offset_c); - id data_maskBuffer = getMTLBufferStorage(data_mask_c); - id data_colBuffer = getMTLBufferStorage(data_col); - - id device = MPSDevice::getInstance()->device(); - - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - const std::string kernel = "deformable_im2col_" + scalarToMetalTypeString(input.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(num_kernels), - static_cast(512)), - static_cast(4096)), - 1, - 1); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_c, data_offset_c, data_mask_c}); - - id computeEncoder = mpsStream->commandEncoder(); - [computeEncoder setComputePipelineState:visionPSO]; - - [computeEncoder setBuffer:inputBuffer offset:input_c.storage_offset() * input_c.element_size() atIndex:1]; - [computeEncoder setBuffer:data_offsetBuffer offset:data_offset_c.storage_offset() * data_offset_c.element_size() atIndex:2]; - [computeEncoder setBuffer:data_maskBuffer offset:data_mask_c.storage_offset() * data_mask_c.element_size() atIndex:3]; - [computeEncoder setBuffer:data_colBuffer offset:data_col.storage_offset() * data_col.element_size() atIndex:20]; - - [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:0]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:10]; - [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:11]; - [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:12]; - [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:13]; - [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:14]; - [computeEncoder setBytes:&n_in_channels length:sizeof(int64_t) atIndex:15]; - [computeEncoder setBytes:&deformable_group length:sizeof(int64_t) atIndex:16]; - [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:17]; - [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:18]; - [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:19]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); - -} - -int get_greatest_divisor_below_bound(int n, int bound) { - for (int k = bound; k > 1; --k) { - if (n % k == 0) { - return k; - } - } - return 1; -} - -void compute_grad_input( - const at::Tensor& columns, - const at::Tensor& offset, - const at::Tensor& mask, - int64_t channels, - int64_t height, - int64_t width, - int64_t weight_h, //kernel_h - int64_t weight_w, //kernel_w - int64_t pad_h, - int64_t pad_w, - int64_t stride_h, - int64_t stride_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t parallel_imgs, //batch_sz - int64_t n_offset_grps, - bool use_mask, - at::Tensor grad_im) { - using namespace at::native::mps; - - at::globalContext().alertNotDeterministic("compute_grad_input"); - - auto columns_c = columns.contiguous(); - auto offset_c = offset.contiguous(); - auto mask_c = mask.contiguous(); - - const int64_t out_h = - (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - const int64_t out_w = - (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - - const int64_t num_kernels = - (int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs; - - id columnsBuffer = getMTLBufferStorage(columns_c); - id offsetBuffer = getMTLBufferStorage(offset_c); - id maskBuffer = getMTLBufferStorage(mask_c); - id grad_imBuffer = getMTLBufferStorage(grad_im); - - id device = MPSDevice::getInstance()->device(); - - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - - const std::string kernel = "deformable_col2im_" + scalarToMetalTypeString(columns.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {columns, offset, mask}); - - [computeEncoder setComputePipelineState:visionPSO]; - - [computeEncoder setBuffer:columnsBuffer offset:columns_c.storage_offset() * columns_c.element_size() atIndex:1]; - [computeEncoder setBuffer:offsetBuffer offset:offset_c.storage_offset() * offset_c.element_size() atIndex:2]; - [computeEncoder setBuffer:maskBuffer offset:mask_c.storage_offset() * mask_c.element_size() atIndex:3]; - [computeEncoder setBuffer:grad_imBuffer - offset:grad_im.storage_offset() * grad_im.element_size() - atIndex:20]; - - [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:0]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:10]; - [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:11]; - [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:12]; - [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:13]; - [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:14]; - [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:15]; - [computeEncoder setBytes:&n_offset_grps length:sizeof(int64_t) atIndex:16]; - [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:17]; - [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:18]; - [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:19]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(num_kernels), static_cast(512)), static_cast(4096)), - 1, - 1); - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); -} - -void compute_grad_offset_and_mask( - const at::Tensor& columns, - const at::Tensor& input, - const at::Tensor& offset, - const at::Tensor& mask, - int64_t channels, - int64_t height, - int64_t width, - int64_t weight_h, - int64_t weight_w, - int64_t pad_h, - int64_t pad_w, - int64_t stride_h, - int64_t stride_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t parallel_imgs, - int64_t n_offset_grps, - bool use_mask, - at::Tensor grad_offset, - at::Tensor grad_mask) { - - using namespace at::native::mps; - - auto columns_c = columns; //.contiguous(); - auto input_c = input; //.contiguous(); - auto offset_c = offset; //.contiguous(); - auto mask_c = mask; //.contiguous(); - - const int64_t out_h = - (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - const int64_t out_w = - (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - const int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w * - n_offset_grps * parallel_imgs; - - const int64_t offset_channels = 2 * weight_h * weight_w * n_offset_grps; - - id columnsBuffer = getMTLBufferStorage(columns_c); - id inputBuffer = getMTLBufferStorage(input_c); - id offsetBuffer = getMTLBufferStorage(offset_c); - id maskBuffer = getMTLBufferStorage(mask_c); - id grad_offsetBuffer = getMTLBufferStorage(grad_offset); - id grad_maskBuffer = getMTLBufferStorage(grad_mask); - - id device = MPSDevice::getInstance()->device(); - - MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_sync(mpsStream->queue(), ^() { - @autoreleasepool { - id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(num_kernels), static_cast(512)), static_cast(4096)), 1, 1); - - const std::string kernel = "deformable_col2im_coord_" + scalarToMetalTypeString(columns.scalar_type()); - id visionPSO = mps::visionPipelineState(device, kernel); - - // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {columns_c, input_c, offset_c, mask_c}); - - [computeEncoder setComputePipelineState:visionPSO]; - - [computeEncoder setBuffer:columnsBuffer offset:columns_c.storage_offset() * columns_c.element_size() atIndex:1]; - [computeEncoder setBuffer:inputBuffer offset:input_c.storage_offset() * input_c.element_size() atIndex:2]; - [computeEncoder setBuffer:offsetBuffer offset:offset_c.storage_offset() * offset_c.element_size() atIndex:3]; - [computeEncoder setBuffer:maskBuffer offset:mask_c.storage_offset() * mask_c.element_size() atIndex:4]; - [computeEncoder setBuffer:grad_offsetBuffer - offset:grad_offset.storage_offset() * grad_offset.element_size() - atIndex:22]; - [computeEncoder setBuffer:grad_maskBuffer - offset:grad_mask.storage_offset() * grad_mask.element_size() - atIndex:23]; - - [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:0]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:10]; - [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:11]; - [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:12]; - [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:13]; - [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:14]; - [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:15]; - [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:16]; - [computeEncoder setBytes:&offset_channels length:sizeof(int64_t) atIndex:17]; - [computeEncoder setBytes:&n_offset_grps length:sizeof(int64_t) atIndex:18]; - [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:19]; - [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:20]; - [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:21]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - - - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - - getMPSProfiler().endProfileKernel(visionPSO); - } - }); -} - -std::tuple backward_gradient_inputs( - at::Tensor input, - at::Tensor weight, - at::Tensor offset, - at::Tensor mask, - at::Tensor grad_out, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - int64_t n_parallel_imgs, - bool use_mask) { - - int64_t batch_sz = input.size(0); - int64_t n_in_channels = input.size(1); - int64_t in_h = input.size(2); - int64_t in_w = input.size(3); - - n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); - - int64_t n_out_channels = weight.size(0); - int64_t weight_h = weight.size(2); - int64_t weight_w = weight.size(3); - - int64_t out_w = - (in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int64_t out_h = - (in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - - auto grad_input = at::zeros_like(input); - auto grad_offset = at::zeros_like(offset); - auto grad_mask = at::zeros_like(mask); - - if (batch_sz == 0) { - return std::make_tuple(grad_input, grad_offset, grad_mask); - } - - auto columns = at::empty( - {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, - input.options()); - - // Separate into blocks - grad_input = grad_input.reshape( - {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - input = input.reshape( - {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - - grad_offset = grad_offset.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - offset = offset.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - - if (use_mask) { - grad_mask = grad_mask.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - mask = mask.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - } - - grad_out = grad_out - .reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_weight_grps, - n_out_channels / n_weight_grps, - out_h, - out_w}) - .permute({0, 2, 3, 1, 4, 5}); - - weight = weight.reshape( - {n_weight_grps, - weight.size(0) / n_weight_grps, - weight.size(1), - weight.size(2), - weight.size(3)}); - - columns = columns.view( - {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); - for (int64_t elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { - columns.zero_(); - // Separate into weight groups - for (int64_t g = 0; g < n_weight_grps; g++) { - columns[g] = columns[g].addmm_( - weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); - } - - compute_grad_offset_and_mask( - columns, - input[elt], - offset[elt], - mask[elt], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - grad_offset[elt], - grad_mask[elt]); - - compute_grad_input( - columns, - offset[elt], - mask[elt], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - grad_input[elt]); - } - - grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); - grad_offset = grad_offset.view( - {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); - - if (use_mask) { - grad_mask = grad_mask.view( - {batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w}); - } - - return std::make_tuple(grad_input, grad_offset, grad_mask); -} - -at::Tensor backward_gradient_parameters( - at::Tensor input, - const at::Tensor& weight, - at::Tensor offset, - at::Tensor mask, - const at::Tensor& grad_out, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - int64_t n_parallel_imgs, - bool use_mask) { - - int64_t batch_sz = input.size(0); - int64_t n_in_channels = input.size(1); - int64_t in_h = input.size(2); - int64_t in_w = input.size(3); - - n_parallel_imgs = std::min(batch_sz, n_parallel_imgs); - - int64_t n_out_channels = weight.size(0); - int64_t weight_h = weight.size(2); - int64_t weight_w = weight.size(3); - - int64_t out_h = grad_out.size(2); - int64_t out_w = grad_out.size(3); - - auto grad_weight = at::zeros_like(weight); - if (batch_sz == 0) { - return grad_weight; - } - - at::Tensor grad_out_buf = grad_out - .reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_weight_grps, - n_out_channels / n_weight_grps, - out_h, - out_w}) - .permute({0, 2, 3, 1, 4, 5}) - .contiguous(); - - input = input.reshape( - {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); - - offset = offset.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - - if (use_mask) { - mask = mask.reshape( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - } - - grad_weight = grad_weight.reshape( - {n_weight_grps, - grad_weight.size(0) / n_weight_grps, - grad_weight.size(1), - grad_weight.size(2), - grad_weight.size(3)}); - - auto columns = at::empty( - {n_weight_grps, - n_in_channels * weight_w * weight_h / n_weight_grps, - n_parallel_imgs * out_h * out_w}, - input.options()); - - for (int64_t elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { - deformable_im2col( - input[elt], - offset[elt], - mask[elt], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - out_h, - out_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - columns); - - for (int64_t g = 0; g < n_weight_grps; g++) { - grad_weight[g] = - grad_weight[g] - .flatten(1) - .addmm_( - grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) - .view_as(grad_weight[g]); - } - } - - grad_weight = grad_weight.view( - {grad_weight.size(0) * grad_weight.size(1), - grad_weight.size(2), - grad_weight.size(3), - grad_weight.size(4)}); - return grad_weight; -} - -at::Tensor deform_conv2d_forward_kernel( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - bool use_mask) { - at::Tensor input_c = input.contiguous(); - at::Tensor offset_c = offset.contiguous(); - at::Tensor weight_c = weight.contiguous(); - at::Tensor mask_c = mask.contiguous(); - at::Tensor bias_c = bias.contiguous(); - - TORCH_CHECK(input_c.ndimension() == 4); - TORCH_CHECK(offset_c.ndimension() == 4); - TORCH_CHECK(!use_mask || mask_c.ndimension() == 4); - TORCH_CHECK(weight_c.ndimension() == 4); - TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor"); - - at::DeviceGuard guard(input_c.device()); - - int batch_sz = input_c.size(0); - int in_channels = input_c.size(1); - int in_h = input_c.size(2); - int in_w = input_c.size(3); - - int n_parallel_imgs = - get_greatest_divisor_below_bound(batch_sz, tkMaxParallelImgs); - - int out_channels = weight_c.size(0); - int weight_h = weight_c.size(2); - int weight_w = weight_c.size(3); - - int ker_h = dilation_h * (weight_h - 1) + 1; - int ker_w = dilation_w * (weight_w - 1) + 1; - int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; - int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; - - TORCH_CHECK( - weight_h > 0 && weight_w > 0, - "weight_h: ", - weight_h, - " weight_w: ", - weight_w); - TORCH_CHECK( - stride_h > 0 && stride_w > 0, - "stride_h: ", - stride_h, - " stride_w: ", - stride_w); - TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w); - TORCH_CHECK( - dilation_h > 0 && dilation_w > 0, - "dilation_h: ", - dilation_h, - " dilation_w: ", - dilation_w); - - TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1)); - TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0); - TORCH_CHECK( - (offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w), - "offset.shape[1] is not valid: got: ", - offset_c.size(1), - " expected: ", - n_offset_grps * 2 * weight_h * weight_w); - TORCH_CHECK( - (!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w), - "mask.shape[1] is not valid: got: ", - mask_c.size(1), - " expected: ", - n_offset_grps * weight_h * weight_w); - TORCH_CHECK(input_c.size(1) % n_offset_grps == 0); - - TORCH_CHECK( - (offset_c.size(0) == input_c.size(0)), "invalid batch size of offset"); - TORCH_CHECK( - (offset_c.size(2) == out_h && offset_c.size(3) == out_w), - "offset output dims: (", - offset_c.size(2), - ", ", - offset_c.size(3), - ") - ", - "computed output dims: (", - out_h, - ", ", - out_w, - ")"); - TORCH_CHECK( - (mask_c.size(0) == input_c.size(0)), "invalid batch size of mask"); - TORCH_CHECK( - (!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)), - "mask output dims: (", - mask_c.size(2), - ", ", - mask_c.size(3), - ") - ", - "computed output dims: (", - out_h, - ", ", - out_w, - ")"); - TORCH_CHECK( - out_h > 0 && out_w > 0, - "Calculated output size too small - out_h: ", - out_h, - " out_w: ", - out_w); - - auto out = - at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options()); - if (batch_sz == 0) { - return out; - } - - // Separate batches into blocks - out = out.view( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - out_channels, - out_h, - out_w}); - input_c = input_c.view( - {batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); - - offset_c = offset_c.view( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * 2 * weight_h * weight_w, - out_h, - out_w}); - - if (use_mask) { - mask_c = mask_c.view( - {batch_sz / n_parallel_imgs, - n_parallel_imgs, - n_offset_grps * weight_h * weight_w, - out_h, - out_w}); - } - - at::Tensor out_buf = at::zeros( - {batch_sz / n_parallel_imgs, - out_channels, - n_parallel_imgs * out_h, - out_w}, - out.options()); - - // Separate channels into convolution groups - out_buf = out_buf.view( - {out_buf.size(0), - n_weight_grps, - out_buf.size(1) / n_weight_grps, - out_buf.size(2), - out_buf.size(3)}); - weight_c = weight_c.view( - {n_weight_grps, - weight_c.size(0) / n_weight_grps, - weight_c.size(1), - weight_c.size(2), - weight_c.size(3)}); - - // Sample points and perform convolution - auto columns = at::zeros( - {in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w}, - input_c.options()); - - for (int b = 0; b < batch_sz / n_parallel_imgs; b++) { - deformable_im2col( - input_c[b], - offset_c[b], - mask_c[b], - in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - out_h, - out_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - columns); - - columns = columns.view( - {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); - for (int g = 0; g < n_weight_grps; g++) { - out_buf[b][g] = out_buf[b][g] - .flatten(1) - .addmm_(weight_c[g].flatten(1), columns[g]) - .view_as(out_buf[b][g]); - } - columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); - } - - out_buf = out_buf.view( - {batch_sz / n_parallel_imgs, - out_channels, - n_parallel_imgs, - out_h, - out_w}); - out_buf.transpose_(1, 2); - out.copy_(out_buf); - out = out.view({batch_sz, out_channels, out_h, out_w}); - - return out + bias_c.view({1, out_channels, 1, 1}); -} - -std::tuple -deform_conv2d_backward_kernel( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - bool use_mask) { - at::Tensor grad_out_c = grad_out.contiguous(); - at::Tensor input_c = input.contiguous(); - at::Tensor weight_c = weight.contiguous(); - at::Tensor offset_c = offset.contiguous(); - at::Tensor mask_c = mask.contiguous(); - at::Tensor bias_c = bias.contiguous(); - - const int64_t batch_sz = input_c.size(0); - const int64_t n_parallel_imgs = - get_greatest_divisor_below_bound(batch_sz, tkMaxParallelImgs); - - auto grad_input_and_offset_and_mask = backward_gradient_inputs( - input_c, - weight_c, - offset_c, - mask_c, - grad_out_c, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - n_weight_grps, - n_offset_grps, - n_parallel_imgs, - use_mask); - - auto grad_input = std::get<0>(grad_input_and_offset_and_mask); - auto grad_offset = std::get<1>(grad_input_and_offset_and_mask); - auto grad_mask = std::get<2>(grad_input_and_offset_and_mask); - - auto grad_weight = backward_gradient_parameters( - input_c, - weight_c, - offset_c, - mask_c, - grad_out_c, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - n_weight_grps, - n_offset_grps, - n_parallel_imgs, - use_mask); - - auto value = grad_out_c.sum({0, 2, 3}); - auto grad_bias = at::ones_like(bias_c) * value; - - return std::make_tuple( - grad_input, grad_weight, grad_offset, grad_mask, grad_bias); -} -} // namespace - - -TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), - TORCH_FN(deform_conv2d_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"), - TORCH_FN(deform_conv2d_backward_kernel)); -} - -} // namespace ops -} // namespace vision - From 2f06f7f10f16c07d24933ef6467dcc8eb116c8d8 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Sat, 16 Nov 2024 14:00:42 +0100 Subject: [PATCH 10/31] Changed location references to pytorch --- CMakePresets.json | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index d9812eea2a1..7b9bba2ef08 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -2,8 +2,8 @@ "version": 6, "configurePresets": [ { - "name": "TorchVision", - "displayName": "TorchVision", + "name": "ARM64_MACOS_DEBUG", + "displayName": "ARM64_MACOS_DEBUG", "description": "TorchVision build using make generator", "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", @@ -11,11 +11,11 @@ "CMAKE_BUILD_TYPE":"Debug", "CMAKE_INSTALL_PREFIX": "/Users/thomas/Developer/projects/visionDev/vision/product", "TORCH_LIBRARY": - "/Users/thomas/Developer/projects/visionDev/libtorch/lib/libtorch.dylib", - "Torch_DIR": "/Users/thomas/Developer/projects/visionDev/libtorch/share/cmake/Torch", - "c10_LIBRARY": "/Users/thomas/Developer/projects/visionDev/libtorch/lib/libc10.dylib", - "kineto_LIBRARY": "/Users/thomas/Developer/projects/visionDev/libtorch/lib/libkineto.a", - "Caffe2_DIR": "/Users/thomas/Developer/projects/visionDev/libtorch/share/cmake/Caffe2", + "/Users/thomas/Developer/projects/pytorchDev/pytorch/torch/lib//libtorch.dylib", + "Torch_DIR": "/Users/thomas/Developer/projects/pytorchDev/pytorch/torch/share/cmake/Torch", + "c10_LIBRARY": "/Users/thomas/Developer/projects/pytorchDev/pytorch/torch/lib/libc10.dylib", + "kineto_LIBRARY": "/Users/thomas/Developer/projects/pytorchDev/pytorch/torch/lib/libkineto.a", + "Caffe2_DIR": "/Users/thomas/Developer/projects/pytorchDev/pytorch/torch/share/cmake/Caffe2", "PNG_LIBRARY_RELEASE": "/opt/homebrew/Cellar/libpng/1.6.43/lib/libpng.dylib", "PNG_PNG_INCLUDE_DIR": "/opt/homebrew/Cellar/libpng/1.6.43/include", "JPEG_LIBRARY_RELEASE": "/opt/homebrew/Cellar/jpeg-turbo/3.0.3/lib/libjpeg.dylib", From 66d76d3695f9810240debb8bf5e75b54020381c2 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Sat, 16 Nov 2024 14:01:26 +0100 Subject: [PATCH 11/31] Clean up git - Removing .DS_Store --- .DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index bb0be2e98d6e1941f8401a7b7ae0215a03c0ab17..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKzi-n(6n@tlk~Bii0D|!()=0!cAgT_SGz_d5!9Y=na__e24fgDa5-ROq*B^^yR{Go;uS5(6K`a7z zHQ#@J^5f(F*GKj3B_h0}=WJT(xj{30pP-QDG^V25)-G?KwjQcp&w16f<4}PS=2V>c zSsIH}o>4frs13V3wZ=516N=ywp_~Sbd}c=XDVW7kUi5n3*w(glWs>_$|J@-ME@7>*7ty;|+x9{vf9UjL?D&FGbAxsWJik!@{yhCBb;LcI7 z3u;r)5J6**6cWdx5XgrpW${sEQZ9%nFOBT`hc7#aYLYID`n^mpA5}BeWfp9H0|RE+ zG}{~Mt=TeQ8Tfx0kmm!3O7sk_HL9ZniFyJ69hj9MFW(X}h8y$@t~H_t!gMH5hcffT zU^*OjZu~rhYmGXbn0b6K(=#(K6sCGd`&_vb^EBGjGGG}v%fPlScI5s4c=P-JY>+Kk z1}p>r6a&l|_=7%%WbW3=;N)G)p|7D*q~BVj3PED7W0jFt@hzwl)Nz>sdIr}TVS(5m N0Y!sNECYX(fuGl~v|a!J From c8eb2ea5248459705508abe6e5e5da459445cae3 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Sat, 16 Nov 2024 14:02:38 +0100 Subject: [PATCH 12/31] Altering the kernel deformable_im2col to mimic the cpp kernel implementation. --- torchvision/csrc/ops/mps/mps_kernels.h | 56 +++++++++++++------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index 002a3d4c242..3e2ab526da2 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -1061,9 +1061,9 @@ kernel void ps_roi_pool_backward( \ template kernel void deformable_im2col( constant int64_t & n [[buffer(0)]], - constant scalar_t * input_ptr [[buffer(1)]], - constant scalar_t * offset_ptr [[buffer(2)]], - constant scalar_t * mask_ptr [[buffer(3)]], + constant scalar_t * input [[buffer(1)]], + constant scalar_t * offset [[buffer(2)]], + constant scalar_t * mask [[buffer(3)]], constant int64_t & height [[buffer(4)]], constant int64_t & width [[buffer(5)]], constant int64_t & weight_h [[buffer(6)]], @@ -1080,7 +1080,7 @@ kernel void deformable_im2col( constant int64_t & out_h [[buffer(17)]], constant int64_t & out_w [[buffer(18)]], constant bool & use_mask [[buffer(19)]], - device scalar_t * columns_ptr [[buffer(20)]], + device scalar_t * columns [[buffer(20)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], uint2 tid2 [[thread_position_in_threadgroup]]) { @@ -1094,19 +1094,21 @@ kernel void deformable_im2col( integer_t c_per_offset_grp = n_in_channels / n_offset_grps; const integer_t grp_idx = in_c / c_per_offset_grp; - columns_ptr += - (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + - out_y * out_w + out_x); - - input_ptr += - (out_b * (n_in_channels * height * width) + in_c * (height * width)); - - offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * - out_h * out_w; - + auto columns_ptr = columns + + (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + + out_y * out_w + out_x); + + auto input_ptr = input + + (out_b * (n_in_channels * height * width) + in_c * (height * width)); + + auto offset_ptr = offset + + (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * + out_w; + + auto mask_ptr = mask; if (use_mask) { - mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * - out_h * out_w; + mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * + out_h * out_w; } for (int i = 0; i < weight_h; ++i) { @@ -1117,19 +1119,19 @@ kernel void deformable_im2col( scalar_t mask_value = 1; if (use_mask) { mask_value = - mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; + mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; } const scalar_t offset_h = - offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; - const scalar_t offset_w = offset_ptr - [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; + offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; + const scalar_t offset_w = + offset_ptr[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; const scalar_t y = - (out_y * stride_h - pad_h) + i * dilation_h + offset_h; + (out_y * stride_h - pad_h) + i * dilation_h + offset_h; const scalar_t x = - (out_x * stride_w - pad_w) + j * dilation_w + offset_w; + (out_x * stride_w - pad_w) + j * dilation_w + offset_w; *columns_ptr = - mask_value * bilinear_interpolate(input_ptr, height, width, y, x, index); + mask_value * bilinear_interpolate(input_ptr, height, width, y, x, index); columns_ptr += batch_sz * out_h * out_w; } } @@ -1141,9 +1143,9 @@ template \ [[host_name("deformable_im2col_" #DTYPE)]] \ kernel void deformable_im2col( \ constant int64_t & n [[buffer(0)]], \ - constant DTYPE * input_ptr [[buffer(1)]], \ - constant DTYPE * offset_ptr [[buffer(2)]], \ - constant DTYPE * mask_ptr [[buffer(3)]], \ + constant DTYPE * input [[buffer(1)]], \ + constant DTYPE * offset [[buffer(2)]], \ + constant DTYPE * mask [[buffer(3)]], \ constant int64_t & height [[buffer(4)]], \ constant int64_t & width [[buffer(5)]], \ constant int64_t & weight_h [[buffer(6)]], \ @@ -1162,7 +1164,7 @@ kernel void deformable_im2col( \ constant bool & use_mask [[buffer(19)]], \ device DTYPE * columns_ptr [[buffer(20)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ + uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); From c92eaa41410e8b1bb164db24e5a44c5c12e28334 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Sat, 16 Nov 2024 14:03:19 +0100 Subject: [PATCH 13/31] Re-ordering include sequence --- test/optest.py | 15 +++++++++++++++ torchvision/csrc/ops/mps/deform_conv2d_kernel.mm | 1 - 2 files changed, 15 insertions(+), 1 deletion(-) create mode 100644 test/optest.py diff --git a/test/optest.py b/test/optest.py new file mode 100644 index 00000000000..68caf575dc1 --- /dev/null +++ b/test/optest.py @@ -0,0 +1,15 @@ +# ========================================================= +# BEGIN REPRO SCRIPT +# ========================================================= +import torch +from torch.testing._internal.optests import opcheck + +# Make sure you have loaded the library that contains the op +# via an import or torch.ops.load_library(...) +# op = torch.ops.torchvision.deform_conv2d.default +op = torch.ops.torchvision.roi_align.default +args, kwargs = torch.load("/var/folders/m7/m4jyvbb97ml6nftpw7b6fsk00000gn/T/pytorch_opcheck_safe_to_delete/repro_173109241941725.22.pt") +opcheck(op, args, kwargs, test_utils="test_autograd_registration") +# ========================================================= +# END REPRO SCRIPT +# ========================================================= \ No newline at end of file diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm index 3a4ee6624b5..d48abf0e4e3 100644 --- a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -4,7 +4,6 @@ #include #include -#include #include #include #include "mps_helpers.h" From b445aed991334dca2fcd4fd7e3db3950dd99635b Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Sat, 16 Nov 2024 14:05:38 +0100 Subject: [PATCH 14/31] Including mps in TestDeformConv::test_is_leaf_node --- test/test_ops.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 1ba7a2c9efa..443fd97bc5d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1041,7 +1041,7 @@ def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, ) return DeformConvModuleWrapper(obj) if wrap else obj - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) def test_is_leaf_node(self, device): op_obj = self.make_obj(wrap=True).to(device=device) graph_node_names = get_graph_node_names(op_obj) @@ -1050,12 +1050,17 @@ def test_is_leaf_node(self, device): assert len(graph_node_names[0]) == len(graph_node_names[1]) assert len(graph_node_names[0]) == 1 + op_obj.n_inputs - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) + @pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64)) # , ids=str) @pytest.mark.parametrize("contiguous", (True, False)) - @pytest.mark.parametrize("batch_sz", (0, 33)) + @pytest.mark.parametrize("batch_sz", (0, 3)) @pytest.mark.opcheck_only_one() - def test_forward(self, device, contiguous, batch_sz, dtype=None): + def test_forward(self, device, contiguous, batch_sz, dtype): dtype = dtype or self.dtype + + if device == "mps" and dtype is torch.float64: + pytest.skip("MPS does not support float64") + x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype) in_channels = 6 out_channels = 2 @@ -1974,4 +1979,4 @@ def test_is_leaf_node(self, dim, p, block_size, inplace): if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__ + "::TestDeformConv::test_forward"]) From 951880c9cf54c44459e7c518e3eff462db1e74f3 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Sun, 1 Dec 2024 15:19:25 +0100 Subject: [PATCH 15/31] Updates gitignore --- .gitignore | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.gitignore b/.gitignore index 2fc3fa261dc..1925049c7b6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,11 @@ # MacOS **/.DS_Store +build_xcode/ build/ dist/ +product/ +framework/ torchvision.egg-info/ torchvision/version.py */**/__pycache__ @@ -16,6 +19,9 @@ torchvision/version.py */**/*~ *~ +#Misc +collect_env.py + docs/build # sphinx-gallery docs/source/auto_examples/ From 9f68fd427ff55c93285880e331b135bfa517ac88 Mon Sep 17 00:00:00 2001 From: Goldfish Sound Date: Mon, 2 Dec 2024 17:14:09 +0100 Subject: [PATCH 16/31] Update .gitignore Removed the product which is not part of the repo. --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index dfd03a2fb4a..fbbbeaa3439 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,6 @@ build_xcode/ build/ dist/ -product/ framework/ torchvision.egg-info/ torchvision/version.py From dc305ae735d6c5231577e48438e7950757c3f3f2 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Mon, 2 Dec 2024 16:58:22 +0100 Subject: [PATCH 17/31] CleanUp --- CMakePresets.json | 2 +- .../csrc/ops/mps/deform_conv2d_kernel.mm | 180 +++++++++--------- torchvision/csrc/ops/mps/mps_kernels.h | 21 +- 3 files changed, 101 insertions(+), 102 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index 7b9bba2ef08..8d782b2837c 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -6,7 +6,7 @@ "displayName": "ARM64_MACOS_DEBUG", "description": "TorchVision build using make generator", "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", + "binaryDir": "${sourceDir}/build", "cacheVariables": { "CMAKE_BUILD_TYPE":"Debug", "CMAKE_INSTALL_PREFIX": "/Users/thomas/Developer/projects/visionDev/vision/product", diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm index d48abf0e4e3..e8cba72b509 100644 --- a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -58,9 +58,9 @@ void deformable_im2col(const at::Tensor& input, // These function parameters have all been made contiguous by the caller function deform_conv2d_forward_kernel // Check if it is safe to skip the following: - auto input_c = input.contiguous(); - auto data_offset_c = data_offset.contiguous(); - auto data_mask_c = data_mask.contiguous(); + auto input_c = input; //.contiguous(); + auto data_offset_c = data_offset; //.contiguous(); + auto data_mask_c = data_mask; //.contiguous(); // Get a raw pointer to the underlying data structure of the tensors and cast it as a pointer to an MTLBuffer. id inputBuffer = getMTLBufferStorage(input_c); @@ -342,21 +342,21 @@ void compute_grad_offset_and_mask( } std::tuple backward_gradient_inputs( - at::Tensor input, - at::Tensor weight, - at::Tensor offset, - at::Tensor mask, - at::Tensor grad_out, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - int64_t n_parallel_imgs, - bool use_mask) { + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor mask, + at::Tensor grad_out, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + int64_t n_parallel_imgs, + bool use_mask) { int64_t batch_sz = input.size(0); int64_t n_in_channels = input.size(1); @@ -448,46 +448,46 @@ void compute_grad_offset_and_mask( } compute_grad_offset_and_mask( - columns, - input[elt], - offset[elt], - mask[elt], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - grad_offset[elt], - grad_mask[elt]); + columns, + input[elt], + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + grad_offset[elt], + grad_mask[elt]); compute_grad_input( - columns, - offset[elt], - mask[elt], - n_in_channels, - in_h, - in_w, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - n_parallel_imgs, - n_offset_grps, - use_mask, - grad_input[elt]); + columns, + offset[elt], + mask[elt], + n_in_channels, + in_h, + in_w, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + n_parallel_imgs, + n_offset_grps, + use_mask, + grad_input[elt]); } grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); @@ -503,21 +503,21 @@ void compute_grad_offset_and_mask( } at::Tensor backward_gradient_parameters( - at::Tensor input, - const at::Tensor& weight, - at::Tensor offset, - at::Tensor mask, - const at::Tensor& grad_out, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - int64_t n_parallel_imgs, - bool use_mask) { + at::Tensor input, + const at::Tensor& weight, + at::Tensor offset, + at::Tensor mask, + const at::Tensor& grad_out, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + int64_t n_parallel_imgs, + bool use_mask) { int64_t batch_sz = input.size(0); int64_t n_in_channels = input.size(1); @@ -623,22 +623,21 @@ void compute_grad_offset_and_mask( } at::Tensor deform_conv2d_forward_kernel( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& offset, - const at::Tensor& mask, - const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t n_weight_grps, - int64_t n_offset_grps, - bool use_mask) { + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { - at::Tensor input_c = input.contiguous(); at::Tensor offset_c = offset.contiguous(); at::Tensor weight_c = weight.contiguous(); @@ -651,7 +650,7 @@ void compute_grad_offset_and_mask( TORCH_CHECK(weight_c.ndimension() == 4); TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor"); - at::DeviceGuard guard(input_c.device()); + // at::DeviceGuard guard(input_c.device()); int batch_sz = input_c.size(0); int in_channels = input_c.size(1); @@ -659,7 +658,7 @@ void compute_grad_offset_and_mask( int in_w = input_c.size(3); int n_parallel_imgs = - get_greatest_divisor_below_bound(batch_sz, tkMaxParallelImgs); + get_greatest_divisor_below_bound(batch_sz, tkMaxParallelImgs); int out_channels = weight_c.size(0); int weight_h = weight_c.size(2); @@ -754,12 +753,13 @@ void compute_grad_offset_and_mask( out_channels, out_h, out_w}); + input_c = input_c.view( {batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w}); offset_c = offset_c.view( {batch_sz / n_parallel_imgs, - n_parallel_imgs, + n_parallel_imgs, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); @@ -831,7 +831,7 @@ void compute_grad_offset_and_mask( .view_as(out_buf[b][g]); } columns = - columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); } out_buf = out_buf.view( diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index a3dbf29c4c0..7aac77b7a28 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -1030,9 +1030,9 @@ kernel void ps_roi_pool_backward( \ template kernel void deformable_im2col( constant int64_t & n [[buffer(0)]], - constant scalar_t * input [[buffer(1)]], - constant scalar_t * offset [[buffer(2)]], - constant scalar_t * mask [[buffer(3)]], + constant scalar_t * input_ptr [[buffer(1)]], + constant scalar_t * offset_ptr [[buffer(2)]], + constant scalar_t * mask_ptr [[buffer(3)]], constant int64_t & height [[buffer(4)]], constant int64_t & width [[buffer(5)]], constant int64_t & weight_h [[buffer(6)]], @@ -1049,7 +1049,7 @@ kernel void deformable_im2col( constant int64_t & out_h [[buffer(17)]], constant int64_t & out_w [[buffer(18)]], constant bool & use_mask [[buffer(19)]], - device scalar_t * columns [[buffer(20)]], + device scalar_t * columns_ptr [[buffer(20)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], uint2 tid2 [[thread_position_in_threadgroup]]) { @@ -1063,18 +1063,17 @@ kernel void deformable_im2col( integer_t c_per_offset_grp = n_in_channels / n_offset_grps; const integer_t grp_idx = in_c / c_per_offset_grp; - auto columns_ptr = columns + + columns_ptr += (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + out_y * out_w + out_x); - auto input_ptr = input + + input_ptr += (out_b * (n_in_channels * height * width) + in_c * (height * width)); - auto offset_ptr = offset + + offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w; - auto mask_ptr = mask; if (use_mask) { mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * out_h * out_w; @@ -1112,9 +1111,9 @@ template \ [[host_name("deformable_im2col_" #DTYPE)]] \ kernel void deformable_im2col( \ constant int64_t & n [[buffer(0)]], \ - constant DTYPE * input [[buffer(1)]], \ - constant DTYPE * offset [[buffer(2)]], \ - constant DTYPE * mask [[buffer(3)]], \ + constant DTYPE * input_ptr [[buffer(1)]], \ + constant DTYPE * offset_ptr [[buffer(2)]], \ + constant DTYPE * mask_ptr [[buffer(3)]], \ constant int64_t & height [[buffer(4)]], \ constant int64_t & width [[buffer(5)]], \ constant int64_t & weight_h [[buffer(6)]], \ From e25e620ffe5b27a3d348064de71d64e2558b1f22 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Wed, 4 Dec 2024 13:31:20 +0100 Subject: [PATCH 18/31] Cleaned up - removed added exclusions. --- .gitignore | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gitignore b/.gitignore index fbbbeaa3439..1450ab2c3c8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,3 @@ -# CMAKE -# CmakePresets.json - # MacOS **/.DS_Store From e4fb8c5407a11a82a0f65ea265f426e0a0ebe865 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Wed, 4 Dec 2024 13:47:29 +0100 Subject: [PATCH 19/31] Updated --- CMakePresets.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakePresets.json b/CMakePresets.json index 8d782b2837c..ed0e32d4018 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -9,7 +9,7 @@ "binaryDir": "${sourceDir}/build", "cacheVariables": { "CMAKE_BUILD_TYPE":"Debug", - "CMAKE_INSTALL_PREFIX": "/Users/thomas/Developer/projects/visionDev/vision/product", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/build/product", "TORCH_LIBRARY": "/Users/thomas/Developer/projects/pytorchDev/pytorch/torch/lib//libtorch.dylib", "Torch_DIR": "/Users/thomas/Developer/projects/pytorchDev/pytorch/torch/share/cmake/Torch", From e39867fcedc5c65768585aea9e97d347461ac4bf Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Wed, 4 Dec 2024 13:49:11 +0100 Subject: [PATCH 20/31] Removed CMakePresets.json --- CMakePresets.json | 32 -------------------------------- 1 file changed, 32 deletions(-) delete mode 100644 CMakePresets.json diff --git a/CMakePresets.json b/CMakePresets.json deleted file mode 100644 index ed0e32d4018..00000000000 --- a/CMakePresets.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "version": 6, - "configurePresets": [ - { - "name": "ARM64_MACOS_DEBUG", - "displayName": "ARM64_MACOS_DEBUG", - "description": "TorchVision build using make generator", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build", - "cacheVariables": { - "CMAKE_BUILD_TYPE":"Debug", - "CMAKE_INSTALL_PREFIX": "${sourceDir}/build/product", - "TORCH_LIBRARY": - "/Users/thomas/Developer/projects/pytorchDev/pytorch/torch/lib//libtorch.dylib", - "Torch_DIR": "/Users/thomas/Developer/projects/pytorchDev/pytorch/torch/share/cmake/Torch", - "c10_LIBRARY": "/Users/thomas/Developer/projects/pytorchDev/pytorch/torch/lib/libc10.dylib", - "kineto_LIBRARY": "/Users/thomas/Developer/projects/pytorchDev/pytorch/torch/lib/libkineto.a", - "Caffe2_DIR": "/Users/thomas/Developer/projects/pytorchDev/pytorch/torch/share/cmake/Caffe2", - "PNG_LIBRARY_RELEASE": "/opt/homebrew/Cellar/libpng/1.6.43/lib/libpng.dylib", - "PNG_PNG_INCLUDE_DIR": "/opt/homebrew/Cellar/libpng/1.6.43/include", - "JPEG_LIBRARY_RELEASE": "/opt/homebrew/Cellar/jpeg-turbo/3.0.3/lib/libjpeg.dylib", - "JPEG_INCLUDE_DIR": "/opt/homebrew/Cellar/jpeg-turbo/3.0.3/include", - "WITH_CUDA": "OFF", - "WITH_MPS": "ON", - "WITH_PNG": "ON", - "WITH_JPEG": "ON", - "WITH_WEBP": "OFF", - "WITH_AVIF": "OFF" - } - } - ] - } \ No newline at end of file From 3e2bc0ef384f513e12db2ec98a27e13ebfedface Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Wed, 4 Dec 2024 13:50:09 +0100 Subject: [PATCH 21/31] Updated to exclude CMakePresets.json --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 1450ab2c3c8..9f9df389d30 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# CMAKE +CmakePresets.json + # MacOS **/.DS_Store From bd62ab3b8ace0e303928bc40eb7773aa69d0c729 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Tue, 4 Mar 2025 15:03:34 +0100 Subject: [PATCH 22/31] Added bilinear_interpolate_2 function which is identical to the one used for cpp and cuda implementation of deform_conv2d, and the implementation used in the optest. kernel deformable_im2col: Using threadgroups_per_grid as the n_tgs in the MPS_1D_KERNEL_LOOP to prevent multiple index values generated by the macro, when threadgroups_per_grid is larger than 1. --- torchvision/csrc/ops/mps/mps_kernels.h | 155 ++++++++++++++++--------- 1 file changed, 103 insertions(+), 52 deletions(-) diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index 7aac77b7a28..d89e01a4ba3 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -37,6 +37,59 @@ inline void atomic_add_float(device half* data_ptr, const half val) atomic_fetch_add_explicit((device atomic_float*) data_ptr, static_cast(val), memory_order_relaxed); } +// ********************** TESTING CPU implementation of bilinear_interpolate ********************** +// This implementation is used by the cpu and cuda implementation of the deform_conv2d kernel +// and is needed here in order for the pytest operator test not to fail. + +template +inline scalar_t bilinear_interpolate_2( + constant scalar_t* in, + integer_t height, + integer_t width, + scalar_t h, + scalar_t w, + uint index /* index for debug only*/) { + if (h <= -1 || height <= h || w <= -1 || width <= w) { + return 0; + } + + integer_t h_low = floor(h); + integer_t w_low = floor(w); + integer_t h_high = h_low + 1; + integer_t w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = in[h_low * width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = in[h_low * width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = in[h_high * width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = in[h_high * width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + + + + + + + + + template inline T bilinear_interpolate( constant T* input, @@ -1022,38 +1075,34 @@ kernel void ps_roi_pool_backward( \ /*----------- START OF DEFORM_CONV2D KERNEL IMPLEMENTATION -----------------*/ - - - - - template kernel void deformable_im2col( - constant int64_t & n [[buffer(0)]], - constant scalar_t * input_ptr [[buffer(1)]], - constant scalar_t * offset_ptr [[buffer(2)]], - constant scalar_t * mask_ptr [[buffer(3)]], - constant int64_t & height [[buffer(4)]], - constant int64_t & width [[buffer(5)]], - constant int64_t & weight_h [[buffer(6)]], - constant int64_t & weight_w [[buffer(7)]], - constant int64_t & pad_h [[buffer(8)]], - constant int64_t & pad_w [[buffer(9)]], - constant int64_t & stride_h [[buffer(10)]], - constant int64_t & stride_w [[buffer(11)]], - constant int64_t & dilation_h [[buffer(12)]], - constant int64_t & dilation_w [[buffer(13)]], - constant int64_t & batch_sz [[buffer(14)]], - constant int64_t & n_in_channels [[buffer(15)]], - constant int64_t & n_offset_grps [[buffer(16)]], - constant int64_t & out_h [[buffer(17)]], - constant int64_t & out_w [[buffer(18)]], - constant bool & use_mask [[buffer(19)]], - device scalar_t * columns_ptr [[buffer(20)]], - uint2 tgid [[threadgroup_position_in_grid]], + constant scalar_t * input_ptr [[buffer(0)]], + constant scalar_t * offset_ptr [[buffer(1)]], + constant scalar_t * mask_ptr [[buffer(2)]], + device scalar_t * columns_ptr [[buffer(3)]], + constant int64_t & n [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], + constant int64_t & weight_h [[buffer(7)]], + constant int64_t & weight_w [[buffer(8)]], + constant int64_t & pad_h [[buffer(9)]], + constant int64_t & pad_w [[buffer(10)]], + constant int64_t & stride_h [[buffer(11)]], + constant int64_t & stride_w [[buffer(12)]], + constant int64_t & dilation_h [[buffer(13)]], + constant int64_t & dilation_w [[buffer(14)]], + constant int64_t & batch_sz [[buffer(15)]], + constant int64_t & n_in_channels [[buffer(16)]], + constant int64_t & n_offset_grps [[buffer(17)]], + constant int64_t & out_h [[buffer(18)]], + constant int64_t & out_w [[buffer(19)]], + constant bool & use_mask [[buffer(20)]], + uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]) { - MPS_1D_KERNEL_LOOP(index, n, 1) { + uint2 tid2 [[thread_position_in_threadgroup]], + uint2 tgpg [[threadgroups_per_grid]]) { + MPS_1D_KERNEL_LOOP(index, n, tgpg.x) { const integer_t out_x = index % out_w; const integer_t out_y = (index / out_w) % out_h; const integer_t out_b = (index / (out_w * out_h)) % batch_sz; @@ -1079,6 +1128,7 @@ kernel void deformable_im2col( out_h * out_w; } + // For each element in the filter for (int i = 0; i < weight_h; ++i) { for (int j = 0; j < weight_w; ++j) { const integer_t mask_idx = i * weight_w + j; @@ -1099,7 +1149,7 @@ kernel void deformable_im2col( const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; *columns_ptr = - mask_value * bilinear_interpolate(input_ptr, height, width, y, x, index); + mask_value * bilinear_interpolate_2(input_ptr, height, width, y, x, index); columns_ptr += batch_sz * out_h * out_w; } } @@ -1110,30 +1160,31 @@ kernel void deformable_im2col( template \ [[host_name("deformable_im2col_" #DTYPE)]] \ kernel void deformable_im2col( \ - constant int64_t & n [[buffer(0)]], \ - constant DTYPE * input_ptr [[buffer(1)]], \ - constant DTYPE * offset_ptr [[buffer(2)]], \ - constant DTYPE * mask_ptr [[buffer(3)]], \ - constant int64_t & height [[buffer(4)]], \ - constant int64_t & width [[buffer(5)]], \ - constant int64_t & weight_h [[buffer(6)]], \ - constant int64_t & weight_w [[buffer(7)]], \ - constant int64_t & pad_h [[buffer(8)]], \ - constant int64_t & pad_w [[buffer(9)]], \ - constant int64_t & stride_h [[buffer(10)]], \ - constant int64_t & stride_w [[buffer(11)]], \ - constant int64_t & dilation_h [[buffer(12)]], \ - constant int64_t & dilation_w [[buffer(13)]], \ - constant int64_t & batch_sz [[buffer(14)]], \ - constant int64_t & n_in_channels [[buffer(15)]], \ - constant int64_t & n_offset_grps [[buffer(16)]], \ - constant int64_t & out_h [[buffer(17)]], \ - constant int64_t & out_w [[buffer(18)]], \ - constant bool & use_mask [[buffer(19)]], \ - device DTYPE * columns_ptr [[buffer(20)]], \ + constant DTYPE * input_ptr [[buffer(0)]], \ + constant DTYPE * offset_ptr [[buffer(1)]], \ + constant DTYPE * mask_ptr [[buffer(2)]], \ + device DTYPE * columns_ptr [[buffer(3)]], \ + constant int64_t & n [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & weight_h [[buffer(7)]], \ + constant int64_t & weight_w [[buffer(8)]], \ + constant int64_t & pad_h [[buffer(9)]], \ + constant int64_t & pad_w [[buffer(10)]], \ + constant int64_t & stride_h [[buffer(11)]], \ + constant int64_t & stride_w [[buffer(12)]], \ + constant int64_t & dilation_h [[buffer(13)]], \ + constant int64_t & dilation_w [[buffer(14)]], \ + constant int64_t & batch_sz [[buffer(15)]], \ + constant int64_t & n_in_channels [[buffer(16)]], \ + constant int64_t & n_offset_grps [[buffer(17)]], \ + constant int64_t & out_h [[buffer(18)]], \ + constant int64_t & out_w [[buffer(19)]], \ + constant bool & use_mask [[buffer(20)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); + uint2 tid2 [[thread_position_in_threadgroup]], \ + uint2 tgpg [[threadgroups_per_grid]]); From 350454f97ef59c20298d8777bc5ffb97922cd5d4 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Thu, 6 Mar 2025 15:37:03 +0100 Subject: [PATCH 23/31] Reorganized the numbering of argumnet indexes in img2col --- .../csrc/ops/mps/deform_conv2d_kernel.mm | 56 +++++++++---------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm index e8cba72b509..18c4cb71119 100644 --- a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -55,17 +55,11 @@ void deformable_im2col(const at::Tensor& input, const int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs; - - // These function parameters have all been made contiguous by the caller function deform_conv2d_forward_kernel - // Check if it is safe to skip the following: - auto input_c = input; //.contiguous(); - auto data_offset_c = data_offset; //.contiguous(); - auto data_mask_c = data_mask; //.contiguous(); // Get a raw pointer to the underlying data structure of the tensors and cast it as a pointer to an MTLBuffer. - id inputBuffer = getMTLBufferStorage(input_c); - id data_offsetBuffer = getMTLBufferStorage(data_offset_c); - id data_maskBuffer = getMTLBufferStorage(data_mask_c); + id inputBuffer = getMTLBufferStorage(input); + id data_offsetBuffer = getMTLBufferStorage(data_offset); + id data_maskBuffer = getMTLBufferStorage(data_mask); id data_colBuffer = getMTLBufferStorage(data_col); id device = MPSDevice::getInstance()->device(); @@ -84,33 +78,33 @@ void deformable_im2col(const at::Tensor& input, 1); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_c, data_offset_c, data_mask_c}); + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input, data_offset, data_mask}); id computeEncoder = mpsStream->commandEncoder(); [computeEncoder setComputePipelineState:visionPSO]; - [computeEncoder setBuffer:inputBuffer offset:input_c.storage_offset() * input_c.element_size() atIndex:1]; - [computeEncoder setBuffer:data_offsetBuffer offset:data_offset_c.storage_offset() * data_offset_c.element_size() atIndex:2]; - [computeEncoder setBuffer:data_maskBuffer offset:data_mask_c.storage_offset() * data_mask_c.element_size() atIndex:3]; - [computeEncoder setBuffer:data_colBuffer offset:data_col.storage_offset() * data_col.element_size() atIndex:20]; + [computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0]; + [computeEncoder setBuffer:data_offsetBuffer offset:data_offset.storage_offset() * data_offset.element_size() atIndex:1]; + [computeEncoder setBuffer:data_maskBuffer offset:data_mask.storage_offset() * data_mask.element_size() atIndex:2]; + [computeEncoder setBuffer:data_colBuffer offset:data_col.storage_offset() * data_col.element_size() atIndex:3]; - [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:0]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:10]; - [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:11]; - [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:12]; - [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:13]; - [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:14]; - [computeEncoder setBytes:&n_in_channels length:sizeof(int64_t) atIndex:15]; - [computeEncoder setBytes:&deformable_group length:sizeof(int64_t) atIndex:16]; - [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:17]; - [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:18]; - [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:19]; + [computeEncoder setBytes:&num_kernels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&weight_h length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&weight_w length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pad_h length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&pad_w length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&stride_h length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&stride_w length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&dilation_h length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&dilation_w length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:¶llel_imgs length:sizeof(int64_t) atIndex:15]; + [computeEncoder setBytes:&n_in_channels length:sizeof(int64_t) atIndex:16]; + [computeEncoder setBytes:&deformable_group length:sizeof(int64_t) atIndex:17]; + [computeEncoder setBytes:&out_h length:sizeof(int64_t) atIndex:18]; + [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:19]; + [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:20]; // A threadGroup is equivalent to a cuda's block. NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; From b31a28cb3647011b6f9754b6c9c53653cf98ada3 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Tue, 11 Mar 2025 11:18:16 +0100 Subject: [PATCH 24/31] Added threadgroups_per_grid to deformable_col2im and deformable_col2im_coord kernels Modified the MPS_1D_KERNEL_LOOP to use the new threadgroups_per_grid --- torchvision/csrc/ops/mps/mps_kernels.h | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index d89e01a4ba3..80759a8b3d5 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -1213,10 +1213,12 @@ kernel void deformable_col2im( device scalar_t * grad_im [[buffer(20)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ + uint2 tid2 [[thread_position_in_threadgroup]], + uint2 tgpg [[threadgroups_per_grid]]) { const integer_t grad_im_numel = width * height * channels * batch_sz; - MPS_1D_KERNEL_LOOP(index, n, 1) { + + MPS_1D_KERNEL_LOOP(index, n, tgpg.x) { const integer_t out_x = index % out_w; const integer_t out_y = (index / out_w) % out_h; const integer_t b = (index / (out_w * out_h)) % batch_sz; @@ -1300,7 +1302,8 @@ kernel void deformable_col2im( \ device DTYPE * grad_im [[buffer(20)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); + uint2 tid2 [[thread_position_in_threadgroup]], \ + uint2 tgpg [[threadgroups_per_grid]]); template @@ -1338,8 +1341,6 @@ scalar_t get_coordinate_weight( - - template kernel void deformable_col2im_coord( constant int64_t & n [[buffer(0)]], @@ -1368,8 +1369,9 @@ kernel void deformable_col2im_coord( device scalar_t* grad_mask [[buffer(23)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]) { - MPS_1D_KERNEL_LOOP(index, n, 1) { + uint2 tid2 [[thread_position_in_threadgroup]], + uint2 tgpg [[threadgroups_per_grid]]) { + MPS_1D_KERNEL_LOOP(index, n, tgpg.x) { scalar_t grad_offset_val = 0; scalar_t grad_mask_val = 0; integer_t w = index % out_w; @@ -1432,7 +1434,7 @@ kernel void deformable_col2im_coord( if (use_mask && is_y_direction) { grad_mask_val += col_ptr[col_pos] * - bilinear_interpolate(im_ptr, height, width, y, x, index); + bilinear_interpolate_2(im_ptr, height, width, y, x, index); } im_ptr += height * width; @@ -1483,7 +1485,8 @@ kernel void deformable_col2im_coord( \ device DTYPE * grad_mask [[buffer(23)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); + uint2 tid2 [[thread_position_in_threadgroup]], \ + uint2 tgpg [[threadgroups_per_grid]]); /* ----------END OF DEFORM_CONV2D KERNELS ----------------------*/ From 358dacceb5df0bf545f65ed92e4be4b2ac64589c Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Tue, 11 Mar 2025 11:21:24 +0100 Subject: [PATCH 25/31] Added printTensor utility function - only temporarily Substitutes addmm_ with addmm in the forward pass because addmm_ failes for weight groups > 1 (see comment in the code) --- .../csrc/ops/mps/deform_conv2d_kernel.mm | 51 +++++++++++++++++-- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm index 18c4cb71119..ee04ebed1b7 100644 --- a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -8,6 +8,8 @@ #include #include "mps_helpers.h" #include "mps_kernels.h" +#include +#include namespace vision { @@ -15,8 +17,44 @@ namespace { -const int64_t tkMaxParallelImgs = 32; +const int tkMaxParallelImgs = 32; +// Helper function to print the tensor content +void printTensor(const at::Tensor& tensor, int indent = 0) { + // Print indentation + for (int i = 0; i < indent; ++i) { + std::cout << " "; + } + + // Check if the tensor is a scalar + if (tensor.dim() == 0) { + std::cout << tensor.item() << std::endl; + return; + } + + // Check if the tensor is 1-dimensional + if (tensor.dim() == 1) { + std::cout << "["; + for (int64_t i = 0; i < tensor.size(0); ++i) { + std::cout << tensor[i].item(); + if (i < tensor.size(0) - 1) { + std::cout << ", "; + } + } + std::cout << "]" << std::endl; + return; + } + + // Handle multi-dimensional tensors + std::cout << "[" << std::endl; + for (int64_t i = 0; i < tensor.size(0); ++i) { + printTensor(tensor[i], indent + 1); + } + for (int i = 0; i < indent; ++i) { + std::cout << " "; + } + std::cout << "]" << std::endl; +} void deformable_im2col(const at::Tensor& input, const at::Tensor& data_offset, @@ -818,10 +856,13 @@ void compute_grad_offset_and_mask( columns = columns.view( {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); + + // The use of addmm_ has a bug in pytorch, so we use addmm instead + // This needs to be fixed in the future for (int g = 0; g < n_weight_grps; g++) { - out_buf[b][g] = out_buf[b][g] - .flatten(1) - .addmm_(weight_c[g].flatten(1), columns[g]) + out_buf[b][g] = + addmm(out_buf[b][g] + .flatten(1), weight_c[g].flatten(1), columns[g]) .view_as(out_buf[b][g]); } columns = @@ -864,7 +905,7 @@ void compute_grad_offset_and_mask( at::Tensor offset_c = offset.contiguous(); at::Tensor mask_c = mask.contiguous(); at::Tensor bias_c = bias.contiguous(); - + std::cout << "\ndeform_conv2d_backward_kernel" << std::endl; const int64_t batch_sz = input_c.size(0); const int64_t n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, tkMaxParallelImgs); From 7da876af9ddf5cfb5bcec14ee32e9ba98326cc41 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Tue, 11 Mar 2025 11:34:57 +0100 Subject: [PATCH 26/31] Modifying TestDeformConv to include mps tests. test_forward is passing test_backward is failing --- test/test_ops.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 443fd97bc5d..30d7b5021f8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -929,7 +929,9 @@ def test_batched_nms_implementations(self, seed): class TestDeformConv: dtype = torch.float64 - + mps_dtype = torch.float32 + mps_backward_atol = 2e-2 + def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1): stride_h, stride_w = _pair(stride) pad_h, pad_w = _pair(padding) @@ -1053,7 +1055,7 @@ def test_is_leaf_node(self, device): @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64)) # , ids=str) @pytest.mark.parametrize("contiguous", (True, False)) - @pytest.mark.parametrize("batch_sz", (0, 3)) + @pytest.mark.parametrize("batch_sz", (0, 33)) @pytest.mark.opcheck_only_one() def test_forward(self, device, contiguous, batch_sz, dtype): dtype = dtype or self.dtype @@ -1108,28 +1110,31 @@ def test_wrong_sizes(self): wrong_mask = torch.rand_like(mask[:, :2]) layer(x, offset, wrong_mask) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) - @pytest.mark.parametrize("batch_sz", (0, 33)) + @pytest.mark.parametrize("batch_sz", (0, 6)) @pytest.mark.opcheck_only_one() - def test_backward(self, device, contiguous, batch_sz): + def test_backward(self, device, contiguous, batch_sz, deterministic=False): + atol = self.mps_backward_atol if device == "mps" else 1e-05 + dtype = self.mps_dtype if device == "mps" else self.dtype + x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args( - device, contiguous, batch_sz, self.dtype + device, contiguous, batch_sz, dtype ) def func(x_, offset_, mask_, weight_, bias_): return ops.deform_conv2d( x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_ ) - - gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True) + with DeterministicGuard(deterministic): + gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True) def func_no_mask(x_, offset_, weight_, bias_): return ops.deform_conv2d( x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None ) - - gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True) + with DeterministicGuard(deterministic): + gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True) @torch.jit.script def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_): From 25a2944b0607df9b0aca5f54820377e102194ad1 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Tue, 11 Mar 2025 13:00:17 +0100 Subject: [PATCH 27/31] House Cleaning: Getting rid of redundant contiguous conversions of tensors. --- .../csrc/ops/mps/deform_conv2d_kernel.mm | 41 ++++++++----------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm index ee04ebed1b7..23c9dcb15af 100644 --- a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -191,10 +191,6 @@ void compute_grad_input( at::globalContext().alertNotDeterministic("compute_grad_input"); - auto columns_c = columns.contiguous(); - auto offset_c = offset.contiguous(); - auto mask_c = mask.contiguous(); - const int64_t out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; const int64_t out_w = @@ -203,9 +199,9 @@ void compute_grad_input( const int64_t num_kernels = (int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs; - id columnsBuffer = getMTLBufferStorage(columns_c); - id offsetBuffer = getMTLBufferStorage(offset_c); - id maskBuffer = getMTLBufferStorage(mask_c); + id columnsBuffer = getMTLBufferStorage(columns); + id offsetBuffer = getMTLBufferStorage(offset); + id maskBuffer = getMTLBufferStorage(mask); id grad_imBuffer = getMTLBufferStorage(grad_im); id device = MPSDevice::getInstance()->device(); @@ -223,9 +219,9 @@ void compute_grad_input( [computeEncoder setComputePipelineState:visionPSO]; - [computeEncoder setBuffer:columnsBuffer offset:columns_c.storage_offset() * columns_c.element_size() atIndex:1]; - [computeEncoder setBuffer:offsetBuffer offset:offset_c.storage_offset() * offset_c.element_size() atIndex:2]; - [computeEncoder setBuffer:maskBuffer offset:mask_c.storage_offset() * mask_c.element_size() atIndex:3]; + [computeEncoder setBuffer:columnsBuffer offset:columns.storage_offset() * columns.element_size() atIndex:1]; + [computeEncoder setBuffer:offsetBuffer offset:offset.storage_offset() * offset.element_size() atIndex:2]; + [computeEncoder setBuffer:maskBuffer offset:mask.storage_offset() * mask.element_size() atIndex:3]; [computeEncoder setBuffer:grad_imBuffer offset:grad_im.storage_offset() * grad_im.element_size() atIndex:20]; @@ -291,11 +287,6 @@ void compute_grad_offset_and_mask( using namespace at::native::mps; - auto columns_c = columns; //.contiguous(); - auto input_c = input; //.contiguous(); - auto offset_c = offset; //.contiguous(); - auto mask_c = mask; //.contiguous(); - const int64_t out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; const int64_t out_w = @@ -305,10 +296,10 @@ void compute_grad_offset_and_mask( const int64_t offset_channels = 2 * weight_h * weight_w * n_offset_grps; - id columnsBuffer = getMTLBufferStorage(columns_c); - id inputBuffer = getMTLBufferStorage(input_c); - id offsetBuffer = getMTLBufferStorage(offset_c); - id maskBuffer = getMTLBufferStorage(mask_c); + id columnsBuffer = getMTLBufferStorage(columns); + id inputBuffer = getMTLBufferStorage(input); + id offsetBuffer = getMTLBufferStorage(offset); + id maskBuffer = getMTLBufferStorage(mask); id grad_offsetBuffer = getMTLBufferStorage(grad_offset); id grad_maskBuffer = getMTLBufferStorage(grad_mask); @@ -324,14 +315,14 @@ void compute_grad_offset_and_mask( id visionPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(visionPSO, kernel, {columns_c, input_c, offset_c, mask_c}); + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {columns, input, offset, mask}); [computeEncoder setComputePipelineState:visionPSO]; - [computeEncoder setBuffer:columnsBuffer offset:columns_c.storage_offset() * columns_c.element_size() atIndex:1]; - [computeEncoder setBuffer:inputBuffer offset:input_c.storage_offset() * input_c.element_size() atIndex:2]; - [computeEncoder setBuffer:offsetBuffer offset:offset_c.storage_offset() * offset_c.element_size() atIndex:3]; - [computeEncoder setBuffer:maskBuffer offset:mask_c.storage_offset() * mask_c.element_size() atIndex:4]; + [computeEncoder setBuffer:columnsBuffer offset:columns.storage_offset() * columns.element_size() atIndex:1]; + [computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:2]; + [computeEncoder setBuffer:offsetBuffer offset:offset.storage_offset() * offset.element_size() atIndex:3]; + [computeEncoder setBuffer:maskBuffer offset:mask.storage_offset() * mask.element_size() atIndex:4]; [computeEncoder setBuffer:grad_offsetBuffer offset:grad_offset.storage_offset() * grad_offset.element_size() atIndex:22]; @@ -905,7 +896,7 @@ void compute_grad_offset_and_mask( at::Tensor offset_c = offset.contiguous(); at::Tensor mask_c = mask.contiguous(); at::Tensor bias_c = bias.contiguous(); - std::cout << "\ndeform_conv2d_backward_kernel" << std::endl; + const int64_t batch_sz = input_c.size(0); const int64_t n_parallel_imgs = get_greatest_divisor_below_bound(batch_sz, tkMaxParallelImgs); From bf7784d52d66b931442c82beb25b4f3c2182019a Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Sun, 20 Apr 2025 13:55:44 +0200 Subject: [PATCH 28/31] Renaming of bilinear_interpolate2 to bilinear_interpolate_deform_conv2d --- torchvision/csrc/ops/mps/mps_kernels.h | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index 80759a8b3d5..6c9f169068a 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -37,12 +37,12 @@ inline void atomic_add_float(device half* data_ptr, const half val) atomic_fetch_add_explicit((device atomic_float*) data_ptr, static_cast(val), memory_order_relaxed); } -// ********************** TESTING CPU implementation of bilinear_interpolate ********************** +// ********************** deform_conv2d implementation of bilinear_interpolate ********************** // This implementation is used by the cpu and cuda implementation of the deform_conv2d kernel // and is needed here in order for the pytest operator test not to fail. template -inline scalar_t bilinear_interpolate_2( +inline scalar_t bilinear_interpolate_deform_conv2d( constant scalar_t* in, integer_t height, integer_t width, @@ -84,12 +84,6 @@ inline scalar_t bilinear_interpolate_2( - - - - - - template inline T bilinear_interpolate( constant T* input, @@ -1066,15 +1060,8 @@ kernel void ps_roi_pool_backward( \ - - - - - - /*----------- START OF DEFORM_CONV2D KERNEL IMPLEMENTATION -----------------*/ - template kernel void deformable_im2col( constant scalar_t * input_ptr [[buffer(0)]], @@ -1149,7 +1136,7 @@ kernel void deformable_im2col( const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; *columns_ptr = - mask_value * bilinear_interpolate_2(input_ptr, height, width, y, x, index); + mask_value * bilinear_interpolate_deform_conv2d(input_ptr, height, width, y, x, index); columns_ptr += batch_sz * out_h * out_w; } } @@ -1216,8 +1203,6 @@ kernel void deformable_col2im( uint2 tid2 [[thread_position_in_threadgroup]], uint2 tgpg [[threadgroups_per_grid]]) { const integer_t grad_im_numel = width * height * channels * batch_sz; - - MPS_1D_KERNEL_LOOP(index, n, tgpg.x) { const integer_t out_x = index % out_w; const integer_t out_y = (index / out_w) % out_h; @@ -1434,7 +1419,7 @@ kernel void deformable_col2im_coord( if (use_mask && is_y_direction) { grad_mask_val += col_ptr[col_pos] * - bilinear_interpolate_2(im_ptr, height, width, y, x, index); + bilinear_interpolate_deform_conv2d(im_ptr, height, width, y, x, index); } im_ptr += height * width; From 9d3105fa9c175ab107d97800d1ce2b32cd60a598 Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Sun, 20 Apr 2025 13:59:16 +0200 Subject: [PATCH 29/31] Added constant mps_backward_eps for eps in backward test. Skipping backward test for batch_sz == 0 --- test/test_ops.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index f403be72abd..289a6444f5e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -931,6 +931,7 @@ class TestDeformConv: dtype = torch.float64 mps_dtype = torch.float32 mps_backward_atol = 2e-2 + mps_backward_eps = 1e-3 def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1): stride_h, stride_w = _pair(stride) @@ -1112,12 +1113,18 @@ def test_wrong_sizes(self): @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) - @pytest.mark.parametrize("batch_sz", (0, 6)) + @pytest.mark.parametrize("batch_sz", (0, 33)) @pytest.mark.opcheck_only_one() def test_backward(self, device, contiguous, batch_sz, deterministic=False): + # Batch size of zero fails a check un OperationUtils.mm because tensors with zero as a dimension + # cause the Placeholder::Placeholder to fail. + if device == "mps" and batch_sz == 0: + pytest.skip("MPS does not currently support zero batch size for backpropagation") + atol = self.mps_backward_atol if device == "mps" else 1e-05 dtype = self.mps_dtype if device == "mps" else self.dtype - + eps = self.mps_backward_eps if device == "mps" else 1e-6 + x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args( device, contiguous, batch_sz, dtype ) @@ -1126,15 +1133,15 @@ def func(x_, offset_, mask_, weight_, bias_): return ops.deform_conv2d( x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_ ) - with DeterministicGuard(deterministic): - gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True) - + + gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True, atol=atol, eps=eps) + def func_no_mask(x_, offset_, weight_, bias_): return ops.deform_conv2d( x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None ) - with DeterministicGuard(deterministic): - gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True) + + gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True, atol=atol, eps=eps) @torch.jit.script def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_): @@ -1147,7 +1154,7 @@ def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_): lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation), (x, offset, mask, weight, bias), nondet_tol=1e-5, - fast_mode=True, + fast_mode=True, eps=eps, atol=atol ) @torch.jit.script @@ -1161,7 +1168,7 @@ def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_): lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation), (x, offset, weight, bias), nondet_tol=1e-5, - fast_mode=True, + fast_mode=True, eps=eps, atol=atol ) @needs_cuda From 501d617f77c63bc3970255bf18dfdba71a32335a Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Sun, 20 Apr 2025 14:03:16 +0200 Subject: [PATCH 30/31] Removed unused includes Removed temporary debug functions Removed redundant comments Temporary substitution of .addmm op with addop. Clean-up of debug std::cout statements --- .../csrc/ops/mps/deform_conv2d_kernel.mm | 81 +++++-------------- 1 file changed, 18 insertions(+), 63 deletions(-) diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm index 23c9dcb15af..7377760b169 100644 --- a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -1,6 +1,4 @@ -// vision::ops:: -// deform_conv2d_kernel.mm -// + #include #include @@ -8,9 +6,6 @@ #include #include "mps_helpers.h" #include "mps_kernels.h" -#include -#include - namespace vision { namespace ops { @@ -19,43 +14,6 @@ const int tkMaxParallelImgs = 32; -// Helper function to print the tensor content -void printTensor(const at::Tensor& tensor, int indent = 0) { - // Print indentation - for (int i = 0; i < indent; ++i) { - std::cout << " "; - } - - // Check if the tensor is a scalar - if (tensor.dim() == 0) { - std::cout << tensor.item() << std::endl; - return; - } - - // Check if the tensor is 1-dimensional - if (tensor.dim() == 1) { - std::cout << "["; - for (int64_t i = 0; i < tensor.size(0); ++i) { - std::cout << tensor[i].item(); - if (i < tensor.size(0) - 1) { - std::cout << ", "; - } - } - std::cout << "]" << std::endl; - return; - } - - // Handle multi-dimensional tensors - std::cout << "[" << std::endl; - for (int64_t i = 0; i < tensor.size(0); ++i) { - printTensor(tensor[i], indent + 1); - } - for (int i = 0; i < indent; ++i) { - std::cout << " "; - } - std::cout << "]" << std::endl; -} - void deformable_im2col(const at::Tensor& input, const at::Tensor& data_offset, const at::Tensor& data_mask, @@ -78,7 +36,6 @@ void deformable_im2col(const at::Tensor& input, at::Tensor data_col) { using namespace at::native::mps; - // Validate tensors as of type mps. TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); TORCH_CHECK(data_offset.is_mps(), "data_offset must be a MPS tensor"); TORCH_CHECK(data_mask.is_mps(), "data_mask must be a MPS tensor"); @@ -94,7 +51,6 @@ void deformable_im2col(const at::Tensor& input, const int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs; - // Get a raw pointer to the underlying data structure of the tensors and cast it as a pointer to an MTLBuffer. id inputBuffer = getMTLBufferStorage(input); id data_offsetBuffer = getMTLBufferStorage(data_offset); id data_maskBuffer = getMTLBufferStorage(data_mask); @@ -115,7 +71,6 @@ void deformable_im2col(const at::Tensor& input, 1, 1); - // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input, data_offset, data_mask}); id computeEncoder = mpsStream->commandEncoder(); @@ -144,7 +99,6 @@ void deformable_im2col(const at::Tensor& input, [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:19]; [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:20]; - // A threadGroup is equivalent to a cuda's block. NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; if (tgSize > threadsPerBlock) { tgSize = threadsPerBlock; @@ -175,15 +129,15 @@ void compute_grad_input( int64_t channels, int64_t height, int64_t width, - int64_t weight_h, //kernel_h - int64_t weight_w, //kernel_w + int64_t weight_h, + int64_t weight_w, int64_t pad_h, int64_t pad_w, int64_t stride_h, int64_t stride_w, int64_t dilation_h, int64_t dilation_w, - int64_t parallel_imgs, //batch_sz + int64_t parallel_imgs, int64_t n_offset_grps, bool use_mask, at::Tensor grad_im) { @@ -214,7 +168,6 @@ void compute_grad_input( const std::string kernel = "deformable_col2im_" + scalarToMetalTypeString(columns.scalar_type()); id visionPSO = mps::visionPipelineState(device, kernel); - // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(visionPSO, kernel, {columns, offset, mask}); [computeEncoder setComputePipelineState:visionPSO]; @@ -244,7 +197,6 @@ void compute_grad_input( [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:18]; [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:19]; - // A threadGroup is equivalent to a cuda's block. NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; if (tgSize > threadsPerBlock) { tgSize = threadsPerBlock; @@ -314,7 +266,6 @@ void compute_grad_offset_and_mask( const std::string kernel = "deformable_col2im_coord_" + scalarToMetalTypeString(columns.scalar_type()); id visionPSO = mps::visionPipelineState(device, kernel); - // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(visionPSO, kernel, {columns, input, offset, mask}); [computeEncoder setComputePipelineState:visionPSO]; @@ -349,7 +300,6 @@ void compute_grad_offset_and_mask( [computeEncoder setBytes:&out_w length:sizeof(int64_t) atIndex:20]; [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:21]; - // A threadGroup is equivalent to a cuda's block. NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; if (tgSize > threadsPerBlock) { tgSize = threadsPerBlock; @@ -466,7 +416,7 @@ void compute_grad_offset_and_mask( columns.zero_(); // Separate into weight groups for (int64_t g = 0; g < n_weight_grps; g++) { - columns[g] = columns[g].addmm_( + columns[g] = addmm(columns[g], weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); } @@ -627,12 +577,11 @@ void compute_grad_offset_and_mask( use_mask, columns); + // We need to use addmm instead of addmm_ here to avoid zero values for weight group > 1 for (int64_t g = 0; g < n_weight_grps; g++) { grad_weight[g] = - grad_weight[g] - .flatten(1) - .addmm_( - grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) + addmm((grad_weight[g].flatten(1)), + grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) .view_as(grad_weight[g]); } } @@ -660,7 +609,6 @@ void compute_grad_offset_and_mask( int64_t n_weight_grps, int64_t n_offset_grps, bool use_mask) { - at::Tensor input_c = input.contiguous(); at::Tensor offset_c = offset.contiguous(); at::Tensor weight_c = weight.contiguous(); @@ -673,8 +621,6 @@ void compute_grad_offset_and_mask( TORCH_CHECK(weight_c.ndimension() == 4); TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor"); - // at::DeviceGuard guard(input_c.device()); - int batch_sz = input_c.size(0); int in_channels = input_c.size(1); int in_h = input_c.size(2); @@ -848,7 +794,7 @@ void compute_grad_offset_and_mask( columns = columns.view( {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); - // The use of addmm_ has a bug in pytorch, so we use addmm instead + // The use of in-place .addmm_ has a bug in pytorch, so we use addmm instead // This needs to be fixed in the future for (int g = 0; g < n_weight_grps; g++) { out_buf[b][g] = @@ -890,6 +836,15 @@ void compute_grad_offset_and_mask( int64_t n_weight_grps, int64_t n_offset_grps, bool use_mask) { + + TORCH_CHECK(grad_out.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(weight.is_mps(), "weight must be a MPS tensor"); + TORCH_CHECK(offset.is_mps(), "offset must be a MPS tensor"); + TORCH_CHECK(mask.is_mps(), "mask must be a MPS tensor"); + TORCH_CHECK(bias.is_mps(), "bias must be a MPS tensor"); + TORCH_CHECK(grad_out.scalar_type() != at::kHalf, "MPS does not support deform_conv2 backward with float16 inputs."); + at::Tensor grad_out_c = grad_out.contiguous(); at::Tensor input_c = input.contiguous(); at::Tensor weight_c = weight.contiguous(); From a294c2e4bdff9143f120bd981738efc0299a00af Mon Sep 17 00:00:00 2001 From: Thomas Martin Date: Sun, 20 Apr 2025 14:45:21 +0200 Subject: [PATCH 31/31] Delete --- Deform_conv2d_kernals.metal | 458 ------------------------------------ 1 file changed, 458 deletions(-) delete mode 100644 Deform_conv2d_kernals.metal diff --git a/Deform_conv2d_kernals.metal b/Deform_conv2d_kernals.metal deleted file mode 100644 index fb18af38ebf..00000000000 --- a/Deform_conv2d_kernals.metal +++ /dev/null @@ -1,458 +0,0 @@ -// -// Deform_conv2d_kernals.metal -// torchvision -// -// Created by Thomas Martin on 14/10/2024. -// - -// This include will only work when the remaining code is embedded in a C string in mps_kernels.h -//#include - -#include -#include - -using namespace metal; - -// ********************************************************************** -// MACROS AND HELPER FUNCTIONS SHOULD NOT BE INCLUDED IN THE FINAL SOURCE -// AS THEY ARE ALREADY INCLUDED IN mps_kernels.h -// ********************************************************************** - -/*----------Macros----------*/ - -#define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \ - for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \ - i += (tptg.x * n_tgs)) - -#define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, uint) - - -/*----------Helpers--------*/ - -template -inline T ceil_div(T n, T m) { - return (n + m - 1) / m; -} - - -template -inline void atomic_add_float( device T* data_ptr, const T val) -{ -#if __METAL_VERSION__ >= 300 - // atomic_float is supported in Metal 3 (macOS Ventura) onward. - atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); -#else - // Custom atomic addition implementation - // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 - // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639 - // https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide) - - // Create an atomic uint pointer for atomic transaction. - device atomic_uint* atom_var = (device atomic_uint*)data_ptr; - // Create necessary storage. - uint fetched_uint, assigning_uint; - T fetched_float, assigning_float; - - // Replace the value in atom_var with 0 and return the previous value in atom_var. - fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed); - // Read out the previous value as float. - fetched_float = *( (thread T*) &fetched_uint ); - - // Do addition and represent the addition result in uint for atomic transaction. - assigning_float = fetched_float + val; - assigning_uint = *((thread uint*) &assigning_float); - - // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr). - while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) { - // If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads. - // Try to assign 0 and get the previously assigned addition result. - uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed); - T fetched_float_again = *( (thread T*) &fetched_uint_again ); - // Re-add again - fetched_float = *((thread T*) &(fetched_uint)); - // Previously assigned addition result + addition result from other threads. - assigning_float = fetched_float_again + fetched_float; - assigning_uint = *( (thread uint*) &assigning_float); - } -#endif -} - - -template -kernel void deformable_im2col( - index_t n [[buffer(0)]], - constant scalar_t* input_ptr [[buffer(1)]], - constant scalar_t* offset_ptr [[buffer(2)]], - constant scalar_t* mask_ptr [[buffer(3)]], - index_t height [[buffer(4)]], - index_t width [[buffer(5)]], - index_t weight_h [[buffer(6)]], - index_t weight_w [[buffer(7)]], - index_t pad_h [[buffer(8)]], - index_t pad_w [[buffer(9)]], - index_t stride_h [[buffer(10)]], - index_t stride_w [[buffer(11)]], - index_t dilation_h [[buffer(12)]], - index_t dilation_w [[buffer(13)]], - index_t batch_sz [[buffer(14)]], // parallel_imgs - index_t n_in_channels [[buffer(15)]], - index_t n_offset_grps [[buffer(16)]], //deformable_grp - index_t out_h [[buffer(17)]], - index_t out_w [[buffer(18)]], - constant bool & use_mask [[buffer(19)]], - device scalar_t* columns_ptr [[buffer(20)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]) { - MPS_1D_KERNEL_LOOP(index, n, 1) { - const index_t out_x = index % out_w; - const index_t out_y = (index / out_w) % out_h; - const index_t out_b = (index / (out_w * out_h)) % batch_sz; - const index_t in_c = index / (out_w * out_h * batch_sz); - const index_t out_c = in_c * weight_h * weight_w; - - index_t c_per_offset_grp = n_in_channels / n_offset_grps; - const index_t grp_idx = in_c / c_per_offset_grp; - - columns_ptr += - (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + - out_y * out_w + out_x); - - input_ptr += - (out_b * (n_in_channels * height * width) + in_c * (height * width)); - - offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * - out_h * out_w; - - if (use_mask) { - mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * - out_h * out_w; - } - - for (int i = 0; i < weight_h; ++i) { - for (int j = 0; j < weight_w; ++j) { - const index_t mask_idx = i * weight_w + j; - const index_t offset_idx = 2 * mask_idx; - - scalar_t mask_value = 1; - if (use_mask) { - mask_value = - mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x]; - } - - const scalar_t offset_h = - offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x]; - const scalar_t offset_w = offset_ptr - [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x]; - const scalar_t y = - (out_y * stride_h - pad_h) + i * dilation_h + offset_h; - const scalar_t x = - (out_x * stride_w - pad_w) + j * dilation_w + offset_w; - *columns_ptr = - mask_value * bilinear_interpolate(input_ptr, height, width, y, x); - columns_ptr += batch_sz * out_h * out_w; - } - } - } - -} - -#define REGISTER_DEFORMABLE_IM2COL_OP(DTYPE) \ -template \ -[[host_name("deformable_im2col_" #DTYPE)]] \ -template \ -kernel void deformable_im2col( \ -index_t n [[buffer(0)]], \ -constant scalar_t* input_ptr [[buffer(1)]], \ -constant scalar_t* offset_ptr [[buffer(2)]], \ -constant scalar_t* mask_ptr [[buffer(3)]], \ -index_t height [[buffer(4)]], \ -index_t width [[buffer(5)]], \ -index_t weight_h [[buffer(6)]], \ -index_t weight_w [[buffer(7)]], \ -index_t pad_h [[buffer(8)]], \ -index_t pad_w [[buffer(9)]], \ -index_t stride_h [[buffer(10)]], \ -index_t stride_w [[buffer(11)]], \ -index_t dilation_h [[buffer(12)]], \ -index_t dilation_w [[buffer(13)]], \ -index_t batch_sz [[buffer(14)]], \ -index_t n_in_channels [[buffer(15)]], \ -index_t n_offset_grps [[buffer(16)]], \ -index_t out_h [[buffer(17)]], \ -index_t out_w [[buffer(18)]], \ -constant bool & use_mask [[buffer(19)]], \ -device scalar_t* columns_ptr [[buffer(20)]], \ -uint2 tgid [[threadgroup_position_in_grid]], \ -uint2 tptg [[threads_per_threadgroup]], \ -uint2 tid2 [[thread_position_in_threadgroup]]); - - - - - - - - -template -kernel void deformable_col2im( - index_t n [[buffer(0)]], - constant scalar_t* col [[buffer(1)]], - constant scalar_t* offset_ptr [[buffer(2)]], - constant scalar_t* mask_ptr [[buffer(3)]], - index_t channels [[buffer(4)]], - index_t height [[buffer(5)]], - index_t width [[buffer(6)]], - index_t kernel_h [[buffer(7)]], - index_t kernel_w [[buffer(8)]], - index_t pad_h [[buffer(9)]], - index_t pad_w [[buffer(10)]], - index_t stride_h [[buffer(11)]], - index_t stride_w [[buffer(12)]], - index_t dilation_h [[buffer(13)]], - index_t dilation_w [[buffer(14)]], - index_t batch_sz [[buffer(15)]], //parallel_imgs - index_t n_offset_grps [[buffer(16)]], - index_t out_h [[buffer(17)]], - index_t out_w [[buffer(18)]], - constant bool & use_mask [[buffer(19)]], - constant scalar_t* grad_im [[buffer(20)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - const index_t grad_im_numel = width * height * channels * batch_sz; - - MPS_1D_KERNEL_LOOP(index, n, 1) { - const index_t out_x = index % out_w; - const index_t out_y = (index / out_w) % out_h; - const index_t b = (index / (out_w * out_h)) % batch_sz; - const index_t j = (index / (out_w * out_h * batch_sz)) % kernel_w; - const index_t i = - (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; - const index_t c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); - - index_t c_per_offset_grp = channels / n_offset_grps; - const index_t offset_grp = c / c_per_offset_grp; - - offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * - out_h * out_w; - - if (use_mask) { - mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w * - out_h * out_w; - } - - const index_t mask_idx = i * kernel_w + j; - const index_t offset_idx = 2 * mask_idx; - - const index_t offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; - const index_t offset_w_ptr = - ((offset_idx + 1) * out_h + out_y) * out_w + out_x; - - const scalar_t offset_h = offset_ptr[offset_h_ptr]; - const scalar_t offset_w = offset_ptr[offset_w_ptr]; - - scalar_t mask_value = 1; - if (use_mask) { - mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; - } - - const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; - const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; - - for (index_t dy = -1; dy <= 1; dy++) { - for (index_t dx = -1; dx <= 1; dx++) { - index_t yp = (index_t)y + dy; - index_t xp = (index_t)x + dx; - if (0 <= yp && yp < height && 0 <= xp && xp < width && - abs(y - yp) < 1 && abs(x - xp) < 1) { - index_t grad_pos = ((b * channels + c) * height + yp) * width + xp; - scalar_t weight = (1 - abs(y - yp)) * (1 - abs(x - xp)); - // MSL doesn't support at::native::fastAtomicAdd - if (grad_pos >= 0 && grad_pos < grad_im_numel) { - // Atomically add the computed value directly - atomic_add_float(grad_im + grad_pos, static_cast(mask_value * weight * col[index])); - } - } - } - } - } -} - -#define REGISTER_DEFORMABLE_COL2IM_OP(DTYPE) \ -template \ -[[host_name("deformable_col2im_" #DTYPE)]] \ -template \ -kernel void deformable_col2im( \ - index_t n [[buffer(0)]], \ - constant scalar_t* col [[buffer(1)]], \ - constant scalar_t* offset_ptr [[buffer(2)]], \ - constant scalar_t* mask_ptr [[buffer(3)]], \ - index_t channels [[buffer(4)]], \ - index_t height [[buffer(5)]], \ - index_t width [[buffer(6)]], \ - index_t kernel_h [[buffer(7)]], \ - index_t kernel_w [[buffer(8)]], \ - index_t pad_h [[buffer(9)]], \ - index_t pad_w [[buffer(10)]], \ - index_t stride_h [[buffer(11)]], \ - index_t stride_w [[buffer(12)]], \ - index_t dilation_h [[buffer(13)]], \ - index_t dilation_w [[buffer(14)]], \ - index_t batch_sz [[buffer(15)]], \ - index_t n_offset_grps [[buffer(16)]], \ - index_t out_h [[buffer(17)]], \ - index_t out_w [[buffer(18)]], \ - constant bool & use_mask [[buffer(19)]], \ - constant scalar_t* grad_im [[buffer(20)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); - - - -template -kernel void deformable_col2im_coord( - index_t n [[buffer(0)]], - constant scalar_t* col_ptr [[buffer(1)]], - constant scalar_t* im_ptr [[buffer(2)]], //input - constant scalar_t* offset_ptr [[buffer(3)]], - constant scalar_t* mask_ptr [[buffer(4)]], - index_t channels [[buffer(5)]], - index_t height [[buffer(6)]], - index_t width [[buffer(7)]], - index_t weight_h [[buffer(8)]], - index_t weight_w [[buffer(9)]], - index_t pad_h [[buffer(10)]], - index_t pad_w [[buffer(11)]], - index_t stride_h [[buffer(12)]], - index_t stride_w [[buffer(13)]], - index_t dilation_h [[buffer(14)]], - index_t dilation_w [[buffer(15)]], - index_t batch_sz [[buffer(16)]], //parallel_imgs - index_t offset_channels [[buffer(17)]], - index_t n_offset_grps [[buffer(18)]], - index_t out_h [[buffer(19)]], - index_t out_w [[buffer(20)]], - constant bool & use_mask [[buffer(21)]], - constant scalar_t* grad_offset [[buffer(22)]], - constant scalar_t* grad_mask [[buffer(23)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]) { - MPS_1D_KERNEL_LOOP(index, n, 1) { - scalar_t grad_offset_val = 0; - scalar_t grad_mask_val = 0; - index_t w = index % out_w; - index_t h = (index / out_w) % out_h; - index_t w_w = (index / (out_w * out_h * 2)) % weight_w; - index_t w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; - index_t c = (index / (out_w * out_h)) % offset_channels; - index_t b = index / (out_w * out_h * offset_channels); - - const index_t offset_grp = c / (2 * weight_h * weight_w); - const index_t col_step = weight_h * weight_w; - - index_t c_per_offset_grp = channels / n_offset_grps; - - col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * - out_w * out_h; - im_ptr += - (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width; - offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * - out_h * out_w; - - if (use_mask) { - mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w * - out_h * out_w; - } - - const index_t offset_c = c - offset_grp * 2 * weight_h * weight_w; - const bool is_y_direction = offset_c % 2 == 0; - - const index_t c_bound = c_per_offset_grp * weight_h * weight_w; - for (index_t col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { - const index_t col_pos = - (((col_c * batch_sz + b) * out_h) + h) * out_w + w; - - index_t out_x = col_pos % out_w; - index_t out_y = (col_pos / out_w) % out_h; - index_t j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; - index_t i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; - - const index_t mask_idx = i * weight_w + j; - - const index_t offset_h_ptr = - (((2 * mask_idx) * out_h + out_y) * out_w + out_x); - const index_t offset_w_ptr = - (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); - const scalar_t offset_h = offset_ptr[offset_h_ptr]; - const scalar_t offset_w = offset_ptr[offset_w_ptr]; - - scalar_t mask_value = 1; - if (use_mask) { - mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x]; - } - - scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; - scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; - - const scalar_t weight = - get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction); - grad_offset_val += mask_value * weight * col_ptr[col_pos]; - - if (use_mask && is_y_direction) { - grad_mask_val += col_ptr[col_pos] * - bilinear_interpolate(im_ptr, height, width, y, x); - } - - im_ptr += height * width; - } - - grad_offset[index] = grad_offset_val; - - if (use_mask && is_y_direction) { - const index_t idx = - ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + - w_w) * - out_h + - h) * - out_w + - w; - grad_mask[idx] = grad_mask_val; - } - } -} - -#define REGISTER_DEFORMABLE_COL2IM_COORD_OP(DTYPE) \ -template \ -[[host_name("deformable_col2im_coord_" #DTYPE)]] \ -template \ -kernel void deformable_col2im_coord( \ - index_t n [[buffer(0)]],\ - constant scalar_t* col_ptr [[buffer(1)]], \ - constant scalar_t* im_ptr [[buffer(2)]], \ - constant scalar_t* offset_ptr [[buffer(3)]], \ - constant scalar_t* mask_ptr [[buffer(4)]], \ - index_t channels [[buffer(5)]], \ - index_t height [[buffer(6)]], \ - index_t width [[buffer(7)]], \ - index_t weight_h [[buffer(8)]], \ - index_t weight_w [[buffer(9)]], \ - index_t pad_h [[buffer(10)]], \ - index_t pad_w [[buffer(11)]], \ - index_t stride_h [[buffer(12)]], \ - index_t stride_w [[buffer(13)]], \ - index_t dilation_h [[buffer(14)]], \ - index_t dilation_w [[buffer(15)]], \ - index_t batch_sz [[buffer(16)]], \ - index_t offset_channels [[buffer(17)]], \ - index_t n_offset_grps [[buffer(18)]], \ - index_t out_h [[buffer(19)]], \ - index_t out_w [[buffer(20)]], \ - constant bool & use_mask [[buffer(21)]], \ - constant scalar_t* grad_offset [[buffer(22)]], \ - constant scalar_t* grad_mask [[buffer(23)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]);