Skip to content

Commit fd39000

Browse files
authored
Safetensors now optionally exposes total params (#1509)
1 parent 7cb9528 commit fd39000

File tree

6 files changed

+111
-31
lines changed

6 files changed

+111
-31
lines changed

packages/hub/src/lib/file-download-info.spec.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ describe("fileDownloadInfo", () => {
55
it("should fetch LFS file info", async () => {
66
const info = await fileDownloadInfo({
77
repo: {
8-
name: "bert-base-uncased",
8+
name: "google-bert/bert-base-uncased",
99
type: "model",
1010
},
1111
path: "tf_model.h5",
@@ -19,7 +19,7 @@ describe("fileDownloadInfo", () => {
1919
it("should fetch raw LFS pointer info", async () => {
2020
const info = await fileDownloadInfo({
2121
repo: {
22-
name: "bert-base-uncased",
22+
name: "google-bert/bert-base-uncased",
2323
type: "model",
2424
},
2525
path: "tf_model.h5",
@@ -34,7 +34,7 @@ describe("fileDownloadInfo", () => {
3434
it("should fetch non-LFS file info", async () => {
3535
const info = await fileDownloadInfo({
3636
repo: {
37-
name: "bert-base-uncased",
37+
name: "google-bert/bert-base-uncased",
3838
type: "model",
3939
},
4040
path: "tokenizer_config.json",

packages/hub/src/lib/file-exists.spec.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ describe("fileExists", () => {
55
it("should return true for file that exists", async () => {
66
const info = await fileExists({
77
repo: {
8-
name: "bert-base-uncased",
8+
name: "google-bert/bert-base-uncased",
99
type: "model",
1010
},
1111
path: "tf_model.h5",
@@ -18,7 +18,7 @@ describe("fileExists", () => {
1818
it("should return false for file that does not exist", async () => {
1919
const info = await fileExists({
2020
repo: {
21-
name: "bert-base-uncased",
21+
name: "google-bert/bert-base-uncased",
2222
type: "model",
2323
},
2424
path: "tf_model.h5dadazdzazd",

packages/hub/src/lib/list-files.spec.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ describe("listFiles", () => {
66
it("should fetch the list of files from the repo", async () => {
77
const cursor = listFiles({
88
repo: {
9-
name: "bert-base-uncased",
9+
name: "google-bert/bert-base-uncased",
1010
type: "model",
1111
},
1212
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
@@ -67,7 +67,7 @@ describe("listFiles", () => {
6767
it("should fetch the list of files from the repo, including last commit", async () => {
6868
const cursor = listFiles({
6969
repo: {
70-
name: "bert-base-uncased",
70+
name: "google-bert/bert-base-uncased",
7171
type: "model",
7272
},
7373
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",

packages/hub/src/lib/parse-safetensors-metadata.spec.ts

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { sum } from "../utils/sum";
55
describe("parseSafetensorsMetadata", () => {
66
it("fetch info for single-file (with the default conventional filename)", async () => {
77
const parse = await parseSafetensorsMetadata({
8-
repo: "bert-base-uncased",
8+
repo: "google-bert/bert-base-uncased",
99
computeParametersCount: true,
1010
revision: "86b5e0934494bd15c9632b12f734a8a67f723594",
1111
});
@@ -88,7 +88,7 @@ describe("parseSafetensorsMetadata", () => {
8888
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 859_520_964);
8989
});
9090

91-
it("fetch info for sharded (with the default conventional filename) with file path", async () => {
91+
it("fetch info for sharded with file path", async () => {
9292
const parse = await parseSafetensorsMetadata({
9393
repo: "Alignment-Lab-AI/ALAI-gemma-7b",
9494
computeParametersCount: true,
@@ -110,6 +110,29 @@ describe("parseSafetensorsMetadata", () => {
110110
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896);
111111
});
112112

113+
it("fetch info for sharded, but get param count directly from metadata", async () => {
114+
const parse = await parseSafetensorsMetadata({
115+
repo: "hf-internal-testing/sharded-model-metadata-num-parameters",
116+
computeParametersCount: true,
117+
revision: "999395eb3db277f3d7a0393402b02486ca91cef8",
118+
});
119+
120+
assert(parse.sharded);
121+
assert.deepStrictEqual(parse.parameterTotal, 109_482_240);
122+
// total params = 109M
123+
});
124+
125+
it("fetch info for single-file, but get param count directly from metadata", async () => {
126+
const parse = await parseSafetensorsMetadata({
127+
repo: "hf-internal-testing/single-file-model",
128+
computeParametersCount: true,
129+
revision: "75fcd3fed0285ac7f1092897ff2aefdf24bf872e",
130+
});
131+
132+
assert(!parse.sharded);
133+
assert.deepStrictEqual(parse.parameterTotal, 109_482_240);
134+
});
135+
113136
it("should detect sharded safetensors filename", async () => {
114137
const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors
115138
const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename);

packages/hub/src/lib/parse-safetensors-metadata.ts

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,20 @@ class SafetensorParseError extends Error {}
4242
type FileName = string;
4343

4444
export type TensorName = string;
45-
export type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL";
45+
export type Dtype =
46+
| "F64"
47+
| "F32"
48+
| "F16"
49+
| "F8_E4M3"
50+
| "F8_E5M2"
51+
| "BF16"
52+
| "I64"
53+
| "I32"
54+
| "I16"
55+
| "I8"
56+
| "U16"
57+
| "U8"
58+
| "BOOL";
4659

4760
export interface TensorInfo {
4861
dtype: Dtype;
@@ -51,13 +64,13 @@ export interface TensorInfo {
5164
}
5265

5366
export type SafetensorsFileHeader = Record<TensorName, TensorInfo> & {
54-
__metadata__: Record<string, string>;
67+
__metadata__: { total_parameters?: string | number } & Record<string, string>;
5568
};
5669

5770
export interface SafetensorsIndexJson {
5871
dtype?: string;
5972
/// ^there's sometimes a dtype but it looks inconsistent.
60-
metadata?: Record<string, string>;
73+
metadata?: { total_parameters?: string | number } & Record<string, string>;
6174
/// ^ why the naming inconsistency?
6275
weight_map: Record<TensorName, FileName>;
6376
}
@@ -69,12 +82,14 @@ export type SafetensorsParseFromRepo =
6982
sharded: false;
7083
header: SafetensorsFileHeader;
7184
parameterCount?: Partial<Record<Dtype, number>>;
85+
parameterTotal?: number;
7286
}
7387
| {
7488
sharded: true;
7589
index: SafetensorsIndexJson;
7690
headers: SafetensorsShardedHeaders;
7791
parameterCount?: Partial<Record<Dtype, number>>;
92+
parameterTotal?: number;
7893
};
7994

8095
async function parseSingleFile(
@@ -127,7 +142,7 @@ async function parseShardedIndex(
127142
*/
128143
fetch?: typeof fetch;
129144
} & Partial<CredentialsParams>
130-
): Promise<{ index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders }> {
145+
): Promise<SafetensorsIndexJson> {
131146
const indexBlob = await downloadFile({
132147
...params,
133148
path,
@@ -137,14 +152,28 @@ async function parseShardedIndex(
137152
throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`);
138153
}
139154

140-
// no validation for now, we assume it's a valid IndexJson.
141-
let index: SafetensorsIndexJson;
142155
try {
143-
index = JSON.parse(await indexBlob.slice(0, 10_000_000).text());
156+
// no validation for now, we assume it's a valid IndexJson.
157+
const index = JSON.parse(await indexBlob.slice(0, 10_000_000).text());
158+
return index;
144159
} catch (error) {
145160
throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`);
146161
}
162+
}
147163

164+
async function fetchAllHeaders(
165+
path: string,
166+
index: SafetensorsIndexJson,
167+
params: {
168+
repo: RepoDesignation;
169+
revision?: string;
170+
hubUrl?: string;
171+
/**
172+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
173+
*/
174+
fetch?: typeof fetch;
175+
} & Partial<CredentialsParams>
176+
): Promise<SafetensorsShardedHeaders> {
148177
const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1);
149178
const filenames = [...new Set(Object.values(index.weight_map))];
150179
const shardedMap: SafetensorsShardedHeaders = Object.fromEntries(
@@ -156,7 +185,7 @@ async function parseShardedIndex(
156185
PARALLEL_DOWNLOADS
157186
)
158187
);
159-
return { index, headers: shardedMap };
188+
return shardedMap;
160189
}
161190

162191
/**
@@ -189,12 +218,12 @@ export async function parseSafetensorsMetadata(
189218
params: {
190219
/** Only models are supported */
191220
repo: RepoDesignation;
221+
path?: string;
192222
/**
193223
* Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType
194224
*
195225
* @default false
196226
*/
197-
path?: string;
198227
computeParametersCount?: boolean;
199228
hubUrl?: string;
200229
revision?: string;
@@ -223,27 +252,55 @@ export async function parseSafetensorsMetadata(
223252
throw new TypeError("Only model repos should contain safetensors files.");
224253
}
225254

226-
if (RE_SAFETENSORS_FILE.test(params.path ?? "") || (await fileExists({ ...params, path: SAFETENSORS_FILE }))) {
255+
if (
256+
(params.path && RE_SAFETENSORS_FILE.test(params.path)) ||
257+
(await fileExists({ ...params, path: SAFETENSORS_FILE }))
258+
) {
227259
const header = await parseSingleFile(params.path ?? SAFETENSORS_FILE, params);
228260
return {
229261
sharded: false,
230262
header,
231-
...(params.computeParametersCount && {
232-
parameterCount: computeNumOfParamsByDtypeSingleFile(header),
233-
}),
263+
...(params.computeParametersCount
264+
? {
265+
parameterCount: computeNumOfParamsByDtypeSingleFile(header),
266+
parameterTotal:
267+
/// shortcut: get param count directly from metadata
268+
header.__metadata__.total_parameters
269+
? typeof header.__metadata__.total_parameters === "number"
270+
? header.__metadata__.total_parameters
271+
: typeof header.__metadata__.total_parameters === "string"
272+
? parseInt(header.__metadata__.total_parameters)
273+
: undefined
274+
: undefined,
275+
}
276+
: undefined),
234277
};
235278
} else if (
236-
RE_SAFETENSORS_INDEX_FILE.test(params.path ?? "") ||
279+
(params.path && RE_SAFETENSORS_INDEX_FILE.test(params.path)) ||
237280
(await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE }))
238281
) {
239-
const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params);
282+
const path = params.path ?? SAFETENSORS_INDEX_FILE;
283+
const index = await parseShardedIndex(path, params);
284+
const shardedMap = await fetchAllHeaders(path, index, params);
285+
240286
return {
241287
sharded: true,
242288
index,
243-
headers,
244-
...(params.computeParametersCount && {
245-
parameterCount: computeNumOfParamsByDtypeSharded(headers),
246-
}),
289+
headers: shardedMap,
290+
...(params.computeParametersCount
291+
? {
292+
parameterCount: computeNumOfParamsByDtypeSharded(shardedMap),
293+
parameterTotal:
294+
/// shortcut: get param count directly from metadata
295+
index.metadata?.total_parameters
296+
? typeof index.metadata.total_parameters === "number"
297+
? index.metadata.total_parameters
298+
: typeof index.metadata.total_parameters === "string"
299+
? parseInt(index.metadata.total_parameters)
300+
: undefined
301+
: undefined,
302+
}
303+
: undefined),
247304
};
248305
} else {
249306
throw new Error("model id does not seem to contain safetensors weights");

packages/hub/src/lib/paths-info.spec.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ describe("pathsInfo", () => {
66
it("should fetch LFS path info", async () => {
77
const result: PathInfo[] = await pathsInfo({
88
repo: {
9-
name: "bert-base-uncased",
9+
name: "google-bert/bert-base-uncased",
1010
type: "model",
1111
},
1212
paths: ["tf_model.h5"],
@@ -35,7 +35,7 @@ describe("pathsInfo", () => {
3535
securityFileStatus: SecurityFileStatus;
3636
})[] = await pathsInfo({
3737
repo: {
38-
name: "bert-base-uncased",
38+
name: "google-bert/bert-base-uncased",
3939
type: "model",
4040
},
4141
paths: ["tf_model.h5"],
@@ -59,7 +59,7 @@ describe("pathsInfo", () => {
5959
it("non-LFS pointer should have lfs undefined", async () => {
6060
const result: PathInfo[] = await pathsInfo({
6161
repo: {
62-
name: "bert-base-uncased",
62+
name: "google-bert/bert-base-uncased",
6363
type: "model",
6464
},
6565
paths: ["config.json"],

0 commit comments

Comments
 (0)