Skip to content

Commit 8265cc3

Browse files
committed
Add more models as part of GA models
Summary: I realized that we don't have object detection models in our repo. This is based on looking at list in https://mlcommons.org/benchmarks/inference-mobile/ and https://ai-benchmark.com/tests.html and coming up with a list that we don't have in our examples.
1 parent aec1322 commit 8265cc3

File tree

20 files changed

+493
-3
lines changed

20 files changed

+493
-3
lines changed

.ci/scripts/test_model.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,9 @@ elif [[ "${BACKEND}" == *"xnnpack"* ]]; then
317317
echo "Testing ${MODEL_NAME} with xnnpack..."
318318
WITH_QUANTIZATION=true
319319
WITH_DELEGATION=true
320-
if [[ "$MODEL_NAME" == "mobilebert" ]]; then
321-
# TODO(T197452682)
320+
if [[ "$MODEL_NAME" == "mobilebert" || "$MODEL_NAME" == "albert" ]]; then
321+
# TODO(https://github.com/pytorch/executorch/issues/12341)
322+
# mobilebert, albert, xlsr, bilstm incompatible with XNNPACK quantization
322323
WITH_QUANTIZATION=false
323324
fi
324325
test_model_with_xnnpack "${WITH_QUANTIZATION}" "${WITH_DELEGATION}"

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
contents: read
6464
strategy:
6565
matrix:
66-
model: [linear, add, add_mul, ic3, ic4, mv2, mv3, resnet18, resnet50, vit, w2l, mobilebert, emformer_join, emformer_transcribe]
66+
model: [linear, add, add_mul, ic3, ic4, mv2, mv3, resnet18, resnet50, vit, w2l, mobilebert, emformer_join, emformer_transcribe, efficientnet_b4, detr_resnet50, segformer_ade, albert, xlsr, bilstm]
6767
backend: [portable, xnnpack-quantization-delegation]
6868
runner: [linux.arm64.2xlarge]
6969
include:

examples/models/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ class Model(str, Enum):
3737
EfficientSam = "efficient_sam"
3838
Qwen25 = "qwen2_5"
3939
Phi4Mini = "phi_4_mini"
40+
EfficientNetB4 = "efficientnet_b4"
41+
DetrResNet50 = "detr_resnet50"
42+
SegformerADE = "segformer_ade"
43+
Albert = "albert"
44+
BiLSTM = "bilstm"
45+
Swin2SR2x = "swin2sr_2x"
46+
TrOCRHandwritten = "trocr_handwritten"
47+
XLSR = "xlsr"
4048

4149
def __str__(self) -> str:
4250
return self.value
@@ -82,6 +90,14 @@ def __str__(self) -> str:
8290
str(Model.EfficientSam): ("efficient_sam", "EfficientSAM"),
8391
str(Model.Qwen25): ("qwen2_5", "Qwen2_5Model"),
8492
str(Model.Phi4Mini): ("phi_4_mini", "Phi4MiniModel"),
93+
str(Model.EfficientNetB4): ("efficientnet_b4", "EfficientNetB4Model"),
94+
str(Model.DetrResNet50): ("detr_resnet50", "DetrResNet50Model"),
95+
str(Model.SegformerADE): ("segformer_ade", "SegformerADEModel"),
96+
str(Model.Albert): ("albert", "AlbertModelExample"),
97+
str(Model.BiLSTM): ("bilstm", "BidirectionalLSTMModel"),
98+
str(Model.Swin2SR2x): ("swin2sr_2x", "Swin2SR2xModel"),
99+
str(Model.TrOCRHandwritten): ("trocr_handwritten", "TrOCRHandwrittenModel"),
100+
str(Model.XLSR): ("xlsr", "XLSRModel"),
85101
}
86102

87103
__all__ = [

examples/models/albert/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import AlbertModelExample
8+
9+
__all__ = [
10+
"AlbertModelExample",
11+
]

examples/models/albert/model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
11+
from transformers import AlbertModel, AutoTokenizer # @manual
12+
13+
from ..model_base import EagerModelBase
14+
15+
16+
class AlbertModelExample(EagerModelBase):
17+
def __init__(self):
18+
pass
19+
20+
def get_eager_model(self) -> torch.nn.Module:
21+
logging.info("Loading ALBERT model")
22+
# pyre-ignore
23+
model = AlbertModel.from_pretrained("albert-base-v2", return_dict=False)
24+
model.eval()
25+
logging.info("Loaded ALBERT model")
26+
return model
27+
28+
def get_example_inputs(self):
29+
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
30+
return (tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"],)

examples/models/bilstm/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import BidirectionalLSTMModel
8+
9+
__all__ = [
10+
"BidirectionalLSTMModel",
11+
]

examples/models/bilstm/model.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
from ..model_base import EagerModelBase
13+
14+
15+
class BidirectionalLSTM(nn.Module):
16+
"""Bidirectional LSTM for sequence modeling"""
17+
18+
def __init__(self, input_size=100, hidden_size=128, num_layers=2, num_classes=10):
19+
super(BidirectionalLSTM, self).__init__()
20+
self.hidden_size = hidden_size
21+
self.num_layers = num_layers
22+
23+
# Bidirectional LSTM
24+
self.lstm = nn.LSTM(
25+
input_size, hidden_size, num_layers, batch_first=True, bidirectional=True
26+
)
27+
28+
# Output layer (hidden_size * 2 because of bidirectional)
29+
self.fc = nn.Linear(hidden_size * 2, num_classes)
30+
31+
def forward(self, x):
32+
# Initialize hidden states
33+
# For bidirectional: hidden states shape is (num_layers * 2, batch, hidden_size)
34+
h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
35+
c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
36+
37+
# LSTM forward pass
38+
out, _ = self.lstm(x, (h0, c0))
39+
40+
# Take the last time step output
41+
out = self.fc(out[:, -1, :])
42+
return out
43+
44+
45+
class BidirectionalLSTMTextClassifier(nn.Module):
46+
"""Bidirectional LSTM for text classification with embedding layer"""
47+
48+
def __init__(
49+
self, vocab_size=10000, embedding_dim=128, hidden_size=256, num_classes=2
50+
):
51+
super(BidirectionalLSTMTextClassifier, self).__init__()
52+
self.hidden_size = hidden_size
53+
54+
# Embedding layer
55+
self.embedding = nn.Embedding(vocab_size, embedding_dim)
56+
57+
# Bidirectional LSTM
58+
self.lstm = nn.LSTM(
59+
embedding_dim, hidden_size, bidirectional=True, batch_first=True
60+
)
61+
62+
# Output layer
63+
self.fc = nn.Linear(hidden_size * 2, num_classes)
64+
65+
def forward(self, x):
66+
# Embedding
67+
embedded = self.embedding(x)
68+
69+
# LSTM
70+
lstm_out, _ = self.lstm(embedded)
71+
72+
# Global max pooling over sequence dimension
73+
pooled = torch.max(lstm_out, dim=1)[0]
74+
75+
# Classification
76+
output = self.fc(pooled)
77+
return output
78+
79+
80+
class BidirectionalLSTMModel(EagerModelBase):
81+
def __init__(self):
82+
pass
83+
84+
def get_eager_model(self) -> torch.nn.Module:
85+
logging.info("Loading Bidirectional LSTM model")
86+
model = BidirectionalLSTM(
87+
input_size=100, hidden_size=128, num_layers=2, num_classes=10
88+
)
89+
model.eval()
90+
logging.info("Loaded Bidirectional LSTM model")
91+
return model
92+
93+
def get_example_inputs(self):
94+
# Example: (batch_size=1, seq_len=50, input_size=100)
95+
tensor_size = (1, 50, 100)
96+
return (torch.randn(tensor_size),)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import DetrResNet50Model
8+
9+
__all__ = [
10+
"DetrResNet50Model",
11+
]
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from transformers import DetrForObjectDetection
11+
12+
from ..model_base import EagerModelBase
13+
14+
15+
class DetrWrapper(torch.nn.Module):
16+
"""Wrapper for HuggingFace DETR model to make it torch.export compatible"""
17+
18+
def __init__(self, model_name="facebook/detr-resnet-50"):
19+
super().__init__()
20+
self.detr = DetrForObjectDetection.from_pretrained(model_name)
21+
self.detr.eval()
22+
23+
def forward(self, pixel_values):
24+
# pixel_values: [batch, 3, height, width] - RGB image
25+
with torch.no_grad():
26+
outputs = self.detr(pixel_values)
27+
# Return logits and boxes for object detection
28+
return outputs.logits, outputs.pred_boxes
29+
30+
31+
class DetrResNet50Model(EagerModelBase):
32+
def __init__(self):
33+
pass
34+
35+
def get_eager_model(self) -> torch.nn.Module:
36+
logging.info("Loading DETR ResNet-50 model from HuggingFace")
37+
model = DetrWrapper("facebook/detr-resnet-50")
38+
model.eval()
39+
logging.info("Loaded DETR ResNet-50 model")
40+
return model
41+
42+
def get_example_inputs(self):
43+
# DETR standard input size: 800x800 RGB image (can handle various sizes)
44+
tensor_size = (1, 3, 800, 800)
45+
return (torch.randn(tensor_size),)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import EfficientNetB4Model
8+
9+
__all__ = [
10+
"EfficientNetB4Model",
11+
]

0 commit comments

Comments
 (0)