Skip to content

OpenAI: support any path configuration #3452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
* @author Stefan Vassilev
* @author Thomas Vitale
* @author Ilayaperumal Gopinathan
* @author lambochen
*/
@AutoConfiguration(after = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class,
SpringAiRetryAutoConfiguration.class })
Expand All @@ -72,6 +73,7 @@ public OpenAiAudioSpeechModel openAiAudioSpeechModel(OpenAiConnectionProperties
.baseUrl(resolved.baseUrl())
.apiKey(new SimpleApiKey(resolved.apiKey()))
.headers(resolved.headers())
.audioSpeechPath(speechProperties.getAudioSpeechPath())
.restClientBuilder(restClientBuilderProvider.getIfAvailable(RestClient::builder))
.webClientBuilder(webClientBuilderProvider.getIfAvailable(WebClient::builder))
.responseErrorHandler(responseErrorHandler)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.springframework.ai.openai.OpenAiAudioSpeechOptions;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;

Expand All @@ -30,6 +31,7 @@
* @author Ahmed Yousri
* @author Stefan Vassilev
* @author Jonghoon Park
* @author lambochen
*/
@ConfigurationProperties(OpenAiAudioSpeechProperties.CONFIG_PREFIX)
public class OpenAiAudioSpeechProperties extends OpenAiParentProperties {
Expand All @@ -44,6 +46,8 @@ public class OpenAiAudioSpeechProperties extends OpenAiParentProperties {

private static final OpenAiAudioApi.SpeechRequest.AudioResponseFormat DEFAULT_RESPONSE_FORMAT = OpenAiAudioApi.SpeechRequest.AudioResponseFormat.MP3;

private String audioSpeechPath = OpenAiApiConstants.DEFAULT_AUDIO_SPEECH_PATH;

@NestedConfigurationProperty
private OpenAiAudioSpeechOptions options = OpenAiAudioSpeechOptions.builder()
.model(DEFAULT_SPEECH_MODEL)
Expand All @@ -60,4 +64,12 @@ public void setOptions(OpenAiAudioSpeechOptions options) {
this.options = options;
}

public String getAudioSpeechPath() {
return audioSpeechPath;
}

public void setAudioSpeechPath(String audioSpeechPath) {
this.audioSpeechPath = audioSpeechPath;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
* @author Stefan Vassilev
* @author Thomas Vitale
* @author Ilayaperumal Gopinathan
* @author lambochen
*/
@AutoConfiguration(after = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class,
SpringAiRetryAutoConfiguration.class })
Expand All @@ -72,6 +73,7 @@ public OpenAiAudioTranscriptionModel openAiAudioTranscriptionModel(OpenAiConnect
.baseUrl(resolved.baseUrl())
.apiKey(new SimpleApiKey(resolved.apiKey()))
.headers(resolved.headers())
.audioTranscriptionPath(transcriptionProperties.getAudioTranscriptionPath())
.restClientBuilder(restClientBuilderProvider.getIfAvailable(RestClient::builder))
.webClientBuilder(webClientBuilderProvider.getIfAvailable(WebClient::builder))
.responseErrorHandler(responseErrorHandler)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,18 @@

import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;

/**
* Configuration properties for OpenAI audio transcription.
*
* Default values for required options are model = whisper-1, temperature = 0.7, and
* response format = text.
*
* @author lambochen
*/
@ConfigurationProperties(OpenAiAudioTranscriptionProperties.CONFIG_PREFIX)
public class OpenAiAudioTranscriptionProperties extends OpenAiParentProperties {

Expand All @@ -32,6 +41,8 @@ public class OpenAiAudioTranscriptionProperties extends OpenAiParentProperties {

private static final OpenAiAudioApi.TranscriptResponseFormat DEFAULT_RESPONSE_FORMAT = OpenAiAudioApi.TranscriptResponseFormat.TEXT;

private String audioTranscriptionPath = OpenAiApiConstants.DEFAULT_AUDIO_TRANSCRIPTION_PATH;

@NestedConfigurationProperty
private OpenAiAudioTranscriptionOptions options = OpenAiAudioTranscriptionOptions.builder()
.model(DEFAULT_TRANSCRIPTION_MODEL)
Expand All @@ -47,4 +58,12 @@ public void setOptions(OpenAiAudioTranscriptionOptions options) {
this.options = options;
}

public String getAudioTranscriptionPath() {
return audioTranscriptionPath;
}

public void setAudioTranscriptionPath(String audioTranscriptionPath) {
this.audioTranscriptionPath = audioTranscriptionPath;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@
package org.springframework.ai.model.openai.autoconfigure;

import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;

/**
* @author lambochen
*/
@ConfigurationProperties(OpenAiChatProperties.CONFIG_PREFIX)
public class OpenAiChatProperties extends OpenAiParentProperties {

public static final String CONFIG_PREFIX = "spring.ai.openai.chat";

public static final String DEFAULT_CHAT_MODEL = "gpt-4o-mini";

public static final String DEFAULT_COMPLETIONS_PATH = "/v1/chat/completions";
public static final String DEFAULT_COMPLETIONS_PATH = OpenAiApiConstants.DEFAULT_COMPLETIONS_PATH;

private static final Double DEFAULT_TEMPERATURE = 0.7;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,21 @@

import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.openai.OpenAiEmbeddingOptions;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;

/**
* @author lambochen
*/
@ConfigurationProperties(OpenAiEmbeddingProperties.CONFIG_PREFIX)
public class OpenAiEmbeddingProperties extends OpenAiParentProperties {

public static final String CONFIG_PREFIX = "spring.ai.openai.embedding";

public static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002";

public static final String DEFAULT_EMBEDDINGS_PATH = "/v1/embeddings";
public static final String DEFAULT_EMBEDDINGS_PATH = OpenAiApiConstants.DEFAULT_EMBEDDINGS_PATH;

private MetadataMode metadataMode = MetadataMode.EMBED;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;

Expand All @@ -33,7 +34,7 @@ public class OpenAiImageProperties extends OpenAiParentProperties {

public static final String CONFIG_PREFIX = "spring.ai.openai.image";

public static final String DEFAULT_IMAGES_PATH = "v1/images/generations";
public static final String DEFAULT_IMAGES_PATH = OpenAiApiConstants.DEFAULT_IMAGES_PATH;

private String imagesPath = DEFAULT_IMAGES_PATH;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
* @author Stefan Vassilev
* @author Thomas Vitale
* @author Ilayaperumal Gopinathan
* @author lambochen
*/
@AutoConfiguration(after = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class,
SpringAiRetryAutoConfiguration.class })
Expand All @@ -70,6 +71,7 @@ public OpenAiModerationModel openAiModerationModel(OpenAiConnectionProperties co
.baseUrl(resolved.baseUrl())
.apiKey(new SimpleApiKey(resolved.apiKey()))
.headers(resolved.headers())
.moderationPath(moderationProperties.getModerationPath())
.restClientBuilder(restClientBuilderProvider.getIfAvailable(RestClient::builder))
.responseErrorHandler(responseErrorHandler)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
package org.springframework.ai.model.openai.autoconfigure;

import org.springframework.ai.openai.OpenAiModerationOptions;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;

/**
* OpenAI Moderation autoconfiguration properties.
*
* @author Ahmed Yousri
* @author lambochen
* @since 0.9.0
*/
@ConfigurationProperties(OpenAiModerationProperties.CONFIG_PREFIX)
Expand All @@ -37,6 +39,8 @@ public class OpenAiModerationProperties extends OpenAiParentProperties {
@NestedConfigurationProperty
private OpenAiModerationOptions options = OpenAiModerationOptions.builder().build();

private String moderationPath = OpenAiApiConstants.DEFAULT_MODERATION_PATH;

public OpenAiModerationOptions getOptions() {
return this.options;
}
Expand All @@ -45,4 +49,12 @@ public void setOptions(OpenAiModerationOptions options) {
this.options = options;
}

public String getModerationPath() {
return moderationPath;
}

public void setModerationPath(String moderationPath) {
this.moderationPath = moderationPath;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
* @author David Frizelle
* @author Alexandros Pappas
* @author Filip Hrisafov
* @author lambochen
*/
public class OpenAiApi {

Expand Down Expand Up @@ -1886,9 +1887,9 @@ public Builder(OpenAiApi api) {

private MultiValueMap<String, String> headers = new LinkedMultiValueMap<>();

private String completionsPath = "/v1/chat/completions";
private String completionsPath = OpenAiApiConstants.DEFAULT_COMPLETIONS_PATH;

private String embeddingsPath = "/v1/embeddings";
private String embeddingsPath = OpenAiApiConstants.DEFAULT_EMBEDDINGS_PATH;

private RestClient.Builder restClientBuilder = RestClient.builder();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
* @author Ilayaperumal Gopinathan
* @author Jonghoon Park
* @author Filip Hrisafov
* @author lambochen
* @since 0.8.1
*/
public class OpenAiAudioApi {
Expand All @@ -58,6 +59,12 @@ public class OpenAiAudioApi {

private final WebClient webClient;

private final String audioSpeechPath;

private final String audioTranscriptionPath;

private final String audioTranslationPath;

/**
* Create a new audio api.
* @param baseUrl api base URL.
Expand All @@ -70,6 +77,34 @@ public class OpenAiAudioApi {
public OpenAiAudioApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers,
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
ResponseErrorHandler responseErrorHandler) {
this(baseUrl, apiKey, headers, OpenAiApiConstants.DEFAULT_AUDIO_SPEECH_PATH,
OpenAiApiConstants.DEFAULT_AUDIO_TRANSCRIPTION_PATH, OpenAiApiConstants.DEFAULT_AUDIO_TRANSLATION_PATH,
restClientBuilder, webClientBuilder, responseErrorHandler);
}

/**
* Create a new audio api.
* @param baseUrl api base URL.
* @param apiKey OpenAI apiKey.
* @param headers the http headers to use.
* @param restClientBuilder RestClient builder.
* @param webClientBuilder WebClient builder.
* @param responseErrorHandler Response error handler.
* @param audioSpeechPath Audio speech path.
* @param audioTranscriptionPath Audio transcriptions path.
* @param audioTranslationPath Audio translations path.
*/
public OpenAiAudioApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers, String audioSpeechPath,
String audioTranscriptionPath, String audioTranslationPath, RestClient.Builder restClientBuilder,
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
Assert.hasText(baseUrl, "baseUrl cannot be null or empty");
Assert.hasText(audioSpeechPath, "audioSpeechPath cannot be null or empty");
Assert.hasText(audioTranscriptionPath, "audioTranscriptionPath cannot be null or empty");
Assert.hasText(audioTranslationPath, "audioTranslationPath cannot be null or empty");

this.audioSpeechPath = audioSpeechPath;
this.audioTranscriptionPath = audioTranscriptionPath;
this.audioTranslationPath = audioTranslationPath;

Consumer<HttpHeaders> authHeaders = h -> {
h.addAll(headers);
Expand Down Expand Up @@ -108,7 +143,7 @@ public static Builder builder() {
* @return Response entity containing the audio binary.
*/
public ResponseEntity<byte[]> createSpeech(SpeechRequest requestBody) {
return this.restClient.post().uri("/v1/audio/speech").body(requestBody).retrieve().toEntity(byte[].class);
return this.restClient.post().uri(this.audioSpeechPath).body(requestBody).retrieve().toEntity(byte[].class);
}

/**
Expand All @@ -125,7 +160,7 @@ public ResponseEntity<byte[]> createSpeech(SpeechRequest requestBody) {
public Flux<ResponseEntity<byte[]>> stream(SpeechRequest requestBody) {

return this.webClient.post()
.uri("/v1/audio/speech")
.uri(this.audioSpeechPath)
.body(Mono.just(requestBody), SpeechRequest.class)
.accept(MediaType.APPLICATION_OCTET_STREAM)
.exchangeToFlux(clientResponse -> {
Expand Down Expand Up @@ -175,7 +210,7 @@ public String getFilename() {
}

return this.restClient.post()
.uri("/v1/audio/transcriptions")
.uri(this.audioTranscriptionPath)
.body(multipartBody)
.retrieve()
.toEntity(responseType);
Expand Down Expand Up @@ -215,7 +250,7 @@ public String getFilename() {
multipartBody.add("temperature", requestBody.temperature());

return this.restClient.post()
.uri("/v1/audio/translations")
.uri(this.audioTranslationPath)
.body(multipartBody)
.retrieve()
.toEntity(responseType);
Expand Down Expand Up @@ -777,6 +812,12 @@ public static class Builder {

private String baseUrl = OpenAiApiConstants.DEFAULT_BASE_URL;

private String audioSpeechPath = OpenAiApiConstants.DEFAULT_AUDIO_SPEECH_PATH;

private String audioTranscriptionPath = OpenAiApiConstants.DEFAULT_AUDIO_TRANSCRIPTION_PATH;

private String audioTranslationPath = OpenAiApiConstants.DEFAULT_AUDIO_TRANSLATION_PATH;

private ApiKey apiKey;

private MultiValueMap<String, String> headers = new LinkedMultiValueMap<>();
Expand All @@ -793,6 +834,24 @@ public Builder baseUrl(String baseUrl) {
return this;
}

public Builder audioSpeechPath(String audioSpeechPath) {
Assert.hasText(audioSpeechPath, "audioSpeechPath cannot be null or empty");
this.audioSpeechPath = audioSpeechPath;
return this;
}

public Builder audioTranscriptionPath(String audioTranscriptionPath) {
Assert.hasText(audioTranscriptionPath, "audioTranscriptionPath cannot be null or empty");
this.audioTranscriptionPath = audioTranscriptionPath;
return this;
}

public Builder audioTranslationPath(String audioTranslationPath) {
Assert.hasText(audioTranslationPath, "audioTranslationPath cannot be null or empty");
this.audioTranslationPath = audioTranslationPath;
return this;
}

public Builder apiKey(ApiKey apiKey) {
Assert.notNull(apiKey, "apiKey cannot be null");
this.apiKey = apiKey;
Expand Down Expand Up @@ -831,7 +890,8 @@ public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) {

public OpenAiAudioApi build() {
Assert.notNull(this.apiKey, "apiKey must be set");
return new OpenAiAudioApi(this.baseUrl, this.apiKey, this.headers, this.restClientBuilder,
return new OpenAiAudioApi(this.baseUrl, this.apiKey, this.headers, this.audioSpeechPath,
this.audioTranscriptionPath, this.audioTranslationPath, this.restClientBuilder,
this.webClientBuilder, this.responseErrorHandler);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ public static class Builder {

private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER;

private String imagesPath = "v1/images/generations";
private String imagesPath = OpenAiApiConstants.DEFAULT_IMAGES_PATH;

public Builder baseUrl(String baseUrl) {
Assert.hasText(baseUrl, "baseUrl cannot be null or empty");
Expand Down
Loading