Skip to content

Commit 660e508

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Refactor preset transforms (#5562)
Summary: * Refactor preset transforms * Making presets public. Reviewed By: vmoens Differential Revision: D34878982 fbshipit-source-id: f1551f5d1f98a58a6820d28b4bd8a0dcbdc53226
1 parent 062b3ef commit 660e508

40 files changed

+207
-174
lines changed

references/classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def load_data(traindir, valdir, args):
163163
weights = prototype.models.get_weight(args.weights)
164164
preprocessing = weights.transforms()
165165
else:
166-
preprocessing = prototype.transforms.ImageNetEval(
166+
preprocessing = prototype.transforms.ImageClassificationEval(
167167
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
168168
)
169169

references/detection/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_transform(train, args):
5757
weights = prototype.models.get_weight(args.weights)
5858
return weights.transforms()
5959
else:
60-
return prototype.transforms.CocoEval()
60+
return prototype.transforms.ObjectDetectionEval()
6161

6262

6363
def get_args_parser(add_help=True):

references/optical_flow/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def validate(model, args):
137137
weights = prototype.models.get_weight(args.weights)
138138
preprocessing = weights.transforms()
139139
else:
140-
preprocessing = prototype.transforms.RaftEval()
140+
preprocessing = prototype.transforms.OpticalFlowEval()
141141
else:
142142
preprocessing = OpticalFlowPresetEval()
143143

references/segmentation/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_transform(train, args):
4242
weights = prototype.models.get_weight(args.weights)
4343
return weights.transforms()
4444
else:
45-
return prototype.transforms.VocEval(resize_size=520)
45+
return prototype.transforms.SemanticSegmentationEval(resize_size=520)
4646

4747

4848
def criterion(inputs, target):

references/video_classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def main(args):
157157
weights = prototype.models.get_weight(args.weights)
158158
transform_test = weights.transforms()
159159
else:
160-
transform_test = prototype.transforms.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171))
160+
transform_test = prototype.transforms.VideoClassificationEval(crop_size=(112, 112), resize_size=(128, 171))
161161

162162
if args.cache_dataset and os.path.exists(cache_path):
163163
print(f"Loading dataset_test from {cache_path}")

torchvision/prototype/models/alexnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from functools import partial
22
from typing import Any, Optional
33

4-
from torchvision.prototype.transforms import ImageNetEval
4+
from torchvision.prototype.transforms import ImageClassificationEval
55
from torchvision.transforms.functional import InterpolationMode
66

77
from ...models.alexnet import AlexNet
@@ -16,7 +16,7 @@
1616
class AlexNet_Weights(WeightsEnum):
1717
IMAGENET1K_V1 = Weights(
1818
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
19-
transforms=partial(ImageNetEval, crop_size=224),
19+
transforms=partial(ImageClassificationEval, crop_size=224),
2020
meta={
2121
"task": "image_classification",
2222
"architecture": "AlexNet",

torchvision/prototype/models/convnext.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from functools import partial
22
from typing import Any, List, Optional
33

4-
from torchvision.prototype.transforms import ImageNetEval
4+
from torchvision.prototype.transforms import ImageClassificationEval
55
from torchvision.transforms.functional import InterpolationMode
66

77
from ...models.convnext import ConvNeXt, CNBlockConfig
@@ -56,7 +56,7 @@ def _convnext(
5656
class ConvNeXt_Tiny_Weights(WeightsEnum):
5757
IMAGENET1K_V1 = Weights(
5858
url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
59-
transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
59+
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236),
6060
meta={
6161
**_COMMON_META,
6262
"num_params": 28589128,
@@ -70,7 +70,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
7070
class ConvNeXt_Small_Weights(WeightsEnum):
7171
IMAGENET1K_V1 = Weights(
7272
url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
73-
transforms=partial(ImageNetEval, crop_size=224, resize_size=230),
73+
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230),
7474
meta={
7575
**_COMMON_META,
7676
"num_params": 50223688,
@@ -84,7 +84,7 @@ class ConvNeXt_Small_Weights(WeightsEnum):
8484
class ConvNeXt_Base_Weights(WeightsEnum):
8585
IMAGENET1K_V1 = Weights(
8686
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
87-
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
87+
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
8888
meta={
8989
**_COMMON_META,
9090
"num_params": 88591464,
@@ -98,7 +98,7 @@ class ConvNeXt_Base_Weights(WeightsEnum):
9898
class ConvNeXt_Large_Weights(WeightsEnum):
9999
IMAGENET1K_V1 = Weights(
100100
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
101-
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
101+
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
102102
meta={
103103
**_COMMON_META,
104104
"num_params": 197767336,

torchvision/prototype/models/densenet.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Optional, Tuple
44

55
import torch.nn as nn
6-
from torchvision.prototype.transforms import ImageNetEval
6+
from torchvision.prototype.transforms import ImageClassificationEval
77
from torchvision.transforms.functional import InterpolationMode
88

99
from ...models.densenet import DenseNet
@@ -78,7 +78,7 @@ def _densenet(
7878
class DenseNet121_Weights(WeightsEnum):
7979
IMAGENET1K_V1 = Weights(
8080
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
81-
transforms=partial(ImageNetEval, crop_size=224),
81+
transforms=partial(ImageClassificationEval, crop_size=224),
8282
meta={
8383
**_COMMON_META,
8484
"num_params": 7978856,
@@ -92,7 +92,7 @@ class DenseNet121_Weights(WeightsEnum):
9292
class DenseNet161_Weights(WeightsEnum):
9393
IMAGENET1K_V1 = Weights(
9494
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
95-
transforms=partial(ImageNetEval, crop_size=224),
95+
transforms=partial(ImageClassificationEval, crop_size=224),
9696
meta={
9797
**_COMMON_META,
9898
"num_params": 28681000,
@@ -106,7 +106,7 @@ class DenseNet161_Weights(WeightsEnum):
106106
class DenseNet169_Weights(WeightsEnum):
107107
IMAGENET1K_V1 = Weights(
108108
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
109-
transforms=partial(ImageNetEval, crop_size=224),
109+
transforms=partial(ImageClassificationEval, crop_size=224),
110110
meta={
111111
**_COMMON_META,
112112
"num_params": 14149480,
@@ -120,7 +120,7 @@ class DenseNet169_Weights(WeightsEnum):
120120
class DenseNet201_Weights(WeightsEnum):
121121
IMAGENET1K_V1 = Weights(
122122
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
123-
transforms=partial(ImageNetEval, crop_size=224),
123+
transforms=partial(ImageClassificationEval, crop_size=224),
124124
meta={
125125
**_COMMON_META,
126126
"num_params": 20013928,

torchvision/prototype/models/detection/faster_rcnn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Optional, Union
22

33
from torch import nn
4-
from torchvision.prototype.transforms import CocoEval
4+
from torchvision.prototype.transforms import ObjectDetectionEval
55
from torchvision.transforms.functional import InterpolationMode
66

77
from ....models.detection.faster_rcnn import (
@@ -43,7 +43,7 @@
4343
class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
4444
COCO_V1 = Weights(
4545
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
46-
transforms=CocoEval,
46+
transforms=ObjectDetectionEval,
4747
meta={
4848
**_COMMON_META,
4949
"num_params": 41755286,
@@ -57,7 +57,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
5757
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
5858
COCO_V1 = Weights(
5959
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
60-
transforms=CocoEval,
60+
transforms=ObjectDetectionEval,
6161
meta={
6262
**_COMMON_META,
6363
"num_params": 19386354,
@@ -71,7 +71,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
7171
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
7272
COCO_V1 = Weights(
7373
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
74-
transforms=CocoEval,
74+
transforms=ObjectDetectionEval,
7575
meta={
7676
**_COMMON_META,
7777
"num_params": 19386354,

torchvision/prototype/models/detection/fcos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Optional
22

33
from torch import nn
4-
from torchvision.prototype.transforms import CocoEval
4+
from torchvision.prototype.transforms import ObjectDetectionEval
55
from torchvision.transforms.functional import InterpolationMode
66

77
from ....models.detection.fcos import (
@@ -27,7 +27,7 @@
2727
class FCOS_ResNet50_FPN_Weights(WeightsEnum):
2828
COCO_V1 = Weights(
2929
url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
30-
transforms=CocoEval,
30+
transforms=ObjectDetectionEval,
3131
meta={
3232
"task": "image_object_detection",
3333
"architecture": "FCOS",

0 commit comments

Comments
 (0)