Skip to content

Add an option to use Azure endpoints for the /completions & /search operations. #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 22, 2022
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
/public/dist
__pycache__
build
*.egg
.vscode/settings.json
2 changes: 2 additions & 0 deletions openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

organization = os.environ.get("OPENAI_ORGANIZATION")
api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
api_type = os.environ.get("OPENAI_API_TYPE", "open_ai")
api_version = None
verify_ssl_certs = True # No effect. Certificates are always verified.
proxy = None
Expand All @@ -50,6 +51,7 @@
"Search",
"api_base",
"api_key",
"api_type",
"api_key_path",
"api_version",
"app_info",
Expand Down
10 changes: 8 additions & 2 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import openai
from openai import error, util, version
from openai.openai_response import OpenAIResponse
from openai.util import ApiType

TIMEOUT_SECS = 600
MAX_CONNECTION_RETRIES = 2
Expand Down Expand Up @@ -70,9 +71,10 @@ def parse_stream(rbody):


class APIRequestor:
def __init__(self, key=None, api_base=None, api_version=None, organization=None):
def __init__(self, key=None, api_base=None, api_type=None, api_version=None, organization=None):
self.api_base = api_base or openai.api_base
self.api_key = key or util.default_api_key()
self.api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str(openai.api_type)
self.api_version = api_version or openai.api_version
self.organization = organization or openai.organization

Expand Down Expand Up @@ -193,9 +195,13 @@ def request_headers(
headers = {
"X-OpenAI-Client-User-Agent": json.dumps(ua),
"User-Agent": user_agent,
"Authorization": "Bearer %s" % (self.api_key,),
}

if self.api_type == ApiType.OPEN_AI:
headers["Authorization"] = "Bearer %s" % (self.api_key,)
elif self.api_type == ApiType.AZURE:
headers["api-key"] = self.api_key

if self.organization:
headers["OpenAI-Organization"] = self.organization

Expand Down
25 changes: 21 additions & 4 deletions openai/api_resources/abstract/api_resource.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from urllib.parse import quote_plus

from openai import api_requestor, error, util
import openai
from openai.openai_object import OpenAIObject
from openai.util import ApiType


class APIResource(OpenAIObject):
api_prefix = ""
azure_api_prefix = 'openai/deployments'
azure_api_version = '?api-version=2021-11-01-preview'

@classmethod
def retrieve(cls, id, api_key=None, request_id=None, **params):
Expand All @@ -32,7 +36,7 @@ def class_url(cls):
return "/%s/%ss" % (cls.api_prefix, base)
return "/%ss" % (base)

def instance_url(self):
def instance_url(self, operation=None):
id = self.get("id")

if not isinstance(id, str):
Expand All @@ -43,9 +47,22 @@ def instance_url(self):
"id",
)

base = self.class_url()
extn = quote_plus(id)
return "%s/%s" % (base, extn)
if self.typed_api_type == ApiType.AZURE:
if not operation:
raise error.InvalidRequestError(
"The request needs an operation (eg: 'search') for the Azure OpenAI API type."
)
extn = quote_plus(id)
return "/%s/%s/%s%s" % (self.azure_api_prefix, extn, operation, self.azure_api_version)

elif self.typed_api_type == ApiType.OPEN_AI:
base = self.class_url()
extn = quote_plus(id)
return "%s/%s" % (base, extn)

else:
raise error.InvalidAPIType('Unsupported API type %s' % self.api_type)


# The `method_` and `url_` arguments are suffixed with an underscore to
# avoid conflicting with actual request parameters in `params`.
Expand Down
52 changes: 42 additions & 10 deletions openai/api_resources/abstract/engine_api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,56 @@
from typing import Optional
from urllib.parse import quote_plus

import openai
from openai import api_requestor, error, util
from openai.api_resources.abstract.api_resource import APIResource
from openai.openai_response import OpenAIResponse
from openai.util import ApiType

MAX_TIMEOUT = 20


class EngineAPIResource(APIResource):
engine_required = True
plain_old_data = False
azure_api_prefix = 'openai/deployments'
azure_api_version = '?api-version=2021-11-01-preview'

def __init__(self, engine: Optional[str] = None, **kwargs):
super().__init__(engine=engine, **kwargs)

@classmethod
def class_url(cls, engine: Optional[str] = None):
def class_url(cls, engine: Optional[str] = None, api_type : Optional[str] = None):
# Namespaces are separated in object names with periods (.) and in URLs
# with forward slashes (/), so replace the former with the latter.
base = cls.OBJECT_NAME.replace(".", "/") # type: ignore
if engine is None:
return "/%ss" % (base)
typed_api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str(openai.api_type)

if typed_api_type == ApiType.AZURE:
if engine is None:
raise error.InvalidRequestError(
"You must provide the deployment name in the 'engine' parameter to access the Azure OpenAI service"
)
extn = quote_plus(engine)
return "/%s/%s/%ss%s" % (cls.azure_api_prefix, extn, base, cls.azure_api_version)

elif typed_api_type == ApiType.OPEN_AI:
if engine is None:
return "/%ss" % (base)

extn = quote_plus(engine)
return "/engines/%s/%ss" % (extn, base)

else:
raise error.InvalidAPIType('Unsupported API type %s' % api_type)

extn = quote_plus(engine)
return "/engines/%s/%ss" % (extn, base)

@classmethod
def create(
cls,
api_key=None,
api_base=None,
api_type=None,
request_id=None,
api_version=None,
organization=None,
Expand All @@ -58,10 +78,11 @@ def create(
requestor = api_requestor.APIRequestor(
api_key,
api_base=api_base,
api_type=api_type,
api_version=api_version,
organization=organization,
)
url = cls.class_url(engine)
url = cls.class_url(engine, api_type)
response, _, api_key = requestor.request(
"post", url, params, stream=stream, request_id=request_id
)
Expand Down Expand Up @@ -103,14 +124,25 @@ def instance_url(self):
"id",
)

base = self.class_url(self.engine)
extn = quote_plus(id)
url = "%s/%s" % (base, extn)
params_connector = '?'
if self.typed_api_type == ApiType.AZURE:
extn = quote_plus(id)
base = self.OBJECT_NAME.replace(".", "/")
url = "/%s/%s/%ss/%s%s" % (self.azure_api_prefix, self.engine, base, extn, self.azure_api_version)
params_connector = '&'

elif self.typed_api_type == ApiType.OPEN_AI:
base = self.class_url(self.engine, self.api_type)
extn = quote_plus(id)
url = "%s/%s" % (base, extn)

else:
raise error.InvalidAPIType('Unsupported API type %s' % self.api_type)

timeout = self.get("timeout")
if timeout is not None:
timeout = quote_plus(str(timeout))
url += "?timeout={}".format(timeout)
url += params_connector + "timeout={}".format(timeout)
return url

def wait(self, timeout=None):
Expand Down
10 changes: 8 additions & 2 deletions openai/api_resources/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from openai import util
from openai.api_resources.abstract import ListableAPIResource, UpdateableAPIResource
from openai.error import TryAgain
from openai.error import InvalidAPIType, TryAgain
from openai.util import ApiType


class Engine(ListableAPIResource, UpdateableAPIResource):
Expand All @@ -26,7 +27,12 @@ def generate(self, timeout=None, **params):
util.log_info("Waiting for model to warm up", error=e)

def search(self, **params):
return self.request("post", self.instance_url() + "/search", params)
if self.typed_api_type == ApiType.AZURE:
return self.request("post", self.instance_url("search"), params)
elif self.typed_api_type == ApiType.OPEN_AI:
return self.request("post", self.instance_url() + "/search", params)
else:
raise InvalidAPIType('Unsupported API type %s' % self.api_type)

def embeddings(self, **params):
return self.request("post", self.instance_url() + "/embeddings", params)
3 changes: 3 additions & 0 deletions openai/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ class RateLimitError(OpenAIError):
class ServiceUnavailableError(OpenAIError):
pass

class InvalidAPIType(OpenAIError):
pass


class SignatureVerificationError(OpenAIError):
def __init__(self, message, sig_header, http_body=None):
Expand Down
13 changes: 13 additions & 0 deletions openai/openai_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from copy import deepcopy
from typing import Optional

import openai
from openai import api_requestor, util
from openai.openai_response import OpenAIResponse
from openai.util import ApiType


class OpenAIObject(dict):
Expand All @@ -14,6 +16,7 @@ def __init__(
id=None,
api_key=None,
api_version=None,
api_type=None,
organization=None,
response_ms: Optional[int] = None,
api_base=None,
Expand All @@ -30,6 +33,7 @@ def __init__(

object.__setattr__(self, "api_key", api_key)
object.__setattr__(self, "api_version", api_version)
object.__setattr__(self, "api_type", api_type)
object.__setattr__(self, "organization", organization)
object.__setattr__(self, "api_base_override", api_base)
object.__setattr__(self, "engine", engine)
Expand Down Expand Up @@ -90,6 +94,7 @@ def __reduce__(self):
self.get("id", None),
self.api_key,
self.api_version,
self.api_type,
self.organization,
),
dict(self), # state
Expand Down Expand Up @@ -128,11 +133,13 @@ def refresh_from(
values,
api_key=None,
api_version=None,
api_type=None,
organization=None,
response_ms: Optional[int] = None,
):
self.api_key = api_key or getattr(values, "api_key", None)
self.api_version = api_version or getattr(values, "api_version", None)
self.api_type = api_type or getattr(values, "api_type", None)
self.organization = organization or getattr(values, "organization", None)
self._response_ms = response_ms or getattr(values, "_response_ms", None)

Expand Down Expand Up @@ -164,6 +171,7 @@ def request(
requestor = api_requestor.APIRequestor(
key=self.api_key,
api_base=self.api_base_override or self.api_base(),
api_type=self.api_type,
api_version=self.api_version,
organization=self.organization,
)
Expand Down Expand Up @@ -233,6 +241,10 @@ def to_dict_recursive(self):
def openai_id(self):
return self.id

@property
def typed_api_type(self):
return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(openai.api_type)

# This class overrides __setitem__ to throw exceptions on inputs that it
# doesn't like. This can cause problems when we try to copy an object
# wholesale because some data that's returned from the API may not be valid
Expand All @@ -243,6 +255,7 @@ def __copy__(self):
self.get("id"),
self.api_key,
api_version=self.api_version,
api_type=self.api_type,
organization=self.organization,
)

Expand Down
2 changes: 1 addition & 1 deletion openai/openai_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def organization(self) -> Optional[str]:
@property
def response_ms(self) -> Optional[int]:
h = self._headers.get("Openai-Processing-Ms")
return None if h is None else int(h)
return None if h is None else round(float(h))
27 changes: 25 additions & 2 deletions openai/tests/test_api_requestor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json

import pytest
import requests
from pytest_mock import MockerFixture

from openai import Model
from openai.api_requestor import APIRequestor


@pytest.mark.requestor
def test_requestor_sets_request_id(mocker: MockerFixture) -> None:
# Fake out 'requests' and confirm that the X-Request-Id header is set.

Expand All @@ -25,3 +26,25 @@ def fake_request(self, *args, **kwargs):
Model.retrieve("xxx", request_id=fake_request_id) # arbitrary API resource
got_request_id = got_headers.get("X-Request-Id")
assert got_request_id == fake_request_id

@pytest.mark.requestor
def test_requestor_open_ai_headers() -> None:
api_requestor = APIRequestor(key="test_key", api_type="open_ai")
headers = {"Test_Header": "Unit_Test_Header"}
headers = api_requestor.request_headers(method="get", extra=headers, request_id="test_id")
print(headers)
assert "Test_Header"in headers
assert headers["Test_Header"] == "Unit_Test_Header"
assert "Authorization"in headers
assert headers["Authorization"] == "Bearer test_key"

@pytest.mark.requestor
def test_requestor_azure_headers() -> None:
api_requestor = APIRequestor(key="test_key", api_type="azure")
headers = {"Test_Header": "Unit_Test_Header"}
headers = api_requestor.request_headers(method="get", extra=headers, request_id="test_id")
print(headers)
assert "Test_Header"in headers
assert headers["Test_Header"] == "Unit_Test_Header"
assert "api-key"in headers
assert headers["api-key"] == "test_key"
Loading