4
4
import pydload
5
5
import pickle
6
6
import os
7
+ import logging
8
+ import time
7
9
8
10
model_links = {
9
11
'en' : {
14
16
}
15
17
16
18
19
+ def chunk (l , n ):
20
+ """Yield successive n-sized chunks from l."""
21
+ chunked_l = []
22
+ for i in range (0 , len (l ), n ):
23
+ chunked_l .append (l [i :i + n ])
24
+
25
+ if not chunked_l :
26
+ chunked_l = [l ]
27
+
28
+ return chunked_l
29
+
17
30
class DeepSegment ():
18
31
seqtag_model = None
19
32
data_converter = None
@@ -54,25 +67,61 @@ def segment(self, sents):
54
67
if not DeepSegment .seqtag_model :
55
68
print ('Please load the model first' )
56
69
70
+ string_output = False
57
71
if not isinstance (sents , list ):
72
+ logging .warn ("Batch input strings for faster inference." )
73
+ string_output = True
58
74
sents = [sents ]
59
75
60
76
sents = [sent .strip ().split () for sent in sents ]
77
+
78
+ max_len = len (max (sents , key = len ))
79
+ if max_len >= 40 :
80
+ logging .warn ("Consider using segment_long for longer sentences." )
81
+
61
82
encoded_sents = DeepSegment .data_converter .transform (sents )
62
83
all_tags = DeepSegment .seqtag_model .predict (encoded_sents )
63
84
all_tags = [np .argmax (_tags , axis = 1 ).tolist () for _tags in all_tags ]
64
85
65
- segmented_sentences = []
66
- for sent , tags in zip (sents , all_tags ):
86
+ segmented_sentences = [[] for _ in sents ]
87
+ for sent_index , ( sent , tags ) in enumerate ( zip (sents , all_tags ) ):
67
88
segmented_sent = []
68
89
for i , (word , tag ) in enumerate (zip (sent , tags )):
69
90
if tag == 2 and i > 0 and segmented_sent :
70
91
segmented_sent = ' ' .join (segmented_sent )
71
- segmented_sentences .append (segmented_sent )
92
+ segmented_sentences [ sent_index ] .append (segmented_sent )
72
93
segmented_sent = []
73
94
74
95
segmented_sent .append (word )
75
96
if segmented_sent :
76
- segmented_sentences .append (' ' .join (segmented_sent ))
97
+ segmented_sentences [ sent_index ] .append (' ' .join (segmented_sent ))
77
98
99
+ if string_output :
100
+ return segmented_sentences [0 ]
101
+
78
102
return segmented_sentences
103
+
104
+ def segment_long (self , sent , n_window = None ):
105
+ if not n_window :
106
+ logging .warn ("Using default n_window=10. Set this parameter based on your data." )
107
+ n_window = 10
108
+
109
+ if isinstance (sent , list ):
110
+ logging .error ("segment_long doesn't support batching as of now. Batching will be added in a future release." )
111
+ return None
112
+
113
+ segmented = []
114
+ sent = sent .split ()
115
+ prefix = []
116
+ while sent :
117
+ current_n_window = n_window - len (prefix )
118
+ if current_n_window == 0 :
119
+ current_n_window = n_window
120
+
121
+ window = prefix + sent [:current_n_window ]
122
+ sent = sent [current_n_window :]
123
+ segmented_window = self .segment ([' ' .join (window )])[0 ]
124
+ segmented += segmented_window [:- 1 ]
125
+ prefix = segmented_window [- 1 ].split ()
126
+
127
+ return segmented
0 commit comments