Skip to content
This repository was archived by the owner on Apr 30, 2021. It is now read-only.

Commit 274f2dc

Browse files
committed
add segment_long
1 parent dcd06e0 commit 274f2dc

File tree

2 files changed

+53
-9
lines changed

2 files changed

+53
-9
lines changed

deepsegment/deepsegment.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import pydload
55
import pickle
66
import os
7+
import logging
8+
import time
79

810
model_links = {
911
'en': {
@@ -14,6 +16,17 @@
1416
}
1517

1618

19+
def chunk(l, n):
20+
"""Yield successive n-sized chunks from l."""
21+
chunked_l = []
22+
for i in range(0, len(l), n):
23+
chunked_l.append(l[i:i + n])
24+
25+
if not chunked_l:
26+
chunked_l = [l]
27+
28+
return chunked_l
29+
1730
class DeepSegment():
1831
seqtag_model = None
1932
data_converter = None
@@ -54,25 +67,61 @@ def segment(self, sents):
5467
if not DeepSegment.seqtag_model:
5568
print('Please load the model first')
5669

70+
string_output = False
5771
if not isinstance(sents, list):
72+
logging.warn("Batch input strings for faster inference.")
73+
string_output = True
5874
sents = [sents]
5975

6076
sents = [sent.strip().split() for sent in sents]
77+
78+
max_len = len(max(sents, key=len))
79+
if max_len >= 40:
80+
logging.warn("Consider using segment_long for longer sentences.")
81+
6182
encoded_sents = DeepSegment.data_converter.transform(sents)
6283
all_tags = DeepSegment.seqtag_model.predict(encoded_sents)
6384
all_tags = [np.argmax(_tags, axis=1).tolist() for _tags in all_tags]
6485

65-
segmented_sentences = []
66-
for sent, tags in zip(sents, all_tags):
86+
segmented_sentences = [[] for _ in sents]
87+
for sent_index, (sent, tags) in enumerate(zip(sents, all_tags)):
6788
segmented_sent = []
6889
for i, (word, tag) in enumerate(zip(sent, tags)):
6990
if tag == 2 and i > 0 and segmented_sent:
7091
segmented_sent = ' '.join(segmented_sent)
71-
segmented_sentences.append(segmented_sent)
92+
segmented_sentences[sent_index].append(segmented_sent)
7293
segmented_sent = []
7394

7495
segmented_sent.append(word)
7596
if segmented_sent:
76-
segmented_sentences.append(' '.join(segmented_sent))
97+
segmented_sentences[sent_index].append(' '.join(segmented_sent))
7798

99+
if string_output:
100+
return segmented_sentences[0]
101+
78102
return segmented_sentences
103+
104+
def segment_long(self, sent, n_window=None):
105+
if not n_window:
106+
logging.warn("Using default n_window=10. Set this parameter based on your data.")
107+
n_window = 10
108+
109+
if isinstance(sent, list):
110+
logging.error("segment_long doesn't support batching as of now. Batching will be added in a future release.")
111+
return None
112+
113+
segmented = []
114+
sent = sent.split()
115+
prefix = []
116+
while sent:
117+
current_n_window = n_window - len(prefix)
118+
if current_n_window == 0:
119+
current_n_window = n_window
120+
121+
window = prefix + sent[:current_n_window]
122+
sent = sent[current_n_window:]
123+
segmented_window = self.segment([' '.join(window)])[0]
124+
segmented += segmented_window[:-1]
125+
prefix = segmented_window[-1].split()
126+
127+
return segmented

deepsegment/test.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)