Skip to content
This repository was archived by the owner on Apr 30, 2021. It is now read-only.

Commit d338140

Browse files
committed
code for v2
1 parent 6dc53a5 commit d338140

File tree

8 files changed

+79
-184
lines changed

8 files changed

+79
-184
lines changed

data/.gitignore

Lines changed: 0 additions & 1 deletion
This file was deleted.

data/data_gen.py

Lines changed: 0 additions & 71 deletions
This file was deleted.
-187 Bytes
Binary file not shown.
-950 Bytes
Binary file not shown.

deepsegment/deepsegment.py

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,78 @@
1-
from seqtag import predictor
1+
from keras.models import model_from_json
2+
import numpy as np
3+
from seqtag_keras.layers import CRF
4+
import pydload
5+
import pickle
6+
import os
7+
8+
model_links = {
9+
'en': {
10+
'checkpoint': 'https://github.com/bedapudi6788/deepsegment/releases/download/v1.0.2/en_checkpoint',
11+
'utils': 'https://github.com/bedapudi6788/deepsegment/releases/download/v1.0.2/en_utils',
12+
'params': 'https://github.com/bedapudi6788/deepsegment/releases/download/v1.0.2/en_params'
13+
}
14+
}
215

316

417
class DeepSegment():
518
seqtag_model = None
6-
def __init__(self, config_path):
19+
data_converter = None
20+
def __init__(self, lang_code='en'):
21+
if lang_code not in model_links:
22+
print("DeepSegment doesn't support '" + lang_code + "' yet.")
23+
print("Please raise a issue at https://github.com/bedapudi6788/deepsegment to add this language into future checklist.")
24+
return None
25+
726
# loading the model
8-
DeepSegment.seqtag_model = predictor.load_model(config_path)
9-
10-
def segment(self, text):
27+
home = os.path.expanduser("~")
28+
lang_path = os.path.join(home, '.DeepSegment_' + lang_code)
29+
checkpoint_path = os.path.join(lang_path, 'checkpoint')
30+
utils_path = os.path.join(lang_path, 'utils')
31+
params_path = os.path.join(lang_path, 'params')
32+
33+
if not os.path.exists(lang_path):
34+
os.mkdir(lang_path)
35+
36+
if not os.path.exists(checkpoint_path):
37+
print('Downloading checkpoint', model_links[lang_code]['checkpoint'], 'to', checkpoint_path)
38+
pydload.dload(url=model_links[lang_code]['checkpoint'], save_to_path=checkpoint_path, max_time=None)
39+
40+
if not os.path.exists(utils_path):
41+
print('Downloading preprocessing utils', model_links[lang_code]['utils'], 'to', utils_path)
42+
pydload.dload(url=model_links[lang_code]['utils'], save_to_path=utils_path, max_time=None)
43+
44+
if not os.path.exists(params_path):
45+
print('Downloading model params', model_links[lang_code]['utils'], 'to', params_path)
46+
pydload.dload(url=model_links[lang_code]['params'], save_to_path=params_path, max_time=None)
47+
48+
49+
DeepSegment.seqtag_model = model_from_json(open(params_path).read(), custom_objects={'CRF': CRF})
50+
DeepSegment.seqtag_model.load_weights(checkpoint_path)
51+
DeepSegment.data_converter = pickle.load(open(utils_path, 'rb'))
52+
53+
def segment(self, sents):
1154
if not DeepSegment.seqtag_model:
1255
print('Please load the model first')
1356

14-
text = text.strip().split()
15-
tags = predictor.predict(DeepSegment.seqtag_model, text)
16-
sents = []
17-
18-
current_sent = []
19-
for i, word in enumerate(text):
20-
if tags[i] == 'B-sent':
21-
if current_sent:
22-
sents.append(' '.join(current_sent))
23-
current_sent = [word]
24-
else:
25-
current_sent.append(word)
26-
27-
sents.append(' '.join(current_sent))
28-
29-
return sents
57+
if not isinstance(sents, list):
58+
sents = [sents]
59+
60+
sents = [sent.strip().split() for sent in sents]
61+
encoded_sents = DeepSegment.data_converter.transform(sents)
62+
all_tags = DeepSegment.seqtag_model.predict(encoded_sents)
63+
all_tags = [np.argmax(_tags, axis=1).tolist() for _tags in all_tags]
64+
65+
segmented_sentences = []
66+
for sent, tags in zip(sents, all_tags):
67+
segmented_sent = []
68+
for i, (word, tag) in enumerate(zip(sent, tags)):
69+
if tag == 2 and i > 0 and segmented_sent:
70+
segmented_sent = ' '.join(segmented_sent)
71+
segmented_sentences.append(segmented_sent)
72+
segmented_sent = []
73+
74+
segmented_sent.append(word)
75+
if segmented_sent:
76+
segmented_sentences.append(' '.join(segmented_sent))
77+
78+
return segmented_sentences

deepsegment/test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from deepsegment import DeepSegment
2+
3+
m = DeepSegment()
4+
5+
print(m.segment('I am batman who are you'))

run_tests.py

Lines changed: 0 additions & 88 deletions
This file was deleted.

setup.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
URL = 'https://github.com/bedapudi6788/Deep-Segmentation'
1818
1919
AUTHOR = 'BEDAPUDI PRANEETH'
20-
REQUIRES_PYTHON = '>=2.6.0'
21-
VERSION = '1.0.1'
20+
REQUIRES_PYTHON = '>=3.5.0'
21+
VERSION = '2.0.0'
2222

2323
# What packages are required for this module to be executed?
2424
REQUIRED = [
25-
'seqtag'
25+
'seqtag-keras',
26+
'pydload'
2627
]
2728

2829
# What packages are optional?

0 commit comments

Comments
 (0)