Skip to content

Commit b4289c1

Browse files
added classifier lambda code
1 parent ee99e19 commit b4289c1

File tree

13 files changed

+447
-0
lines changed

13 files changed

+447
-0
lines changed

.github/workflows/classification.yml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
name: classification
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
paths:
8+
- "classification/**"
9+
10+
jobs:
11+
classification:
12+
runs-on: ubuntu-latest
13+
defaults:
14+
run:
15+
working-directory: ./classification
16+
steps:
17+
- name: Checkout
18+
uses: actions/checkout@v2
19+
with:
20+
ref: ${{ github.ref }}
21+
- name: Build container
22+
run: |
23+
docker build --tag classification:latest .
24+
- name: Configure AWS Credentials
25+
uses: aws-actions/configure-aws-credentials@v1
26+
with:
27+
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
28+
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
29+
aws-region: us-east-1
30+
- name: Push2ECR
31+
id: ecr
32+
uses: jwalton/gh-ecr-push@v1
33+
with:
34+
access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
35+
secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
36+
region: us-east-1
37+
image: classification:latest
38+
- name: Update lambda with image
39+
run: aws lambda update-function-code --function-name classification --image-uri 968911158010.dkr.ecr.us-east-1.amazonaws.com/classification:latest

classification/.gitignore

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
pip-wheel-metadata/
24+
share/python-wheels/
25+
*.egg-info/
26+
.installed.cfg
27+
*.egg
28+
MANIFEST
29+
30+
# PyInstaller
31+
# Usually these files are written by a python script from a template
32+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33+
*.manifest
34+
*.spec
35+
36+
# Installer logs
37+
pip-log.txt
38+
pip-delete-this-directory.txt
39+
40+
# Unit test / coverage reports
41+
htmlcov/
42+
.tox/
43+
.nox/
44+
.coverage
45+
.coverage.*
46+
.cache
47+
nosetests.xml
48+
coverage.xml
49+
*.cover
50+
*.py,cover
51+
.hypothesis/
52+
.pytest_cache/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
target/
76+
77+
# Jupyter Notebook
78+
.ipynb_checkpoints
79+
80+
# IPython
81+
profile_default/
82+
ipython_config.py
83+
84+
# pyenv
85+
.python-version
86+
87+
# pipenv
88+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
90+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
91+
# install all needed dependencies.
92+
#Pipfile.lock
93+
94+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
95+
__pypackages__/
96+
97+
# Celery stuff
98+
celerybeat-schedule
99+
celerybeat.pid
100+
101+
# SageMath parsed files
102+
*.sage.py
103+
104+
# Environments
105+
.env
106+
.venv
107+
env/
108+
venv/
109+
ENV/
110+
env.bak/
111+
venv.bak/
112+
113+
# Spyder project settings
114+
.spyderproject
115+
.spyproject
116+
117+
# Rope project settings
118+
.ropeproject
119+
120+
# mkdocs documentation
121+
/site
122+
123+
# mypy
124+
.mypy_cache/
125+
.dmypy.json
126+
dmypy.json
127+
128+
# Pyre type checker
129+
.pyre/
130+
131+
.aws-sam
132+
*.pyc
133+
.vscode
134+
.DS_store
135+
**.bin
136+
**.ipynb_checkpoints

classification/Dockerfile

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
FROM amazon/aws-lambda-python
2+
3+
ARG MODEL_DIR=./models
4+
5+
ENV TRANSFORMERS_CACHE=$MODEL_DIR
6+
ENV TRANSFORMERS_VERBOSITY=error
7+
8+
RUN yum -y install gcc-c++
9+
10+
COPY requirements.txt requirements.txt
11+
RUN pip install torch==1.8+cpu -f https://download.pytorch.org/whl/torch_stable.html --no-cache-dir
12+
RUN pip install -r requirements.txt --no-cache-dir
13+
14+
COPY ./ ./
15+
16+
# Run test cases and this saves the transformer model in the container
17+
RUN pip install pytest --no-cache-dir && pytest tests -s -vv
18+
19+
RUN chmod -R 0777 $MODEL_DIR
20+
21+
CMD [ "main.lambda_handler"]

classification/README.MD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
## Classification service
2+
Classification using AWS Lambda & Transformers

classification/main.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from sklearn import pipeline
2+
from src.classifier import Classifier
3+
4+
pipeline = Classifier()
5+
6+
7+
def lambda_handler(event, context):
8+
try:
9+
return pipeline(event)
10+
except Exception as e:
11+
raise

classification/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
transformers==4.*
2+
tqdm==4.*
3+
scikit-learn==0.24.*

classification/src/__init__.py

Whitespace-only changes.

classification/src/classifier.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import warnings
2+
from functools import lru_cache
3+
4+
warnings.filterwarnings("ignore")
5+
6+
from tqdm import tqdm
7+
from transformers import (AutoConfig, AutoModelForSequenceClassification,
8+
AutoTokenizer, pipeline)
9+
10+
from src import config, utils
11+
12+
logger = utils.create_logger(project_name=config.PREDICTION_TYPE, level="INFO")
13+
14+
class Classifier:
15+
def __init__(self):
16+
_ = self.get_sentiment_pipeline(model_name=config.DEFAULT_MODEL_NAME, tokenizer_name=config.DEFAULT_TOKENIZER_NAME) #warm up
17+
18+
@staticmethod
19+
@lru_cache(maxsize=config.CACHE_MAXSIZE)
20+
def get_sentiment_pipeline(model_name: str, tokenizer_name: str) -> pipeline:
21+
"""Sentiment pipeline for the given model and tokenizer
22+
23+
Args:
24+
model_name (str): Indicating the name of the model
25+
tokenizer_name (str): Indicating the name of the tokenizer
26+
27+
Returns:
28+
pipeline: sentiment pipeline
29+
"""
30+
logger.info(f"Loading model: {model_name}")
31+
id2label = config.ID_SENTIMENT_MAPPING[model_name]
32+
label2id = {label: idx for idx, label in id2label.items()}
33+
34+
model_config = AutoConfig.from_pretrained(model_name)
35+
model_config.label2id = label2id
36+
model_config.id2label = id2label
37+
model = AutoModelForSequenceClassification.from_pretrained(
38+
model_name, config=model_config
39+
)
40+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
41+
classification_pipeline = pipeline(
42+
"sentiment-analysis", model=model, tokenizer=tokenizer
43+
)
44+
return classification_pipeline
45+
46+
def get_clean_text(self, text: str) -> str:
47+
"""Clean the text
48+
49+
Args:
50+
text (str): text
51+
52+
Returns:
53+
str: clean text
54+
"""
55+
return text.strip().lower()
56+
57+
def __call__(self, request: dict)-> dict:
58+
"""Predict the sentiment of the given texts
59+
60+
Args:
61+
request (dict): request containing the list of text to predict the sentiment
62+
63+
Returns:
64+
dict: classes of the given text
65+
"""
66+
texts = [self.get_clean_text(text) for text in request["texts"]]
67+
model_name = request["model_name"]
68+
tokenizer_name = request["tokenizer_name"]
69+
70+
logger.info(f"Predicting sentiment for {len(texts)} texts")
71+
classification_pipeline = self.get_sentiment_pipeline(model_name, tokenizer_name)
72+
73+
predictions = classification_pipeline(texts)
74+
for i, pred in enumerate(predictions):
75+
predictions[i]["score"] = round(pred["score"], 2)
76+
77+
return {
78+
"predictions": predictions
79+
}
80+
81+

classification/src/config.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
PREDICTION_TYPE = 'classification'
2+
3+
DEFAULT_MODEL_NAME = "cardiffnlp/twitter-roberta-base-sentiment"
4+
DEFAULT_TOKENIZER_NAME = "roberta-base"
5+
ID_SENTIMENT_MAPPING = { # add for all models to be supported
6+
"cardiffnlp/twitter-roberta-base-sentiment": {
7+
0: "NEGATIVE",
8+
1: "NEUTRAL",
9+
2: "POSITIVE"
10+
},
11+
"cardiffnlp/twitter-roberta-base-emotion": {
12+
0: "ANGER",
13+
1: "JOY",
14+
2: "OPTIMISM",
15+
3: "SADNESS"
16+
}
17+
}
18+
19+
# cache
20+
CACHE_MAXSIZE = 4

0 commit comments

Comments
 (0)