Skip to content

Commit c29f732

Browse files
DevinTDHadanilojsl
andcommitted
[SPARKNLP-1333] LightPipeline Improvements (#14726)
commit 21e014a Merge: e369fb7 6b9356b Author: Danilo Burbano <[email protected]> Date: Mon Jan 19 19:13:38 2026 -0500 Merge branch 'release/632-release-candidate' into feature/SPARKNLP-1333-Improve-DocumentAssembler-in-LightPipelines commit e369fb7 Author: Danilo Burbano <[email protected]> Date: Mon Jan 19 19:11:07 2026 -0500 [SPARKNLP-1333] Refactor to better handle idCol logic and outputCols filtering, in Python added **kwargs for backwards compatibility commit 701b801 Author: Danilo Burbano <[email protected]> Date: Sun Jan 18 07:06:48 2026 -0500 [SPARKNLP-1333] Refactor annotator LightPipeline for QA models on python commit 2ef4d15 Author: Danilo Burbano <[email protected]> Date: Thu Jan 15 19:46:13 2026 -0500 [SPARKNLP-1333] Adding ids input for LightPipeline Co-authored-by: Danilo Burbano <[email protected]>
1 parent 6b9356b commit c29f732

File tree

8 files changed

+713
-116
lines changed

8 files changed

+713
-116
lines changed

python/sparknlp/base/light_pipeline.py

Lines changed: 185 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,25 @@ class LightPipeline:
6363
}
6464
"""
6565

66-
def __init__(self, pipelineModel, parse_embeddings=False):
66+
def __init__(self, pipelineModel, parse_embeddings=False, output_cols=None):
67+
"""
68+
Parameters
69+
----------
70+
pipelineModel : PipelineModel
71+
The fitted Spark NLP pipeline model.
72+
parse_embeddings : bool, optional
73+
Whether to parse embeddings.
74+
output_cols : list[str], optional
75+
List of output columns to return in results (optional).
76+
"""
77+
if output_cols is None:
78+
output_cols = []
79+
6780
self.pipeline_model = pipelineModel
6881
self.parse_embeddings = parse_embeddings
69-
self._lightPipeline = _internal._LightPipeline(pipelineModel, parse_embeddings).apply()
82+
self.output_cols = output_cols
83+
84+
self._lightPipeline = _internal._LightPipeline(pipelineModel, parse_embeddings, output_cols).apply()
7085

7186
def _validateStagesInputCols(self, stages):
7287
annotator_types = self._getAnnotatorTypes(stages)
@@ -157,22 +172,14 @@ def __get_result(annotation):
157172

158173
return result
159174

160-
def fullAnnotate(self, target, optional_target=""):
161-
"""Annotates the data provided into `Annotation` type results.
162-
163-
The data should be either a list or a str.
164-
165-
Parameters
166-
----------
167-
target : list or str or float
168-
The data to be annotated
169-
optional_target: list or str
170-
Optional data to be annotated (currently used for Question Answering)
175+
def fullAnnotate(self, *args, **kwargs):
176+
"""
177+
Annotate and return full Annotation objects.
171178
172-
Returns
173-
-------
174-
List[Annotation]
175-
The result of the annotation
179+
Supports both:
180+
- fullAnnotate(text: str)
181+
- fullAnnotate(texts: list[str])
182+
- fullAnnotate(ids: list[int], texts: list[str])
176183
177184
Examples
178185
--------
@@ -191,25 +198,46 @@ def fullAnnotate(self, target, optional_target=""):
191198
Annotation(named_entity, 30, 36, B-LOC, {'word': 'Baghdad'}),
192199
Annotation(named_entity, 37, 37, O, {'word': '.'})]
193200
"""
201+
if "target" in kwargs:
202+
args = (kwargs["target"],) + args
203+
if "optional_target" in kwargs:
204+
args = args + (kwargs["optional_target"],)
205+
194206
stages = self.pipeline_model.stages
195207
if not self._skipPipelineValidation(stages):
196208
self._validateStagesInputCols(stages)
197209

198-
if optional_target == "":
199-
if self.__isTextInput(target):
200-
result = self.__fullAnnotateText(target)
201-
elif self.__isAudioInput(target):
202-
result = self.__fullAnnotateAudio(target)
203-
else:
204-
raise TypeError(
205-
"argument for annotation must be 'str' or list[str] or list[float] or list[list[float]]")
206-
else:
207-
if self.__isTextInput(target) and self.__isTextInput(optional_target):
208-
result = self.__fullAnnotateQuestionAnswering(target, optional_target)
209-
else:
210-
raise TypeError("arguments for annotation must be 'str' or list[str]")
210+
input_type = self.__detectInputType(args)
211211

212-
return result
212+
if input_type == "ids_texts":
213+
ids, texts = args
214+
results = self._lightPipeline.fullAnnotateWithIdsJava(ids, texts)
215+
return [self.__buildStages(r) for r in results]
216+
217+
if input_type == "qa":
218+
question, context = args
219+
return self.__fullAnnotateQuestionAnswering(question, context)
220+
221+
if input_type == "text":
222+
target = args[0]
223+
return self.__fullAnnotateText(target)
224+
225+
if input_type == "audio":
226+
audios = args[0]
227+
return self.__fullAnnotateAudio(audios)
228+
229+
if input_type == "image":
230+
images = args[0]
231+
return self.fullAnnotateImage(images)
232+
233+
raise TypeError(
234+
"Unsupported input for fullAnnotate(). Expected: "
235+
"(text: str | list[str]), "
236+
"(ids: list[int], texts: list[str]), "
237+
"(question: str, context: str), "
238+
"(audio: list[float] | list[list[float]]), or "
239+
"(image_path: str | list[str])."
240+
)
213241

214242
@staticmethod
215243
def __isTextInput(target):
@@ -326,22 +354,19 @@ def __buildStages(self, annotations_result):
326354
stages[annotator_type] = self._annotationFromJava(annotations)
327355
return stages
328356

329-
def annotate(self, target, optional_target=""):
330-
"""Annotates the data provided, extracting the results.
331357

332-
The data should be either a list or a str.
358+
def annotate(self, *args, **kwargs):
359+
"""
360+
Annotate text(s) or text(s) with IDs using the LightPipeline.
333361
334-
Parameters
335-
----------
336-
target : list or str
337-
The data to be annotated
338-
optional_target: list or str
339-
Optional data to be annotated (currently used for Question Answering)
362+
Supports both:
363+
- annotate(text: str)
364+
- annotate(texts: list[str])
365+
- annotate(ids: list[int], texts: list[str])
340366
341367
Returns
342368
-------
343-
List[dict] or dict
344-
The result of the annotation
369+
list[dict[str, list[str]]]
345370
346371
Examples
347372
--------
@@ -353,39 +378,125 @@ def annotate(self, target, optional_target=""):
353378
>>> result["ner"]
354379
['B-ORG', 'O', 'O', 'B-PER', 'O', 'O', 'B-LOC', 'O']
355380
"""
356-
357-
def reformat(annotations):
358-
return {k: list(v) for k, v in annotations.items()}
381+
if "target" in kwargs:
382+
args = (kwargs["target"],) + args
383+
if "optional_target" in kwargs:
384+
args = args + (kwargs["optional_target"],)
359385

360386
stages = self.pipeline_model.stages
361387
if not self._skipPipelineValidation(stages):
362388
self._validateStagesInputCols(stages)
363389

364-
if optional_target == "":
365-
if type(target) is str:
366-
annotations = self._lightPipeline.annotateJava(target)
367-
result = reformat(annotations)
368-
elif type(target) is list:
369-
if type(target[0]) is list:
370-
raise TypeError("target is a 1D list")
371-
annotations = self._lightPipeline.annotateJava(target)
372-
result = list(map(lambda a: reformat(a), list(annotations)))
390+
input_type = self.__detectInputType(args)
391+
392+
if input_type == "ids_texts":
393+
ids, texts = args
394+
results = self._lightPipeline.annotateWithIdsJava(ids, texts)
395+
return [dict((k, list(v)) for k, v in r.items()) for r in results]
396+
397+
if input_type == "qa":
398+
question, context = args
399+
if isinstance(question, list) and isinstance(context, list):
400+
results = self._lightPipeline.annotateJava(question, context)
373401
else:
374-
raise TypeError("target for annotation must be 'str' or list")
402+
results = self._lightPipeline.annotateJava([question], [context])
375403

376-
else:
377-
if type(target) is str and type(optional_target) is str:
378-
annotations = self._lightPipeline.annotateJava(target, optional_target)
379-
result = reformat(annotations)
380-
elif type(target) is list and type(optional_target) is list:
381-
if type(target[0]) is list or type(optional_target[0]) is list:
382-
raise TypeError("target and optional_target is a 1D list")
383-
annotations = self._lightPipeline.annotateJava(target, optional_target)
384-
result = list(map(lambda a: reformat(a), list(annotations)))
404+
if isinstance(results, dict):
405+
results = [results]
406+
407+
results = [dict((k, list(v)) for k, v in r.items()) for r in results]
408+
409+
if len(results) == 1:
410+
return results[0]
411+
return results
412+
413+
414+
if input_type == "text":
415+
target = args[0]
416+
if isinstance(target, str):
417+
return self._lightPipeline.annotateJava(target)
385418
else:
386-
raise TypeError("target and optional_target for annotation must be both 'str' or both lists")
419+
results = self._lightPipeline.annotateJava(target)
420+
return [dict((k, list(v)) for k, v in r.items()) for r in results]
421+
422+
raise TypeError(
423+
"Unsupported input for annotate(). Expected: "
424+
"(text: str | list[str]), "
425+
"(ids: list[int], texts: list[str]), "
426+
"or (question: str, context: str)."
427+
)
428+
429+
def __detectInputType(self, args):
430+
"""
431+
Determine the input type pattern for fullAnnotate().
432+
Returns one of: 'ids_texts', 'qa', 'text', 'audio', 'image', or 'unknown'.
433+
"""
434+
if len(args) == 2:
435+
a1, a2 = args
436+
437+
# (ids, texts)
438+
if (
439+
isinstance(a1, list)
440+
and all(isinstance(i, int) for i in a1)
441+
and isinstance(a2, list)
442+
and all(isinstance(t, str) for t in a2)
443+
):
444+
return "ids_texts"
445+
446+
# (question, context)
447+
if isinstance(a1, str) and isinstance(a2, str):
448+
return "qa"
449+
450+
# (questions[], contexts[])
451+
if (
452+
isinstance(a1, list)
453+
and all(isinstance(q, str) for q in a1)
454+
and isinstance(a2, list)
455+
and all(isinstance(c, str) for c in a2)
456+
):
457+
return "qa"
458+
459+
elif len(args) == 1:
460+
a1 = args[0]
461+
462+
# 🚧 Guard: ignore anything that’s not str or list
463+
if not isinstance(a1, (str, list)):
464+
return "unknown"
465+
466+
# 🧩 Case 1: plain string
467+
if isinstance(a1, str):
468+
if self.__isPath(a1):
469+
return "image"
470+
if self.__isTextInput(a1):
471+
return "text"
472+
return "unknown"
473+
474+
# 🧩 Case 2: list — ensure homogeneous types
475+
if isinstance(a1, list) and len(a1) > 0:
476+
first_elem = a1[0]
477+
478+
# Guard clause — mixed or invalid types
479+
if not all(isinstance(x, (str, float, list)) for x in a1):
480+
return "unknown"
481+
482+
# Text list
483+
if all(isinstance(x, str) for x in a1) and self.__isTextInput(a1):
484+
return "text"
485+
486+
# Audio list
487+
if all(isinstance(x, float) for x in a1) or (
488+
all(isinstance(x, list) for x in a1) and all(isinstance(i, float) for sub in a1 for i in sub)
489+
):
490+
return "audio"
491+
492+
# Image list (only strings allowed)
493+
if all(isinstance(x, str) for x in a1) and all("/" in x for x in a1):
494+
return "image"
495+
496+
return "unknown"
497+
498+
return "unknown"
387499

388-
return result
389500

390501
def transform(self, dataframe):
391502
"""Transforms a dataframe provided with the stages of the LightPipeline.
@@ -400,7 +511,14 @@ def transform(self, dataframe):
400511
:class:`pyspark.sql.DataFrame`
401512
The transformed DataFrame
402513
"""
403-
return self.pipeline_model.transform(dataframe)
514+
transformed_df = self.pipeline_model.transform(dataframe)
515+
516+
if self.output_cols:
517+
original_cols = dataframe.columns
518+
filtered_cols = list(dict.fromkeys(original_cols + self.output_cols))
519+
transformed_df = transformed_df.select(*filtered_cols)
520+
521+
return transformed_df
404522

405523
def setIgnoreUnsupported(self, value):
406524
"""Sets whether to ignore unsupported AnnotatorModels.

python/sparknlp/internal/__init__.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -785,11 +785,27 @@ def __init__(self, name, language, remote_loc):
785785

786786

787787
class _LightPipeline(ExtendedJavaWrapper):
788-
def __init__(self, pipelineModel, parse_embeddings):
788+
def __init__(self, pipelineModel, parse_embeddings=False, output_cols=None):
789+
"""
790+
Internal wrapper around the JVM LightPipeline class.
791+
792+
Parameters
793+
----------
794+
pipelineModel : PipelineModel
795+
A fitted Spark NLP pipeline model.
796+
parse_embeddings : bool, optional
797+
Whether to parse embeddings from embeddings annotators.
798+
output_cols : list[str], optional
799+
List of output columns to include in the result. If not provided, returns all.
800+
"""
801+
if output_cols is None:
802+
output_cols = []
803+
789804
super(_LightPipeline, self).__init__(
790805
"com.johnsnowlabs.nlp.LightPipeline",
791806
pipelineModel._to_java(),
792-
parse_embeddings,
807+
bool(parse_embeddings),
808+
output_cols,
793809
)
794810

795811

0 commit comments

Comments
 (0)