@@ -42,7 +42,20 @@ class SafetensorParseError extends Error {}
42
42
type FileName = string ;
43
43
44
44
export type TensorName = string ;
45
- export type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL" ;
45
+ export type Dtype =
46
+ | "F64"
47
+ | "F32"
48
+ | "F16"
49
+ | "F8_E4M3"
50
+ | "F8_E5M2"
51
+ | "BF16"
52
+ | "I64"
53
+ | "I32"
54
+ | "I16"
55
+ | "I8"
56
+ | "U16"
57
+ | "U8"
58
+ | "BOOL" ;
46
59
47
60
export interface TensorInfo {
48
61
dtype : Dtype ;
@@ -51,13 +64,13 @@ export interface TensorInfo {
51
64
}
52
65
53
66
export type SafetensorsFileHeader = Record < TensorName , TensorInfo > & {
54
- __metadata__ : Record < string , string > ;
67
+ __metadata__ : { total_parameters ?: string | number } & Record < string , string > ;
55
68
} ;
56
69
57
70
export interface SafetensorsIndexJson {
58
71
dtype ?: string ;
59
72
/// ^there's sometimes a dtype but it looks inconsistent.
60
- metadata ?: Record < string , string > ;
73
+ metadata ?: { total_parameters ?: string | number } & Record < string , string > ;
61
74
/// ^ why the naming inconsistency?
62
75
weight_map : Record < TensorName , FileName > ;
63
76
}
@@ -69,12 +82,14 @@ export type SafetensorsParseFromRepo =
69
82
sharded : false ;
70
83
header : SafetensorsFileHeader ;
71
84
parameterCount ?: Partial < Record < Dtype , number > > ;
85
+ parameterTotal ?: number ;
72
86
}
73
87
| {
74
88
sharded : true ;
75
89
index : SafetensorsIndexJson ;
76
90
headers : SafetensorsShardedHeaders ;
77
91
parameterCount ?: Partial < Record < Dtype , number > > ;
92
+ parameterTotal ?: number ;
78
93
} ;
79
94
80
95
async function parseSingleFile (
@@ -127,7 +142,7 @@ async function parseShardedIndex(
127
142
*/
128
143
fetch ?: typeof fetch ;
129
144
} & Partial < CredentialsParams >
130
- ) : Promise < { index : SafetensorsIndexJson ; headers : SafetensorsShardedHeaders } > {
145
+ ) : Promise < SafetensorsIndexJson > {
131
146
const indexBlob = await downloadFile ( {
132
147
...params ,
133
148
path,
@@ -137,14 +152,28 @@ async function parseShardedIndex(
137
152
throw new SafetensorParseError ( `Failed to parse file ${ path } : failed to fetch safetensors index.` ) ;
138
153
}
139
154
140
- // no validation for now, we assume it's a valid IndexJson.
141
- let index : SafetensorsIndexJson ;
142
155
try {
143
- index = JSON . parse ( await indexBlob . slice ( 0 , 10_000_000 ) . text ( ) ) ;
156
+ // no validation for now, we assume it's a valid IndexJson.
157
+ const index = JSON . parse ( await indexBlob . slice ( 0 , 10_000_000 ) . text ( ) ) ;
158
+ return index ;
144
159
} catch ( error ) {
145
160
throw new SafetensorParseError ( `Failed to parse file ${ path } : not a valid JSON.` ) ;
146
161
}
162
+ }
147
163
164
+ async function fetchAllHeaders (
165
+ path : string ,
166
+ index : SafetensorsIndexJson ,
167
+ params : {
168
+ repo : RepoDesignation ;
169
+ revision ?: string ;
170
+ hubUrl ?: string ;
171
+ /**
172
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
173
+ */
174
+ fetch ?: typeof fetch ;
175
+ } & Partial < CredentialsParams >
176
+ ) : Promise < SafetensorsShardedHeaders > {
148
177
const pathPrefix = path . slice ( 0 , path . lastIndexOf ( "/" ) + 1 ) ;
149
178
const filenames = [ ...new Set ( Object . values ( index . weight_map ) ) ] ;
150
179
const shardedMap : SafetensorsShardedHeaders = Object . fromEntries (
@@ -156,7 +185,7 @@ async function parseShardedIndex(
156
185
PARALLEL_DOWNLOADS
157
186
)
158
187
) ;
159
- return { index , headers : shardedMap } ;
188
+ return shardedMap ;
160
189
}
161
190
162
191
/**
@@ -189,12 +218,12 @@ export async function parseSafetensorsMetadata(
189
218
params : {
190
219
/** Only models are supported */
191
220
repo : RepoDesignation ;
221
+ path ?: string ;
192
222
/**
193
223
* Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType
194
224
*
195
225
* @default false
196
226
*/
197
- path ?: string ;
198
227
computeParametersCount ?: boolean ;
199
228
hubUrl ?: string ;
200
229
revision ?: string ;
@@ -223,27 +252,55 @@ export async function parseSafetensorsMetadata(
223
252
throw new TypeError ( "Only model repos should contain safetensors files." ) ;
224
253
}
225
254
226
- if ( RE_SAFETENSORS_FILE . test ( params . path ?? "" ) || ( await fileExists ( { ...params , path : SAFETENSORS_FILE } ) ) ) {
255
+ if (
256
+ ( params . path && RE_SAFETENSORS_FILE . test ( params . path ) ) ||
257
+ ( await fileExists ( { ...params , path : SAFETENSORS_FILE } ) )
258
+ ) {
227
259
const header = await parseSingleFile ( params . path ?? SAFETENSORS_FILE , params ) ;
228
260
return {
229
261
sharded : false ,
230
262
header,
231
- ...( params . computeParametersCount && {
232
- parameterCount : computeNumOfParamsByDtypeSingleFile ( header ) ,
233
- } ) ,
263
+ ...( params . computeParametersCount
264
+ ? {
265
+ parameterCount : computeNumOfParamsByDtypeSingleFile ( header ) ,
266
+ parameterTotal :
267
+ /// shortcut: get param count directly from metadata
268
+ header . __metadata__ . total_parameters
269
+ ? typeof header . __metadata__ . total_parameters === "number"
270
+ ? header . __metadata__ . total_parameters
271
+ : typeof header . __metadata__ . total_parameters === "string"
272
+ ? parseInt ( header . __metadata__ . total_parameters )
273
+ : undefined
274
+ : undefined ,
275
+ }
276
+ : undefined ) ,
234
277
} ;
235
278
} else if (
236
- RE_SAFETENSORS_INDEX_FILE . test ( params . path ?? "" ) ||
279
+ ( params . path && RE_SAFETENSORS_INDEX_FILE . test ( params . path ) ) ||
237
280
( await fileExists ( { ...params , path : SAFETENSORS_INDEX_FILE } ) )
238
281
) {
239
- const { index, headers } = await parseShardedIndex ( params . path ?? SAFETENSORS_INDEX_FILE , params ) ;
282
+ const path = params . path ?? SAFETENSORS_INDEX_FILE ;
283
+ const index = await parseShardedIndex ( path , params ) ;
284
+ const shardedMap = await fetchAllHeaders ( path , index , params ) ;
285
+
240
286
return {
241
287
sharded : true ,
242
288
index,
243
- headers,
244
- ...( params . computeParametersCount && {
245
- parameterCount : computeNumOfParamsByDtypeSharded ( headers ) ,
246
- } ) ,
289
+ headers : shardedMap ,
290
+ ...( params . computeParametersCount
291
+ ? {
292
+ parameterCount : computeNumOfParamsByDtypeSharded ( shardedMap ) ,
293
+ parameterTotal :
294
+ /// shortcut: get param count directly from metadata
295
+ index . metadata ?. total_parameters
296
+ ? typeof index . metadata . total_parameters === "number"
297
+ ? index . metadata . total_parameters
298
+ : typeof index . metadata . total_parameters === "string"
299
+ ? parseInt ( index . metadata . total_parameters )
300
+ : undefined
301
+ : undefined ,
302
+ }
303
+ : undefined ) ,
247
304
} ;
248
305
} else {
249
306
throw new Error ( "model id does not seem to contain safetensors weights" ) ;
0 commit comments