@@ -13,6 +13,15 @@ def validate_float(value):
13
13
raise ValueError ("Value must be convertible to a float" )
14
14
15
15
16
+ def validate_int (value ):
17
+ if type (value ) == int :
18
+ return value
19
+ try :
20
+ return int (value .strip ("'" ).strip ("\" " ))
21
+ except (TypeError , ValueError ):
22
+ raise ValueError ("Value must be convertible to an integer" )
23
+
24
+
16
25
class WeaviateSettings (BaseSettings ):
17
26
weaviate_uri : Optional [str ] = Field (
18
27
default = "localhost:8080" ,
@@ -60,3 +69,32 @@ def get_weaviate_grpc_uri(self):
60
69
61
70
def get_weaviate_grpc_port (self ):
62
71
return int (self .weaviate_grpc_uri .split (":" )[1 ])
72
+
73
+
74
+ class LlmSettings (BaseSettings ):
75
+ llm_server_url : Optional [str ] = Field (
76
+ default = "http://localhost:9000/v1" ,
77
+ alias = "MODEL_LLM_SERVER_URL" ,
78
+ )
79
+ model_id : Optional [str ] = Field (
80
+ default = "rubra-ai/Phi-3-mini-128k-instruct" ,
81
+ alias = "MODEL_ID" ,
82
+ )
83
+ max_tokens : Optional [int ] = Field (
84
+ default = 256 ,
85
+ alias = "MAX_TOKENS" ,
86
+ )
87
+ model_temperature : Optional [float ] = Field (
88
+ default = 0.01 ,
89
+ alias = "MODEL_TEMPERATURE" ,
90
+ )
91
+
92
+ @field_validator ("max_tokens" , mode = "before" )
93
+ @classmethod
94
+ def validate_max_tokens (cls , v ):
95
+ return validate_int (v )
96
+
97
+ @field_validator ("model_temperature" , mode = "before" )
98
+ @classmethod
99
+ def validate_model_temperature (cls , v ):
100
+ return validate_float (v )
0 commit comments