@@ -63,10 +63,25 @@ class LightPipeline:
6363 }
6464 """
6565
66- def __init__ (self , pipelineModel , parse_embeddings = False ):
66+ def __init__ (self , pipelineModel , parse_embeddings = False , output_cols = None ):
67+ """
68+ Parameters
69+ ----------
70+ pipelineModel : PipelineModel
71+ The fitted Spark NLP pipeline model.
72+ parse_embeddings : bool, optional
73+ Whether to parse embeddings.
74+ output_cols : list[str], optional
75+ List of output columns to return in results (optional).
76+ """
77+ if output_cols is None :
78+ output_cols = []
79+
6780 self .pipeline_model = pipelineModel
6881 self .parse_embeddings = parse_embeddings
69- self ._lightPipeline = _internal ._LightPipeline (pipelineModel , parse_embeddings ).apply ()
82+ self .output_cols = output_cols
83+
84+ self ._lightPipeline = _internal ._LightPipeline (pipelineModel , parse_embeddings , output_cols ).apply ()
7085
7186 def _validateStagesInputCols (self , stages ):
7287 annotator_types = self ._getAnnotatorTypes (stages )
@@ -157,22 +172,14 @@ def __get_result(annotation):
157172
158173 return result
159174
160- def fullAnnotate (self , target , optional_target = "" ):
161- """Annotates the data provided into `Annotation` type results.
162-
163- The data should be either a list or a str.
164-
165- Parameters
166- ----------
167- target : list or str or float
168- The data to be annotated
169- optional_target: list or str
170- Optional data to be annotated (currently used for Question Answering)
175+ def fullAnnotate (self , * args , ** kwargs ):
176+ """
177+ Annotate and return full Annotation objects.
171178
172- Returns
173- -------
174- List[Annotation]
175- The result of the annotation
179+ Supports both:
180+ - fullAnnotate(text: str)
181+ - fullAnnotate(texts: list[str])
182+ - fullAnnotate(ids: list[int], texts: list[str])
176183
177184 Examples
178185 --------
@@ -191,25 +198,46 @@ def fullAnnotate(self, target, optional_target=""):
191198 Annotation(named_entity, 30, 36, B-LOC, {'word': 'Baghdad'}),
192199 Annotation(named_entity, 37, 37, O, {'word': '.'})]
193200 """
201+ if "target" in kwargs :
202+ args = (kwargs ["target" ],) + args
203+ if "optional_target" in kwargs :
204+ args = args + (kwargs ["optional_target" ],)
205+
194206 stages = self .pipeline_model .stages
195207 if not self ._skipPipelineValidation (stages ):
196208 self ._validateStagesInputCols (stages )
197209
198- if optional_target == "" :
199- if self .__isTextInput (target ):
200- result = self .__fullAnnotateText (target )
201- elif self .__isAudioInput (target ):
202- result = self .__fullAnnotateAudio (target )
203- else :
204- raise TypeError (
205- "argument for annotation must be 'str' or list[str] or list[float] or list[list[float]]" )
206- else :
207- if self .__isTextInput (target ) and self .__isTextInput (optional_target ):
208- result = self .__fullAnnotateQuestionAnswering (target , optional_target )
209- else :
210- raise TypeError ("arguments for annotation must be 'str' or list[str]" )
210+ input_type = self .__detectInputType (args )
211211
212- return result
212+ if input_type == "ids_texts" :
213+ ids , texts = args
214+ results = self ._lightPipeline .fullAnnotateWithIdsJava (ids , texts )
215+ return [self .__buildStages (r ) for r in results ]
216+
217+ if input_type == "qa" :
218+ question , context = args
219+ return self .__fullAnnotateQuestionAnswering (question , context )
220+
221+ if input_type == "text" :
222+ target = args [0 ]
223+ return self .__fullAnnotateText (target )
224+
225+ if input_type == "audio" :
226+ audios = args [0 ]
227+ return self .__fullAnnotateAudio (audios )
228+
229+ if input_type == "image" :
230+ images = args [0 ]
231+ return self .fullAnnotateImage (images )
232+
233+ raise TypeError (
234+ "Unsupported input for fullAnnotate(). Expected: "
235+ "(text: str | list[str]), "
236+ "(ids: list[int], texts: list[str]), "
237+ "(question: str, context: str), "
238+ "(audio: list[float] | list[list[float]]), or "
239+ "(image_path: str | list[str])."
240+ )
213241
214242 @staticmethod
215243 def __isTextInput (target ):
@@ -326,22 +354,19 @@ def __buildStages(self, annotations_result):
326354 stages [annotator_type ] = self ._annotationFromJava (annotations )
327355 return stages
328356
329- def annotate (self , target , optional_target = "" ):
330- """Annotates the data provided, extracting the results.
331357
332- The data should be either a list or a str.
358+ def annotate (self , * args , ** kwargs ):
359+ """
360+ Annotate text(s) or text(s) with IDs using the LightPipeline.
333361
334- Parameters
335- ----------
336- target : list or str
337- The data to be annotated
338- optional_target: list or str
339- Optional data to be annotated (currently used for Question Answering)
362+ Supports both:
363+ - annotate(text: str)
364+ - annotate(texts: list[str])
365+ - annotate(ids: list[int], texts: list[str])
340366
341367 Returns
342368 -------
343- List[dict] or dict
344- The result of the annotation
369+ list[dict[str, list[str]]]
345370
346371 Examples
347372 --------
@@ -353,39 +378,125 @@ def annotate(self, target, optional_target=""):
353378 >>> result["ner"]
354379 ['B-ORG', 'O', 'O', 'B-PER', 'O', 'O', 'B-LOC', 'O']
355380 """
356-
357- def reformat (annotations ):
358- return {k : list (v ) for k , v in annotations .items ()}
381+ if "target" in kwargs :
382+ args = (kwargs ["target" ],) + args
383+ if "optional_target" in kwargs :
384+ args = args + (kwargs ["optional_target" ],)
359385
360386 stages = self .pipeline_model .stages
361387 if not self ._skipPipelineValidation (stages ):
362388 self ._validateStagesInputCols (stages )
363389
364- if optional_target == "" :
365- if type (target ) is str :
366- annotations = self ._lightPipeline .annotateJava (target )
367- result = reformat (annotations )
368- elif type (target ) is list :
369- if type (target [0 ]) is list :
370- raise TypeError ("target is a 1D list" )
371- annotations = self ._lightPipeline .annotateJava (target )
372- result = list (map (lambda a : reformat (a ), list (annotations )))
390+ input_type = self .__detectInputType (args )
391+
392+ if input_type == "ids_texts" :
393+ ids , texts = args
394+ results = self ._lightPipeline .annotateWithIdsJava (ids , texts )
395+ return [dict ((k , list (v )) for k , v in r .items ()) for r in results ]
396+
397+ if input_type == "qa" :
398+ question , context = args
399+ if isinstance (question , list ) and isinstance (context , list ):
400+ results = self ._lightPipeline .annotateJava (question , context )
373401 else :
374- raise TypeError ( "target for annotation must be 'str' or list" )
402+ results = self . _lightPipeline . annotateJava ([ question ], [ context ] )
375403
376- else :
377- if type (target ) is str and type (optional_target ) is str :
378- annotations = self ._lightPipeline .annotateJava (target , optional_target )
379- result = reformat (annotations )
380- elif type (target ) is list and type (optional_target ) is list :
381- if type (target [0 ]) is list or type (optional_target [0 ]) is list :
382- raise TypeError ("target and optional_target is a 1D list" )
383- annotations = self ._lightPipeline .annotateJava (target , optional_target )
384- result = list (map (lambda a : reformat (a ), list (annotations )))
404+ if isinstance (results , dict ):
405+ results = [results ]
406+
407+ results = [dict ((k , list (v )) for k , v in r .items ()) for r in results ]
408+
409+ if len (results ) == 1 :
410+ return results [0 ]
411+ return results
412+
413+
414+ if input_type == "text" :
415+ target = args [0 ]
416+ if isinstance (target , str ):
417+ return self ._lightPipeline .annotateJava (target )
385418 else :
386- raise TypeError ("target and optional_target for annotation must be both 'str' or both lists" )
419+ results = self ._lightPipeline .annotateJava (target )
420+ return [dict ((k , list (v )) for k , v in r .items ()) for r in results ]
421+
422+ raise TypeError (
423+ "Unsupported input for annotate(). Expected: "
424+ "(text: str | list[str]), "
425+ "(ids: list[int], texts: list[str]), "
426+ "or (question: str, context: str)."
427+ )
428+
429+ def __detectInputType (self , args ):
430+ """
431+ Determine the input type pattern for fullAnnotate().
432+ Returns one of: 'ids_texts', 'qa', 'text', 'audio', 'image', or 'unknown'.
433+ """
434+ if len (args ) == 2 :
435+ a1 , a2 = args
436+
437+ # (ids, texts)
438+ if (
439+ isinstance (a1 , list )
440+ and all (isinstance (i , int ) for i in a1 )
441+ and isinstance (a2 , list )
442+ and all (isinstance (t , str ) for t in a2 )
443+ ):
444+ return "ids_texts"
445+
446+ # (question, context)
447+ if isinstance (a1 , str ) and isinstance (a2 , str ):
448+ return "qa"
449+
450+ # (questions[], contexts[])
451+ if (
452+ isinstance (a1 , list )
453+ and all (isinstance (q , str ) for q in a1 )
454+ and isinstance (a2 , list )
455+ and all (isinstance (c , str ) for c in a2 )
456+ ):
457+ return "qa"
458+
459+ elif len (args ) == 1 :
460+ a1 = args [0 ]
461+
462+ # 🚧 Guard: ignore anything that’s not str or list
463+ if not isinstance (a1 , (str , list )):
464+ return "unknown"
465+
466+ # 🧩 Case 1: plain string
467+ if isinstance (a1 , str ):
468+ if self .__isPath (a1 ):
469+ return "image"
470+ if self .__isTextInput (a1 ):
471+ return "text"
472+ return "unknown"
473+
474+ # 🧩 Case 2: list — ensure homogeneous types
475+ if isinstance (a1 , list ) and len (a1 ) > 0 :
476+ first_elem = a1 [0 ]
477+
478+ # Guard clause — mixed or invalid types
479+ if not all (isinstance (x , (str , float , list )) for x in a1 ):
480+ return "unknown"
481+
482+ # Text list
483+ if all (isinstance (x , str ) for x in a1 ) and self .__isTextInput (a1 ):
484+ return "text"
485+
486+ # Audio list
487+ if all (isinstance (x , float ) for x in a1 ) or (
488+ all (isinstance (x , list ) for x in a1 ) and all (isinstance (i , float ) for sub in a1 for i in sub )
489+ ):
490+ return "audio"
491+
492+ # Image list (only strings allowed)
493+ if all (isinstance (x , str ) for x in a1 ) and all ("/" in x for x in a1 ):
494+ return "image"
495+
496+ return "unknown"
497+
498+ return "unknown"
387499
388- return result
389500
390501 def transform (self , dataframe ):
391502 """Transforms a dataframe provided with the stages of the LightPipeline.
@@ -400,7 +511,14 @@ def transform(self, dataframe):
400511 :class:`pyspark.sql.DataFrame`
401512 The transformed DataFrame
402513 """
403- return self .pipeline_model .transform (dataframe )
514+ transformed_df = self .pipeline_model .transform (dataframe )
515+
516+ if self .output_cols :
517+ original_cols = dataframe .columns
518+ filtered_cols = list (dict .fromkeys (original_cols + self .output_cols ))
519+ transformed_df = transformed_df .select (* filtered_cols )
520+
521+ return transformed_df
404522
405523 def setIgnoreUnsupported (self , value ):
406524 """Sets whether to ignore unsupported AnnotatorModels.
0 commit comments