Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def webhook_secret(self, value: str | None) -> None: # type: ignore
@override
def base_url(self) -> _httpx.URL:
if base_url is not None:
return _httpx.URL(base_url)
return self._enforce_trailing_slash(_httpx.URL(base_url))

return super().base_url

Expand Down
2 changes: 2 additions & 0 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def _get_azure_ad_token(self) -> str | None:

@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
self._refresh_api_key()
Comment thread
LittleChenLiya marked this conversation as resolved.
Outdated
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}

options = model_copy(options)
Expand Down Expand Up @@ -603,6 +604,7 @@ async def _get_azure_ad_token(self) -> str | None:

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
await self._refresh_api_key()
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}

options = model_copy(options)
Expand Down
78 changes: 78 additions & 0 deletions tests/lib/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,84 @@ def test_client_copying_override_options(client: Client) -> None:
assert copied._custom_query == {"api-version": "2022-05-01"}


@pytest.mark.respx()
def test_client_api_key_provider_refresh_sync(respx_mock: MockRouter) -> None:
respx_mock.post(
"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
).mock(
side_effect=[
httpx.Response(500, json={"error": "server error"}),
httpx.Response(200, json={"foo": "bar"}),
]
)

counter = 0

def api_key_provider() -> str:
nonlocal counter

counter += 1

if counter == 1:
return "first"

return "second"

client = AzureOpenAI(
api_version="2024-02-01",
api_key=api_key_provider,
azure_endpoint="https://example-resource.azure.openai.com",
)
client.chat.completions.create(messages=[], model="gpt-4")

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2

assert calls[0].request.headers.get("api-key") == "first"
assert calls[1].request.headers.get("api-key") == "second"


@pytest.mark.asyncio
@pytest.mark.respx()
async def test_client_api_key_provider_refresh_async(respx_mock: MockRouter) -> None:
respx_mock.post(
"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
).mock(
side_effect=[
httpx.Response(500, json={"error": "server error"}),
httpx.Response(200, json={"foo": "bar"}),
]
)

counter = 0

async def api_key_provider() -> str:
nonlocal counter

counter += 1

if counter == 1:
return "first"

return "second"

client = AsyncAzureOpenAI(
api_version="2024-02-01",
api_key=api_key_provider,
azure_endpoint="https://example-resource.azure.openai.com",
)

await client.chat.completions.create(messages=[], model="gpt-4")

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2

assert calls[0].request.headers.get("api-key") == "first"
assert calls[1].request.headers.get("api-key") == "second"


@pytest.mark.respx()
def test_client_token_provider_refresh_sync(respx_mock: MockRouter) -> None:
respx_mock.post(
Expand Down
11 changes: 9 additions & 2 deletions tests/test_module_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,15 @@ def test_base_url_option() -> None:

openai.base_url = "http://foo.com"

assert openai.base_url == URL("http://foo.com")
assert openai.completions._client.base_url == URL("http://foo.com")
assert openai.base_url == "http://foo.com"
assert openai.completions._client.base_url.raw_path == b"/"


def test_base_url_option_without_trailing_slash() -> None:
openai.base_url = "http://foo.com/custom/path"

assert openai.base_url == "http://foo.com/custom/path"
assert openai.completions._client.base_url == URL("http://foo.com/custom/path/")


def test_timeout_option() -> None:
Expand Down