Skip to content

Commit 145ff6c

Browse files
committed
Add more models as part of GA models
Summary: I realized that we don't have object detection models in our repo. This is based on looking at list in https://mlcommons.org/benchmarks/inference-mobile/ and https://ai-benchmark.com/tests.html and coming up with a list that we don't have in our examples.
1 parent aaf0a4c commit 145ff6c

File tree

25 files changed

+1234
-1
lines changed

25 files changed

+1234
-1
lines changed

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
contents: read
6464
strategy:
6565
matrix:
66-
model: [linear, add, add_mul, ic3, ic4, mv2, mv3, resnet18, resnet50, vit, w2l, mobilebert, emformer_join, emformer_transcribe]
66+
model: [linear, add, add_mul, ic3, ic4, mv2, mv3, resnet18, resnet50, vit, w2l, mobilebert, emformer_join, emformer_transcribe, efficientnet_b4, yolov4_tiny, unet, albert, ssd_mobilenetv2, deeplabv3_mobilenet, bert, bilstm, esrgan, srgan, crnn]
6767
backend: [portable, xnnpack-quantization-delegation]
6868
runner: [linux.arm64.2xlarge]
6969
include:

examples/models/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,18 @@ class Model(str, Enum):
3737
EfficientSam = "efficient_sam"
3838
Qwen25 = "qwen2_5"
3939
Phi4Mini = "phi_4_mini"
40+
# Newly added models
41+
EfficientNetB4 = "efficientnet_b4"
42+
YOLOv4Tiny = "yolov4_tiny"
43+
UNet = "unet"
44+
Albert = "albert"
45+
SSDMobileNetV2 = "ssd_mobilenetv2"
46+
DeepLabV3MobileNet = "deeplabv3_mobilenet"
47+
Bert = "bert"
48+
BiLSTM = "bilstm"
49+
ESRGAN = "esrgan"
50+
SRGAN = "srgan"
51+
CRNN = "crnn"
4052

4153
def __str__(self) -> str:
4254
return self.value
@@ -82,6 +94,18 @@ def __str__(self) -> str:
8294
str(Model.EfficientSam): ("efficient_sam", "EfficientSAM"),
8395
str(Model.Qwen25): ("qwen2_5", "Qwen2_5Model"),
8496
str(Model.Phi4Mini): ("phi_4_mini", "Phi4MiniModel"),
97+
# Newly added models
98+
str(Model.EfficientNetB4): ("efficientnet_b4", "EfficientNetB4Model"),
99+
str(Model.YOLOv4Tiny): ("yolov4_tiny", "YOLOv4TinyModel"),
100+
str(Model.UNet): ("unet", "UNetBinaryModel"),
101+
str(Model.Albert): ("albert", "AlbertModelExample"),
102+
str(Model.SSDMobileNetV2): ("ssd_mobilenetv2", "SSDMobileNetV2Model"),
103+
str(Model.DeepLabV3MobileNet): ("deeplabv3_mobilenet", "DeepLabV3MobileNetModel"),
104+
str(Model.Bert): ("bert", "BertModelExample"),
105+
str(Model.BiLSTM): ("bilstm", "BiLSTMModel"),
106+
str(Model.ESRGAN): ("esrgan", "ESRGANModel"),
107+
str(Model.SRGAN): ("srgan", "SRGANModel"),
108+
str(Model.CRNN): ("crnn", "CRNNModel"),
85109
}
86110

87111
__all__ = [

examples/models/albert/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import AlbertModelExample, AlbertLargeModelExample

examples/models/albert/model.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
11+
from transformers import AlbertModel, AutoTokenizer # @manual
12+
13+
from ..model_base import EagerModelBase
14+
15+
16+
class AlbertModelExample(EagerModelBase):
17+
def __init__(self):
18+
pass
19+
20+
def get_eager_model(self) -> torch.nn.Module:
21+
logging.info("Loading ALBERT model")
22+
# pyre-ignore
23+
model = AlbertModel.from_pretrained(
24+
"albert-base-v2", return_dict=False
25+
)
26+
model.eval()
27+
logging.info("Loaded ALBERT model")
28+
return model
29+
30+
def get_example_inputs(self):
31+
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
32+
return (tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"],)
33+
34+
35+
class AlbertLargeModelExample(EagerModelBase):
36+
def __init__(self):
37+
pass
38+
39+
def get_eager_model(self) -> torch.nn.Module:
40+
logging.info("Loading ALBERT-Large model")
41+
# pyre-ignore
42+
model = AlbertModel.from_pretrained(
43+
"albert-large-v2", return_dict=False
44+
)
45+
model.eval()
46+
logging.info("Loaded ALBERT-Large model")
47+
return model
48+
49+
def get_example_inputs(self):
50+
tokenizer = AutoTokenizer.from_pretrained("albert-large-v2")
51+
return (tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"],)

examples/models/bert/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import BertModelExample, BertLargeModelExample

examples/models/bert/model.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
11+
from transformers import AutoTokenizer, BertModel # @manual
12+
13+
from ..model_base import EagerModelBase
14+
15+
16+
class BertModelExample(EagerModelBase):
17+
def __init__(self):
18+
pass
19+
20+
def get_eager_model(self) -> torch.nn.Module:
21+
logging.info("Loading BERT model")
22+
# pyre-ignore
23+
model = BertModel.from_pretrained(
24+
"google-bert/bert-base-uncased", return_dict=False
25+
)
26+
model.eval()
27+
logging.info("Loaded BERT model")
28+
return model
29+
30+
def get_example_inputs(self):
31+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
32+
return (tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"],)
33+
34+
35+
class BertLargeModelExample(EagerModelBase):
36+
def __init__(self):
37+
pass
38+
39+
def get_eager_model(self) -> torch.nn.Module:
40+
logging.info("Loading BERT-Large model")
41+
# pyre-ignore
42+
model = BertModel.from_pretrained(
43+
"google-bert/bert-large-uncased", return_dict=False
44+
)
45+
model.eval()
46+
logging.info("Loaded BERT-Large model")
47+
return model
48+
49+
def get_example_inputs(self):
50+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-large-uncased")
51+
return (tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"],)

examples/models/bilstm/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import BidirectionalLSTMModel, BidirectionalLSTMTextModel

examples/models/bilstm/model.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
from ..model_base import EagerModelBase
13+
14+
15+
class BidirectionalLSTM(nn.Module):
16+
"""Bidirectional LSTM for sequence modeling"""
17+
18+
def __init__(self, input_size=100, hidden_size=128, num_layers=2, num_classes=10):
19+
super(BidirectionalLSTM, self).__init__()
20+
self.hidden_size = hidden_size
21+
self.num_layers = num_layers
22+
23+
# Bidirectional LSTM
24+
self.lstm = nn.LSTM(
25+
input_size,
26+
hidden_size,
27+
num_layers,
28+
batch_first=True,
29+
bidirectional=True
30+
)
31+
32+
# Output layer (hidden_size * 2 because of bidirectional)
33+
self.fc = nn.Linear(hidden_size * 2, num_classes)
34+
35+
def forward(self, x):
36+
# Initialize hidden states
37+
# For bidirectional: hidden states shape is (num_layers * 2, batch, hidden_size)
38+
h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
39+
c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
40+
41+
# LSTM forward pass
42+
out, _ = self.lstm(x, (h0, c0))
43+
44+
# Take the last time step output
45+
out = self.fc(out[:, -1, :])
46+
return out
47+
48+
49+
class BidirectionalLSTMTextClassifier(nn.Module):
50+
"""Bidirectional LSTM for text classification with embedding layer"""
51+
52+
def __init__(self, vocab_size=10000, embedding_dim=128, hidden_size=256, num_classes=2):
53+
super(BidirectionalLSTMTextClassifier, self).__init__()
54+
self.hidden_size = hidden_size
55+
56+
# Embedding layer
57+
self.embedding = nn.Embedding(vocab_size, embedding_dim)
58+
59+
# Bidirectional LSTM
60+
self.lstm = nn.LSTM(
61+
embedding_dim,
62+
hidden_size,
63+
bidirectional=True,
64+
batch_first=True
65+
)
66+
67+
# Output layer
68+
self.fc = nn.Linear(hidden_size * 2, num_classes)
69+
70+
def forward(self, x):
71+
# Embedding
72+
embedded = self.embedding(x)
73+
74+
# LSTM
75+
lstm_out, _ = self.lstm(embedded)
76+
77+
# Global max pooling over sequence dimension
78+
pooled = torch.max(lstm_out, dim=1)[0]
79+
80+
# Classification
81+
output = self.fc(pooled)
82+
return output
83+
84+
85+
class BidirectionalLSTMModel(EagerModelBase):
86+
def __init__(self):
87+
pass
88+
89+
def get_eager_model(self) -> torch.nn.Module:
90+
logging.info("Loading Bidirectional LSTM model")
91+
model = BidirectionalLSTM(
92+
input_size=100,
93+
hidden_size=128,
94+
num_layers=2,
95+
num_classes=10
96+
)
97+
model.eval()
98+
logging.info("Loaded Bidirectional LSTM model")
99+
return model
100+
101+
def get_example_inputs(self):
102+
# Example: (batch_size=1, seq_len=50, input_size=100)
103+
tensor_size = (1, 50, 100)
104+
return (torch.randn(tensor_size),)
105+
106+
107+
class BidirectionalLSTMTextModel(EagerModelBase):
108+
def __init__(self):
109+
pass
110+
111+
def get_eager_model(self) -> torch.nn.Module:
112+
logging.info("Loading Bidirectional LSTM text classifier")
113+
model = BidirectionalLSTMTextClassifier(
114+
vocab_size=10000,
115+
embedding_dim=128,
116+
hidden_size=256,
117+
num_classes=2
118+
)
119+
model.eval()
120+
logging.info("Loaded Bidirectional LSTM text classifier")
121+
return model
122+
123+
def get_example_inputs(self):
124+
# Example: (batch_size=1, seq_len=100) - token indices
125+
tensor_size = (1, 100)
126+
return (torch.randint(0, 10000, tensor_size),)

examples/models/crnn/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import CRNNModel, CRNNMobileModel, CRNNRGBModel

0 commit comments

Comments
 (0)