Skip to content

Commit 0f57dca

Browse files
authored
[Fix] Fetch when vector id string contains spaces (#372)
## Problem Some data operations fail when the vector id string contains a space. ```python from pinecone import Pinecone pc = Pinecone() pc.fetch(ids=["id with string"]) # no results returned, even when vector exists ``` ## Solution The problem occurred due to the way spaces were being encoded as `+` instead of `%20` in url query params. The fix was a small adjustment to our code generation templates. I added test coverage for upsert / query / fetch with various weird ids to make sure the change in encoding hasn't broken any other use cases that could pop up. ## Type of Change - [x] Bug fix (non-breaking change which fixes an issue)
1 parent d9df375 commit 0f57dca

File tree

6 files changed

+146
-9
lines changed

6 files changed

+146
-9
lines changed

codegen/apis

Submodule apis updated from e9b47c7 to 062b114

pinecone/core/openapi/shared/rest.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
import ssl
66
import os
7-
from urllib.parse import urlencode
7+
from urllib.parse import urlencode, quote
88

99
import urllib3
1010

@@ -182,7 +182,7 @@ def request(
182182
if (method != "DELETE") and ("Content-Type" not in headers):
183183
headers["Content-Type"] = "application/json"
184184
if query_params:
185-
url += "?" + urlencode(query_params)
185+
url += "?" + urlencode(query_params, quote_via=quote)
186186
if ("Content-Type" not in headers) or (re.search("json", headers["Content-Type"], re.IGNORECASE)):
187187
request_body = None
188188
if body is not None:
@@ -240,8 +240,10 @@ def request(
240240
raise PineconeApiException(status=0, reason=msg)
241241
# For `GET`, `HEAD`
242242
else:
243+
if query_params:
244+
url += "?" + urlencode(query_params, quote_via=quote)
243245
r = self.pool_manager.request(
244-
method, url, fields=query_params, preload_content=_preload_content, timeout=timeout, headers=headers
246+
method, url, preload_content=_preload_content, timeout=timeout, headers=headers
245247
)
246248
except urllib3.exceptions.SSLError as e:
247249
msg = "{0}\n{1}".format(type(e).__name__, str(e))

tests/integration/data/conftest.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
import json
55
from ..helpers import get_environment_var, random_string
6-
from .seed import setup_data, setup_list_data
6+
from .seed import setup_data, setup_list_data, setup_weird_ids_data
77

88
# Test matrix needs to consider the following dimensions:
99
# - pod vs serverless
@@ -60,13 +60,16 @@ def index_name():
6060

6161
@pytest.fixture(scope="session")
6262
def namespace():
63-
# return 'banana'
6463
return random_string(10)
6564

6665

6766
@pytest.fixture(scope="session")
6867
def list_namespace():
69-
# return 'list-banana'
68+
return random_string(10)
69+
70+
71+
@pytest.fixture(scope="session")
72+
def weird_ids_namespace():
7073
return random_string(10)
7174

7275

@@ -89,9 +92,12 @@ def index_host(index_name, metric, spec):
8992

9093

9194
@pytest.fixture(scope="session", autouse=True)
92-
def seed_data(idx, namespace, index_host, list_namespace):
95+
def seed_data(idx, namespace, index_host, list_namespace, weird_ids_namespace):
9396
print("Seeding data in host " + index_host)
9497

98+
print("Seeding data in weird is namespace " + weird_ids_namespace)
99+
setup_weird_ids_data(idx, weird_ids_namespace, True)
100+
95101
print('Seeding list data in namespace "' + list_namespace + '"')
96102
setup_list_data(idx, list_namespace, True)
97103

tests/integration/data/seed.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ..helpers import poll_fetch_for_ids_in_namespace
22
from pinecone import Vector
33
from .utils import embedding_values
4+
import itertools
45

56

67
def setup_data(idx, target_namespace, wait):
@@ -43,3 +44,82 @@ def setup_list_data(idx, target_namespace, wait):
4344

4445
if wait:
4546
poll_fetch_for_ids_in_namespace(idx, ids=["999"], namespace=target_namespace)
47+
48+
49+
def weird_invalid_ids():
50+
invisible = [
51+
"⠀", # U+2800
52+
" ", # U+00A0
53+
"­", # U+00AD
54+
"឴", # U+17F4
55+
"᠎", # U+180E
56+
" ", # U+2000
57+
" ", # U+2001
58+
" ", # U+2002
59+
]
60+
emojis = list("🌲🍦")
61+
two_byte = list("田中さんにあげて下さい")
62+
quotes = ["‘", "’", "“", "”", "„", "‟", "‹", "›", "❛", "❜", "❝", "❞", "❮", "❯", """, "'", "「", "」"]
63+
64+
return invisible + emojis + two_byte + quotes
65+
66+
67+
def weird_valid_ids():
68+
# Drawing inspiration from the big list of naughty strings https://github.com/minimaxir/big-list-of-naughty-strings/blob/master/blns.txt
69+
ids = []
70+
71+
numbers = list("1234567890")
72+
invisible = [" ", "\n", "\t", "\r"]
73+
punctuation = list("!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~")
74+
escaped = [f"\\{c}" for c in punctuation]
75+
76+
characters = numbers + invisible + punctuation + escaped
77+
ids.extend(characters)
78+
ids.extend(["".join(x) for x in itertools.combinations_with_replacement(characters, 2)])
79+
80+
boolean_ish = [
81+
"undefined",
82+
"nil",
83+
"null",
84+
"Null",
85+
"NULL",
86+
"None",
87+
"True",
88+
"False",
89+
"true",
90+
"false",
91+
]
92+
ids.extend(boolean_ish)
93+
94+
script_injection = [
95+
"<script>alert(0)</script>",
96+
"<svg><script>123<1>alert(3)</script>",
97+
'" onfocus=JaVaSCript:alert(10) autofocus',
98+
"javascript:alert(1)",
99+
"javascript:alert(1);",
100+
'<img src\x32=x onerror="javascript:alert(182)">' "1;DROP TABLE users",
101+
"' OR 1=1 -- 1",
102+
"' OR '1'='1",
103+
]
104+
ids.extend(script_injection)
105+
106+
unwanted_interpolation = [
107+
"$HOME",
108+
"$ENV{'HOME'}",
109+
"%d",
110+
"%s",
111+
"%n",
112+
"%x",
113+
"{0}",
114+
]
115+
ids.extend(unwanted_interpolation)
116+
117+
return ids
118+
119+
120+
def setup_weird_ids_data(idx, target_namespace, wait):
121+
weird_ids = weird_valid_ids()
122+
batch_size = 100
123+
for i in range(0, len(weird_ids), batch_size):
124+
chunk = weird_ids[i : i + batch_size]
125+
idx.upsert(vectors=[(x, embedding_values(2)) for x in chunk], namespace=target_namespace)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pytest
2+
from .seed import weird_valid_ids, weird_invalid_ids
3+
4+
5+
class TestHandlingOfWeirdIds:
6+
def test_fetch_weird_ids(self, idx, weird_ids_namespace):
7+
weird_ids = weird_valid_ids()
8+
batch_size = 100
9+
for i in range(0, len(weird_ids), batch_size):
10+
ids_to_fetch = weird_ids[i : i + batch_size]
11+
results = idx.fetch(ids=ids_to_fetch, namespace=weird_ids_namespace)
12+
assert results.usage["read_units"] > 0
13+
assert len(results.vectors) == len(ids_to_fetch)
14+
for id in ids_to_fetch:
15+
assert id in results.vectors
16+
assert results.vectors[id].id == id
17+
assert results.vectors[id].metadata == None
18+
assert results.vectors[id].values != None
19+
assert len(results.vectors[id].values) == 2
20+
21+
@pytest.mark.parametrize("id_to_query", weird_valid_ids())
22+
def test_query_weird_ids(self, idx, weird_ids_namespace, id_to_query):
23+
results = idx.query(id=id_to_query, top_k=10, namespace=weird_ids_namespace, include_values=True)
24+
assert results.usage["read_units"] > 0
25+
assert len(results.matches) == 10
26+
assert results.namespace == weird_ids_namespace
27+
assert results.matches[0].id != None
28+
assert results.matches[0].metadata == None
29+
assert results.matches[0].values != None
30+
assert len(results.matches[0].values) == 2
31+
32+
def test_list_weird_ids(self, idx, weird_ids_namespace):
33+
expected_ids = set(weird_valid_ids())
34+
id_iterator = idx.list(namespace=weird_ids_namespace)
35+
for page in id_iterator:
36+
for id in page:
37+
assert id in expected_ids
38+
39+
@pytest.mark.parametrize("id_to_upsert", weird_invalid_ids())
40+
def test_weird_invalid_ids(self, idx, weird_ids_namespace, id_to_upsert):
41+
with pytest.raises(Exception) as e:
42+
idx.upsert(vectors=[(id_to_upsert, [0.1, 0.1])], namespace=weird_ids_namespace)
43+
assert "Vector ID must be ASCII" in str(e.value)
44+
45+
def test_null_character(self, idx, weird_ids_namespace):
46+
with pytest.raises(Exception) as e:
47+
idx.upsert(vectors=[("\0", [0.1, 0.1])], namespace=weird_ids_namespace)
48+
49+
assert "Vector ID must not contain null character" in str(e.value)

0 commit comments

Comments
 (0)