Skip to content

Commit 97c9733

Browse files
committed
[FR] Add support for structured extraction with Ollama models
Fixes #68
1 parent 1cda053 commit 97c9733

File tree

2 files changed

+255
-3
lines changed

2 files changed

+255
-3
lines changed

src/extraction.jl

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,24 @@
66
# There are potential formats: 1) JSON-based for OpenAI compatible APIs, 2) XML-based for Anthropic compatible APIs (used also by Hermes-2-Pro model).
77
#
88

9+
"""
10+
JSON_PRIMITIVE_TYPES
11+
12+
A set of primitive types that are supported by JSON. If a type
13+
is not in this set, the JSON typer [`to_json_type`](@ref) will
14+
assume that the type is a `struct` and will attempt to recursively
15+
unpack the fields of the struct.
16+
"""
17+
const JSON_PRIMITIVE_TYPES = Union{
18+
Integer,
19+
Real,
20+
AbstractString,
21+
Bool,
22+
Nothing,
23+
Missing,
24+
AbstractArray
25+
}
26+
927
######################
1028
# 1) OpenAI / JSON format
1129
######################
@@ -15,7 +33,14 @@ to_json_type(n::Type{<:Real}) = "number"
1533
to_json_type(n::Type{<:Integer}) = "integer"
1634
to_json_type(b::Type{Bool}) = "boolean"
1735
to_json_type(t::Type{<:Union{Missing, Nothing}}) = "null"
18-
to_json_type(t::Type{<:Any}) = "string" # object?
36+
to_json_type(t::Type{T}) where {T <: AbstractArray} = to_json_type(eltype(t)) * "[]"
37+
to_json_type(t::Type{Any}) = throw(ArgumentError("""
38+
Type $t is not a valid type for to_json_type. Please provide a valid type found in:
39+
40+
$JSON_PRIMITIVE_TYPES
41+
42+
You may be using to_json_schema but forgot to properly type the fields of your struct.
43+
"""))
1944

2045
has_null_type(T::Type{Missing}) = true
2146
has_null_type(T::Type{Nothing}) = true
@@ -236,3 +261,108 @@ Extract zero, one or more specified items from the provided data.
236261
struct ItemsExtract{T <: Any}
237262
items::Vector{T}
238263
end
264+
265+
"""
266+
typed_json_schema(x::Type{T}) where {T}
267+
268+
Convert a Julia type to a JSON schema that lists keys as field names and values as
269+
the types of those field names.
270+
271+
WARNING! Every field in your struct, and all nested structs, must be typed using a subtype of values in [`JSON_PRIMITIVE_TYPES`](@ref)
272+
before calling this function. Otherwise, you will get a recursion error.
273+
274+
## Example
275+
276+
```julia
277+
# Simple flat structure where each field is a primitive type
278+
struct SimpleSingleton
279+
singleton_value::Int
280+
end
281+
282+
typed_json_schema(SimpleSingleton)
283+
```
284+
285+
```
286+
Dict{Any, Any} with 1 entry:
287+
:singleton_value => "integer"
288+
```
289+
290+
Or using nested structs
291+
292+
```julia
293+
# Test a struct that contains another struct.
294+
struct Nested
295+
inside_element::SimpleSingleton
296+
end
297+
298+
typed_json_schema(Nested)
299+
```
300+
301+
```julia
302+
Dict{Any, Any} with 1 entry:
303+
:inside_element => Dict{Any, Any}("singleton_value" => "integer")
304+
```
305+
306+
Lists of created Julia types will be specified as `List[Object]` with the value being the type of the elements,
307+
i.e.
308+
309+
```julia
310+
# Test a struct with a vector of primitives
311+
struct ABunchOfVectors
312+
strings::Vector{String}
313+
ints::Vector{Int}
314+
floats::Vector{Float64}
315+
nested_vector::Vector{Nested}
316+
end
317+
318+
typed_json_schema(ABunchOfVectors)
319+
```
320+
321+
```
322+
Dict{Any, Any} with 4 entries:
323+
:strings => "string[]"
324+
:ints => "integer[]"
325+
:nested_vector => Dict("list[Object]"=>"{\"inside_element\":{\"singleton_value\":\"integer\"}}")
326+
:floats => "number[]"
327+
```
328+
329+
## Resources
330+
- the [original issue](https://github.com/svilupp/PromptingTools.jl/issues/143)
331+
- the [motivation](https://www.boundaryml.com/blog/type-definition-prompting-baml)
332+
"""
333+
function typed_json_schema(x::Type{T}) where {T}
334+
# We can return early if the type is a non-array primitive
335+
if T <: JSON_PRIMITIVE_TYPES && !(T <: AbstractArray)
336+
return to_json_type(T)
337+
end
338+
339+
# If there are no fields, return the type
340+
if isempty(fieldnames(T))
341+
# Check if this is a vector type. If so, return the type of the elements.
342+
if T <: AbstractArray
343+
# Now check if the element type is a non-primitive. If so, recursively call typed_json_schema.
344+
if eltype(T) <: JSON_PRIMITIVE_TYPES
345+
return to_json_type(T)
346+
else
347+
return Dict("list[Object]" => JSON3.write(typed_json_schema(eltype(T))))
348+
# return "List[" * JSON3.write(typed_json_schema(eltype(T))) * "]"
349+
end
350+
end
351+
352+
# Check if the type is a non-primitive.
353+
if T <: JSON_PRIMITIVE_TYPES
354+
return to_json_type(T)
355+
else
356+
return typed_json_schema(T)
357+
end
358+
end
359+
360+
# Preallocate a mapping
361+
mapping = Dict()
362+
for (type, field) in zip(T.types, fieldnames(T))
363+
mapping[field] = typed_json_schema(type)
364+
end
365+
366+
# Get property names
367+
return mapping
368+
end

test/extraction.jl

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,11 @@ end
225225
end
226226
output = function_call_signature(MyMeasurement2)#|> JSON3.pretty
227227
expected_output = Dict{String, Any}("name" => "MyMeasurement2_extractor",
228-
"parameters" => Dict{String, Any}("properties" => Dict{String, Any}("height" => Dict{
228+
"parameters" => Dict{String, Any}(
229+
"properties" => Dict{String, Any}(
230+
"height" => Dict{
229231
String,
230-
Any,
232+
Any
231233
}("type" => "integer"),
232234
"weight" => Dict{String, Any}("type" => "number"),
233235
"age" => Dict{String, Any}("type" => "integer")),
@@ -240,3 +242,123 @@ end
240242
schema = function_call_signature(MaybeExtract{MyMeasurement2})
241243
@test schema["name"] == "MaybeExtractMyMeasurement2_extractor"
242244
end
245+
@testset "to_json_schema-primitive_types" begin
246+
@test to_json_schema(Int) == Dict("type" => "integer")
247+
@test to_json_schema(Float64) == Dict("type" => "number")
248+
@test to_json_schema(Bool) == Dict("type" => "boolean")
249+
@test to_json_schema(String) == Dict("type" => "string")
250+
@test_throws ArgumentError to_json_schema(Any) # Type Any is not supported
251+
end
252+
@testset "to_json_schema-structs" begin
253+
# Function to check the equivalence of two JSON strings, since Dict is
254+
# unordered, we need to sort keys before comparison.
255+
function check_json_equivalence(json1::AbstractString, json2::AbstractString)
256+
println("\ncheck_json_equivalence\n===json1===")
257+
println(json1)
258+
println("===json2===")
259+
println(json2)
260+
println()
261+
# JSON dictionary
262+
d1 = JSON3.read(json1)
263+
d2 = JSON3.read(json2)
264+
265+
# Get all the keys
266+
k1 = sort(collect(keys(d1)))
267+
k2 = sort(collect(keys(d2)))
268+
269+
# Test that all the keys are present
270+
@test setdiff(k1, k2) == []
271+
@test setdiff(k2, k1) == []
272+
273+
# Test that all the values are equivalent
274+
for (k, v) in d1
275+
@test d2[k] == v
276+
end
277+
278+
# @test JSON3.write(JSON3.read(json1)) == JSON3.write(JSON3.read(json2))
279+
end
280+
function check_json_equivalence(d::Dict, s::AbstractString)
281+
return check_json_equivalence(JSON3.write(d), s)
282+
end
283+
284+
# Simple flat structure where each field is a primitive type
285+
struct SimpleSingleton
286+
singleton_value::Int
287+
end
288+
289+
check_json_equivalence(
290+
JSON3.write(typed_json_schema(SimpleSingleton)),
291+
"{\"singleton_value\":\"integer\"}"
292+
)
293+
294+
# Test a struct that contains another struct.
295+
struct Nested
296+
inside_element::SimpleSingleton
297+
end
298+
299+
check_json_equivalence(
300+
JSON3.write(typed_json_schema(Nested)),
301+
"{\"inside_element\":{\"singleton_value\":\"integer\"}}"
302+
)
303+
304+
# Test a struct with two primitive types
305+
struct IntFloatFlat
306+
int_value::Int
307+
float_value::Float64
308+
end
309+
check_json_equivalence(
310+
typed_json_schema(IntFloatFlat),
311+
"{\"int_value\":\"integer\",\"float_value\":\"number\"}"
312+
)
313+
314+
# Test a struct that contains all primitive types
315+
struct AllJSONPrimitives
316+
int::Integer
317+
float::Real
318+
string::AbstractString
319+
bool::Bool
320+
nothing::Nothing
321+
missing::Missing
322+
323+
# Array types
324+
array_of_strings::Vector{String}
325+
array_of_ints::Vector{Int}
326+
array_of_floats::Vector{Float64}
327+
array_of_bools::Vector{Bool}
328+
array_of_nothings::Vector{Nothing}
329+
array_of_missings::Vector{Missing}
330+
end
331+
332+
check_json_equivalence(
333+
typed_json_schema(AllJSONPrimitives),
334+
"{\"int\":\"integer\",\"float\":\"number\",\"string\":\"string\",\"bool\":\"boolean\",\"nothing\":\"null\",\"missing\":\"null\",\"array_of_strings\":\"string[]\",\"array_of_ints\":\"integer[]\",\"array_of_floats\":\"number[]\",\"array_of_bools\":\"boolean[]\",\"array_of_nothings\":\"null[]\",\"array_of_missings\":\"null[]\"}"
335+
)
336+
337+
# Test a struct with a vector of primitives
338+
struct ABunchOfVectors
339+
strings::Vector{String}
340+
ints::Vector{Int}
341+
floats::Vector{Float64}
342+
nested_vector::Vector{Nested}
343+
end
344+
345+
check_json_equivalence(
346+
typed_json_schema(ABunchOfVectors),
347+
"{\"strings\":\"string[]\",\"ints\":\"integer[]\",\"nested_vector\":{\"list[Object]\":\"{\\\"inside_element\\\":{\\\"singleton_value\\\":\\\"integer\\\"}}\"},\"floats\":\"number[]\"}"
348+
)
349+
350+
# Weird struct with a bunch of different types
351+
struct Monster
352+
name::String
353+
age::Int
354+
height::Float64
355+
friends::Vector{String}
356+
nested::Nested
357+
flat::IntFloatFlat
358+
end
359+
360+
check_json_equivalence(
361+
typed_json_schema(Monster),
362+
"{\"flat\":{\"float_value\":\"number\",\"int_value\":\"integer\"},\"nested\":{\"inside_element\":{\"singleton_value\":\"integer\"}},\"age\":\"integer\",\"name\":\"string\",\"height\":\"number\",\"friends\":\"string[]\"}"
363+
)
364+
end;

0 commit comments

Comments
 (0)