Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
df25fe2
feat(component.rankers): Add HuggingFace API (text-embeddings-inferen…
atopx May 20, 2025
9f12090
update test flow & doc loaders
atopx May 20, 2025
71192d6
Support run_async for HuggingFaceAPIRanker
atopx May 20, 2025
670523d
Add release note for HuggingFace API support in component.rankers
atopx May 21, 2025
ddc2aad
Add release note for HuggingFace API support in component.rankers
atopx May 21, 2025
35da7ea
Add release note for HuggingFace API support in component.rankers
atopx May 21, 2025
ace8b4c
Add release note for HuggingFace API support in component.rankers
atopx May 21, 2025
bd5e317
Merge remote-tracking branch 'upstream/main'
atopx May 21, 2025
2ffc4b3
Merge branch 'main' into main
atopx May 22, 2025
1ebc6f1
Merge branch 'main' into main
atopx May 23, 2025
391232a
fix:
atopx May 23, 2025
aa5348b
fix(HuggingFaceTEIRanker): change the token check logic to use the re…
atopx May 23, 2025
75a2d69
fix(format): run `hatch run format`
atopx May 23, 2025
7133e00
fix:
atopx May 23, 2025
66a4c16
Merge branch 'main' into main
atopx May 23, 2025
976b8ac
fix HuggingFaceTEIRanker:
atopx May 23, 2025
30d0d05
Merge branch 'main' into main
atopx May 23, 2025
243b3f7
Merge branch 'main' into main
atopx May 23, 2025
bc9529e
fix(HuggingFaceTEIRanker) :Revise the docstring of the HuggingFaceTEI…
atopx May 25, 2025
91aa51f
Merge branch 'main' into main
atopx May 25, 2025
d028977
fix:unit test for HuggingFaceTEIRanker raise message
atopx May 26, 2025
d6d1b68
Merge branch 'main' of github.com:atopx/haystack
atopx May 26, 2025
0b4043e
Merge branch 'main' into main
atopx May 26, 2025
542b472
Merge branch 'main' into tei-ranker
anakin87 May 27, 2025
6f2569b
fix fmt
anakin87 May 27, 2025
e80f814
minor refinements
anakin87 May 27, 2025
a88e2f2
refine release note
anakin87 May 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,6 @@ haystack/json-schemas

# Zed configs
.zed/*

# uv
uv.lock
10 changes: 8 additions & 2 deletions docs/pydoc/config/rankers_api.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/components/rankers]
modules: ["lost_in_the_middle", "meta_field", "meta_field_grouping_ranker", "transformers_similarity",
"sentence_transformers_diversity", "sentence_transformers_similarity"]
modules: [
"hugging_face_tei",
"lost_in_the_middle",
"meta_field",
"meta_field_grouping_ranker",
"sentence_transformers_diversity",
"sentence_transformers_similarity",
"transformers_similarity"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
Expand Down
2 changes: 2 additions & 0 deletions haystack/components/rankers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lazy_imports import LazyImporter

_import_structure = {
"hugging_face_tei": ["HuggingFaceTEIRanker"],
"lost_in_the_middle": ["LostInTheMiddleRanker"],
"meta_field": ["MetaFieldRanker"],
"meta_field_grouping_ranker": ["MetaFieldGroupingRanker"],
Expand All @@ -17,6 +18,7 @@
}

if TYPE_CHECKING:
from .hugging_face_tei import HuggingFaceTEIRanker
from .lost_in_the_middle import LostInTheMiddleRanker
from .meta_field import MetaFieldRanker
from .meta_field_grouping_ranker import MetaFieldGroupingRanker
Expand Down
270 changes: 270 additions & 0 deletions haystack/components/rankers/hugging_face_tei.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import copy
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urljoin

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.requests_utils import async_request_with_retry, request_with_retry


class TruncationDirection(str, Enum):
"""
Defines the direction to truncate text when input length exceeds the model's limit.

Attributes:
LEFT: Truncate text from the left side (start of text).
RIGHT: Truncate text from the right side (end of text).
"""

LEFT = "Left"
RIGHT = "Right"


@component
class HuggingFaceTEIRanker:
"""
Ranks documents based on their semantic similarity to the query.

It can be used with a Text Embeddings Inference (TEI) API endpoint:
- [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
- [Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints)

Usage example:
```python
from haystack import Document
from haystack.components.rankers import HuggingFaceTEIRanker
from haystack.utils import Secret

reranker = HuggingFaceTEIRanker(
url="http://localhost:8080",
top_k=5,
timeout=30,
token=Secret.from_token("my_api_token")
)

docs = [Document(content="The capital of France is Paris"), Document(content="The capital of Germany is Berlin")]

result = reranker.run(query="What is the capital of France?", documents=docs)

ranked_docs = result["documents"]
print(ranked_docs)
>> {'documents': [Document(id=..., content: 'the capital of France is Paris', score: 0.9979767),
>> Document(id=..., content: 'the capital of Germany is Berlin', score: 0.13982213)]}
```
"""

def __init__(
self,
*,
url: str,
top_k: int = 10,
raw_scores: bool = False,
timeout: Optional[int] = 30,
max_retries: int = 3,
retry_status_codes: Optional[List[int]] = None,
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
) -> None:
"""
Initializes the TEI reranker component.

:param url: Base URL of the TEI reranking service (for example, "https://api.example.com").
:param top_k: Maximum number of top documents to return.
:param raw_scores: If True, include raw relevance scores in the API payload.
:param timeout: Request timeout in seconds.
:param max_retries: Maximum number of retry attempts for failed requests.
:param retry_status_codes: List of HTTP status codes that will trigger a retry.
When None, HTTP 408, 418, 429 and 503 will be retried (default: None).
:param token: The Hugging Face token to use as HTTP bearer authorization. Not always required
depending on your TEI server configuration.
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
"""
self.url = url
self.top_k = top_k
self.timeout = timeout
self.token = token
self.max_retries = max_retries
self.retry_status_codes = retry_status_codes
self.raw_scores = raw_scores

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
url=self.url,
top_k=self.top_k,
timeout=self.timeout,
token=self.token.to_dict() if self.token else None,
max_retries=self.max_retries,
retry_status_codes=self.retry_status_codes,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEIRanker":
"""
Deserializes the component from a dictionary.

:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)

def _compose_response(
self, result: Union[Dict[str, str], List[Dict[str, Any]]], top_k: Optional[int], documents: List[Document]
) -> Dict[str, List[Document]]:
"""
Processes the API response into a structured format.

:param result: The raw response from the API.

:returns: A dictionary with the following keys:
- `documents`: A list of reranked documents.

:raises requests.exceptions.RequestException:
- If the API request fails.

:raises RuntimeError:
- If the API returns an error response.
"""
if isinstance(result, dict) and "error" in result:
error_type = result.get("error_type", "UnknownError")
error_msg = result.get("error", "No additional information.")
raise RuntimeError(f"HuggingFaceTEIRanker API call failed ({error_type}): {error_msg}")

# Ensure we have a list of score dicts
if not isinstance(result, list):
# Expected list or dict, but encountered an unknown response format.
error_msg = f"Expected a list of score dictionaries, but got `{type(result).__name__}`. "
error_msg += f"Response content: {result}"
raise RuntimeError(f"Unexpected response format from text-embeddings-inference rerank API: {error_msg}")

# Determine number of docs to return
final_k = min(top_k or self.top_k, len(result))

# Select and return the top_k documents
ranked_docs = []
for item in result[:final_k]:
index: int = item["index"]
doc_copy = copy.copy(documents[index])
doc_copy.score = item["score"]
ranked_docs.append(doc_copy)
return {"documents": ranked_docs}

@component.output_types(documents=List[Document])
def run(
self,
query: str,
documents: List[Document],
top_k: Optional[int] = None,
truncation_direction: Optional[TruncationDirection] = None,
) -> Dict[str, List[Document]]:
"""
Reranks the provided documents by relevance to the query using the TEI API.

:param query: The user query string to guide reranking.
:param documents: List of `Document` objects to rerank.
:param top_k: Optional override for the maximum number of documents to return.
:param truncation_direction: If set, enables text truncation in the specified direction.

:returns: A dictionary with the following keys:
- `documents`: A list of reranked documents.

:raises requests.exceptions.RequestException:
- If the API request fails.

:raises RuntimeError:
- If the API returns an error response.
"""
# Return empty if no documents provided
if not documents:
return {"documents": []}

# Prepare the payload
texts = [doc.content for doc in documents]
payload: Dict[str, Any] = {"query": query, "texts": texts, "raw_scores": self.raw_scores}
if truncation_direction:
payload.update({"truncate": True, "truncation_direction": truncation_direction.value})

headers = {}
if self.token and self.token.resolve_value():
headers["Authorization"] = f"Bearer {self.token.resolve_value()}"

# Call the external service with retry
response = request_with_retry(
method="POST",
url=urljoin(self.url, "/rerank"),
json=payload,
timeout=self.timeout,
headers=headers,
attempts=self.max_retries,
status_codes_to_retry=self.retry_status_codes,
)

result: Union[Dict[str, str], List[Dict[str, Any]]] = response.json()

return self._compose_response(result, top_k, documents)

@component.output_types(documents=List[Document])
async def run_async(
self,
query: str,
documents: List[Document],
top_k: Optional[int] = None,
truncation_direction: Optional[TruncationDirection] = None,
) -> Dict[str, List[Document]]:
"""
Asynchronously reranks the provided documents by relevance to the query using the TEI API.

:param query: The user query string to guide reranking.
:param documents: List of `Document` objects to rerank.
:param top_k: Optional override for the maximum number of documents to return.
:param truncation_direction: If set, enables text truncation in the specified direction.

:returns: A dictionary with the following keys:
- `documents`: A list of reranked documents.

:raises httpx.RequestError:
- If the API request fails.
:raises RuntimeError:
- If the API returns an error response.
"""
# Return empty if no documents provided
if not documents:
return {"documents": []}

# Prepare the payload
texts = [doc.content for doc in documents]
payload: Dict[str, Any] = {"query": query, "texts": texts, "raw_scores": self.raw_scores}
if truncation_direction:
payload.update({"truncate": True, "truncation_direction": truncation_direction.value})

headers = {}
if self.token and self.token.resolve_value():
headers["Authorization"] = f"Bearer {self.token.resolve_value()}"

# Call the external service with retry
response = await async_request_with_retry(
method="POST",
url=urljoin(self.url, "/rerank"),
json=payload,
timeout=self.timeout,
headers=headers,
attempts=self.max_retries,
status_codes_to_retry=self.retry_status_codes,
)

result: Union[Dict[str, str], List[Dict[str, Any]]] = response.json()

return self._compose_response(result, top_k, documents)
4 changes: 2 additions & 2 deletions haystack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"jinja2_extensions": ["Jinja2TimeExtension"],
"jupyter": ["is_in_jupyter"],
"misc": ["expit", "expand_page_range"],
"requests_utils": ["request_with_retry"],
"requests_utils": ["request_with_retry", "async_request_with_retry"],
"type_serialization": ["deserialize_type", "serialize_type"],
}

Expand All @@ -33,7 +33,7 @@
from .jinja2_extensions import Jinja2TimeExtension
from .jupyter import is_in_jupyter
from .misc import expand_page_range, expit
from .requests_utils import request_with_retry
from .requests_utils import async_request_with_retry, request_with_retry
from .type_serialization import deserialize_type, serialize_type
else:
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
Loading
Loading