Skip to content

Commit 6d6fb38

Browse files
authored
feat: add InvokeLLMWithStructuredOutput functionality (#369)
* feat: add InvokeLLMWithStructuredOutput functionality - Introduced a new method InvokeLLMWithStructuredOutput to the BackwardsInvocation interface for handling structured output requests. - Added corresponding request and response types to support structured output. - Implemented the method in both RealBackwardsInvocation and MockedDifyInvocation for testing purposes. - Updated permission handling and task execution for the new structured output invocation type. This enhancement allows for more flexible and detailed responses from the LLM, improving the overall functionality of the invocation system. * refactor: enhance LLMResultChunkWithStructuredOutput structure - Updated the LLMResultChunkWithStructuredOutput type to include additional fields: Model, SystemFingerprint, and Delta. - Added comments to clarify the reasoning behind the structure and the use of type embedding for JSON marshaling. This change improves the clarity and functionality of the LLMResultChunkWithStructuredOutput type, ensuring proper JSON serialization. * refactor: streamline LLMResultChunk construction in InvokeLLMWithStructuredOutput - Simplified the construction of LLMResultChunk and LLMResultChunkWithStructuredOutput by removing unnecessary type embedding. - Enhanced readability and maintainability of the code while preserving functionality. This change contributes to cleaner code and improved clarity in the handling of structured output responses.
1 parent 7a7848b commit 6d6fb38

File tree

6 files changed

+120
-0
lines changed

6 files changed

+120
-0
lines changed

internal/core/dify_invocation/invcation.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import (
99
type BackwardsInvocation interface {
1010
// InvokeLLM
1111
InvokeLLM(payload *InvokeLLMRequest) (*stream.Stream[model_entities.LLMResultChunk], error)
12+
// InvokeLLMWithStructuredOutput
13+
InvokeLLMWithStructuredOutput(payload *InvokeLLMWithStructuredOutputRequest) (
14+
*stream.Stream[model_entities.LLMResultChunkWithStructuredOutput], error)
1215
// InvokeTextEmbedding
1316
InvokeTextEmbedding(payload *InvokeTextEmbeddingRequest) (*model_entities.TextEmbeddingResult, error)
1417
// InvokeRerank

internal/core/dify_invocation/real/http_request.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ func (i *RealBackwardsInvocation) InvokeLLM(payload *dify_invocation.InvokeLLMRe
115115
return StreamResponse[model_entities.LLMResultChunk](i, "POST", "invoke/llm", http_requests.HttpPayloadJson(payload))
116116
}
117117

118+
func (i *RealBackwardsInvocation) InvokeLLMWithStructuredOutput(payload *dify_invocation.InvokeLLMWithStructuredOutputRequest) (
119+
*stream.Stream[model_entities.LLMResultChunkWithStructuredOutput], error,
120+
) {
121+
return StreamResponse[model_entities.LLMResultChunkWithStructuredOutput](i, "POST", "/invoke/llm/structured-output", http_requests.HttpPayloadJson(payload))
122+
}
123+
118124
func (i *RealBackwardsInvocation) InvokeTextEmbedding(payload *dify_invocation.InvokeTextEmbeddingRequest) (*model_entities.TextEmbeddingResult, error) {
119125
return Request[model_entities.TextEmbeddingResult](i, "POST", "invoke/text-embedding", http_requests.HttpPayloadJson(payload))
120126
}

internal/core/dify_invocation/tester/mock.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,62 @@ func (m *MockedDifyInvocation) InvokeLLM(payload *dify_invocation.InvokeLLMReque
136136
return stream, nil
137137
}
138138

139+
func (m *MockedDifyInvocation) InvokeLLMWithStructuredOutput(payload *dify_invocation.InvokeLLMWithStructuredOutputRequest) (
140+
*stream.Stream[model_entities.LLMResultChunkWithStructuredOutput], error,
141+
) {
142+
// generate json from payload.StructuredOutputSchema
143+
structuredOutput, err := jsonschema.GenerateValidateJson(payload.StructuredOutputSchema)
144+
if err != nil {
145+
return nil, err
146+
}
147+
148+
// marshal jsonSchema to string
149+
structuredOutputString := parser.MarshalJson(structuredOutput)
150+
151+
// split structuredOutputString into 10 parts and write them to the stream
152+
parts := []string{}
153+
for i := 0; i < 10; i++ {
154+
start := i * len(structuredOutputString) / 10
155+
end := (i + 1) * len(structuredOutputString) / 10
156+
if i == 9 { // last part
157+
end = len(structuredOutputString)
158+
}
159+
parts = append(parts, structuredOutputString[start:end])
160+
}
161+
162+
stream := stream.NewStream[model_entities.LLMResultChunkWithStructuredOutput](11)
163+
routine.Submit(nil, func() {
164+
for i, part := range parts {
165+
stream.Write(model_entities.LLMResultChunkWithStructuredOutput{
166+
Model: model_entities.LLMModel(payload.Model),
167+
SystemFingerprint: "test",
168+
Delta: model_entities.LLMResultChunkDelta{
169+
Index: &[]int{i}[0],
170+
Message: model_entities.PromptMessage{
171+
Role: model_entities.PROMPT_MESSAGE_ROLE_ASSISTANT,
172+
Content: part,
173+
Name: "test",
174+
ToolCalls: []model_entities.PromptMessageToolCall{},
175+
},
176+
},
177+
})
178+
}
179+
// write the last part
180+
stream.Write(model_entities.LLMResultChunkWithStructuredOutput{
181+
Model: model_entities.LLMModel(payload.Model),
182+
SystemFingerprint: "test",
183+
Delta: model_entities.LLMResultChunkDelta{
184+
Index: &[]int{10}[0],
185+
},
186+
LLMStructuredOutput: model_entities.LLMStructuredOutput{
187+
StructuredOutput: structuredOutput,
188+
},
189+
})
190+
stream.Close()
191+
})
192+
return stream, nil
193+
}
194+
139195
func (m *MockedDifyInvocation) InvokeTextEmbedding(payload *dify_invocation.InvokeTextEmbeddingRequest) (*model_entities.TextEmbeddingResult, error) {
140196
result := model_entities.TextEmbeddingResult{
141197
Model: payload.Model,

internal/core/dify_invocation/types.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ type InvokeType string
1818

1919
const (
2020
INVOKE_TYPE_LLM InvokeType = "llm"
21+
INVOKE_TYPE_LLM_STRUCTURED_OUTPUT InvokeType = "llm_structured_output"
2122
INVOKE_TYPE_TEXT_EMBEDDING InvokeType = "text_embedding"
2223
INVOKE_TYPE_RERANK InvokeType = "rerank"
2324
INVOKE_TYPE_TTS InvokeType = "tts"
@@ -51,6 +52,15 @@ type InvokeLLMRequest struct {
5152
InvokeLLMSchema
5253
}
5354

55+
type InvokeLLMWithStructuredOutputRequest struct {
56+
BaseInvokeDifyRequest
57+
requests.BaseRequestInvokeModel
58+
// requests.InvokeLLMSchema
59+
// TODO: as completion_params in requests.InvokeLLMSchema is "model_parameters", we declare another one here
60+
InvokeLLMSchema
61+
StructuredOutputSchema map[string]any `json:"structured_output_schema" validate:"required"`
62+
}
63+
5464
type InvokeTextEmbeddingRequest struct {
5565
BaseInvokeDifyRequest
5666
requests.BaseRequestInvokeModel

internal/core/plugin_daemon/backwards_invocation/task.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,12 @@ var (
154154
},
155155
"error": "permission denied, you need to enable app access in plugin manifest",
156156
},
157+
dify_invocation.INVOKE_TYPE_LLM_STRUCTURED_OUTPUT: {
158+
"func": func(declaration *plugin_entities.PluginDeclaration) bool {
159+
return declaration.Resource.Permission.AllowInvokeLLM()
160+
},
161+
"error": "permission denied, you need to enable llm access in plugin manifest",
162+
},
157163
}
158164
)
159165

@@ -250,6 +256,9 @@ var (
250256
dify_invocation.INVOKE_TYPE_FETCH_APP: func(handle *BackwardsInvocation) {
251257
genericDispatchTask(handle, executeDifyInvocationFetchAppTask)
252258
},
259+
dify_invocation.INVOKE_TYPE_LLM_STRUCTURED_OUTPUT: func(handle *BackwardsInvocation) {
260+
genericDispatchTask(handle, executeDifyInvocationLLMStructuredOutputTask)
261+
},
253262
}
254263
)
255264

@@ -337,6 +346,26 @@ func executeDifyInvocationLLMTask(
337346
}
338347
}
339348

349+
func executeDifyInvocationLLMStructuredOutputTask(
350+
handle *BackwardsInvocation,
351+
request *dify_invocation.InvokeLLMWithStructuredOutputRequest,
352+
) {
353+
response, err := handle.backwardsInvocation.InvokeLLMWithStructuredOutput(request)
354+
if err != nil {
355+
handle.WriteError(fmt.Errorf("invoke llm with structured output model failed: %s", err.Error()))
356+
return
357+
}
358+
359+
for response.Next() {
360+
value, err := response.Read()
361+
if err != nil {
362+
handle.WriteError(fmt.Errorf("read llm with structured output model failed: %s", err.Error()))
363+
return
364+
}
365+
handle.WriteResponse("stream", value)
366+
}
367+
}
368+
340369
func executeDifyInvocationTextEmbeddingTask(
341370
handle *BackwardsInvocation,
342371
request *dify_invocation.InvokeTextEmbeddingRequest,

pkg/entities/model_entities/llm.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,22 @@ type LLMResultChunk struct {
188188
Delta LLMResultChunkDelta `json:"delta" validate:"required"`
189189
}
190190

191+
type LLMStructuredOutput struct {
192+
StructuredOutput map[string]any `json:"structured_output" validate:"omitempty"`
193+
}
194+
195+
type LLMResultChunkWithStructuredOutput struct {
196+
// You might argue that why not embed LLMResultChunk directly?
197+
// `LLMResultChunk` has implemented interface `MarshalJSON`, due to Golang's type embedding,
198+
// it also effectively implements the `MarshalJSON` method of `LLMResultChunkWithStructuredOutput`,
199+
// resulting in a unexpected JSON marshaling of `LLMResultChunkWithStructuredOutput`
200+
Model LLMModel `json:"model" validate:"required"`
201+
SystemFingerprint string `json:"system_fingerprint" validate:"omitempty"`
202+
Delta LLMResultChunkDelta `json:"delta" validate:"required"`
203+
204+
LLMStructuredOutput
205+
}
206+
191207
/*
192208
This is a compatibility layer for the old LLMResultChunk format.
193209
The old one has the `PromptMessages` field, we need to ensure the new one is backward compatible.

0 commit comments

Comments
 (0)