1
- from seqtag import predictor
1
+ from keras .models import model_from_json
2
+ import numpy as np
3
+ from seqtag_keras .layers import CRF
4
+ import pydload
5
+ import pickle
6
+ import os
7
+
8
+ model_links = {
9
+ 'en' : {
10
+ 'checkpoint' : 'https://github.com/bedapudi6788/deepsegment/releases/download/v1.0.2/en_checkpoint' ,
11
+ 'utils' : 'https://github.com/bedapudi6788/deepsegment/releases/download/v1.0.2/en_utils' ,
12
+ 'params' : 'https://github.com/bedapudi6788/deepsegment/releases/download/v1.0.2/en_params'
13
+ }
14
+ }
2
15
3
16
4
17
class DeepSegment ():
5
18
seqtag_model = None
6
- def __init__ (self , config_path ):
19
+ data_converter = None
20
+ def __init__ (self , lang_code = 'en' ):
21
+ if lang_code not in model_links :
22
+ print ("DeepSegment doesn't support '" + lang_code + "' yet." )
23
+ print ("Please raise a issue at https://github.com/bedapudi6788/deepsegment to add this language into future checklist." )
24
+ return None
25
+
7
26
# loading the model
8
- DeepSegment .seqtag_model = predictor .load_model (config_path )
9
-
10
- def segment (self , text ):
27
+ home = os .path .expanduser ("~" )
28
+ lang_path = os .path .join (home , '.DeepSegment_' + lang_code )
29
+ checkpoint_path = os .path .join (lang_path , 'checkpoint' )
30
+ utils_path = os .path .join (lang_path , 'utils' )
31
+ params_path = os .path .join (lang_path , 'params' )
32
+
33
+ if not os .path .exists (lang_path ):
34
+ os .mkdir (lang_path )
35
+
36
+ if not os .path .exists (checkpoint_path ):
37
+ print ('Downloading checkpoint' , model_links [lang_code ]['checkpoint' ], 'to' , checkpoint_path )
38
+ pydload .dload (url = model_links [lang_code ]['checkpoint' ], save_to_path = checkpoint_path , max_time = None )
39
+
40
+ if not os .path .exists (utils_path ):
41
+ print ('Downloading preprocessing utils' , model_links [lang_code ]['utils' ], 'to' , utils_path )
42
+ pydload .dload (url = model_links [lang_code ]['utils' ], save_to_path = utils_path , max_time = None )
43
+
44
+ if not os .path .exists (params_path ):
45
+ print ('Downloading model params' , model_links [lang_code ]['utils' ], 'to' , params_path )
46
+ pydload .dload (url = model_links [lang_code ]['params' ], save_to_path = params_path , max_time = None )
47
+
48
+
49
+ DeepSegment .seqtag_model = model_from_json (open (params_path ).read (), custom_objects = {'CRF' : CRF })
50
+ DeepSegment .seqtag_model .load_weights (checkpoint_path )
51
+ DeepSegment .data_converter = pickle .load (open (utils_path , 'rb' ))
52
+
53
+ def segment (self , sents ):
11
54
if not DeepSegment .seqtag_model :
12
55
print ('Please load the model first' )
13
56
14
- text = text .strip ().split ()
15
- tags = predictor .predict (DeepSegment .seqtag_model , text )
16
- sents = []
17
-
18
- current_sent = []
19
- for i , word in enumerate (text ):
20
- if tags [i ] == 'B-sent' :
21
- if current_sent :
22
- sents .append (' ' .join (current_sent ))
23
- current_sent = [word ]
24
- else :
25
- current_sent .append (word )
26
-
27
- sents .append (' ' .join (current_sent ))
28
-
29
- return sents
57
+ if not isinstance (sents , list ):
58
+ sents = [sents ]
59
+
60
+ sents = [sent .strip ().split () for sent in sents ]
61
+ encoded_sents = DeepSegment .data_converter .transform (sents )
62
+ all_tags = DeepSegment .seqtag_model .predict (encoded_sents )
63
+ all_tags = [np .argmax (_tags , axis = 1 ).tolist () for _tags in all_tags ]
64
+
65
+ segmented_sentences = []
66
+ for sent , tags in zip (sents , all_tags ):
67
+ segmented_sent = []
68
+ for i , (word , tag ) in enumerate (zip (sent , tags )):
69
+ if tag == 2 and i > 0 and segmented_sent :
70
+ segmented_sent = ' ' .join (segmented_sent )
71
+ segmented_sentences .append (segmented_sent )
72
+ segmented_sent = []
73
+
74
+ segmented_sent .append (word )
75
+ if segmented_sent :
76
+ segmented_sentences .append (' ' .join (segmented_sent ))
77
+
78
+ return segmented_sentences
0 commit comments