Skip to content

Commit cf53588

Browse files
Add a blank starter package for gpt2 (#1185)
1 parent 7d69946 commit cf53588

18 files changed

+4855
-0
lines changed

gpt2/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# GPT2
2+
Run GPT2 in the browser with TFJS. TODO(mattsoulanille || pforderique): Expand on this.

gpt2/demo/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
esbuild_bundle_meta.json

gpt2/demo/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# GPT2 demo
2+
3+
## Contents
4+
5+
This demo shows how to use the GPT2 model to generate text.
6+
7+
## Setup
8+
9+
cd into the `gpt2` folder. From the root of the repo, this is located at `gpt2/`. From the demo, it's `../`.
10+
11+
Install dependencies:
12+
```sh
13+
yarn
14+
```
15+
16+
cd into the demo and install dependencies:
17+
18+
```sh
19+
cd demo
20+
yarn
21+
```
22+
23+
build the demo's dependencies. You'll need to re-run this whenever you make changes to the `@tfjs-models/gpt2` package that this demo uses.
24+
```sh
25+
yarn build-deps
26+
```
27+
28+
start the dev demo server:
29+
```sh
30+
yarn watch
31+
```

gpt2/demo/index.html

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
<!-- copyright 2023 google llc.
2+
3+
licensed under the apache license, version 2.0 (the "license");
4+
you may not use this file except in compliance with the license.
5+
you may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================-->
15+
16+
17+
<head>
18+
<style>
19+
.model-div {
20+
display: flex;
21+
margin-bottom: 1em;
22+
}
23+
.model-textbox {
24+
flex-grow: 1;
25+
min-height: 10em;
26+
}
27+
.generate-button {
28+
29+
}
30+
</style>
31+
</head>
32+
<body>
33+
<h1>GPT2</h1>
34+
<script src="bundle.js" defer></script>
35+
<div class="model-div">
36+
<textarea class="model-textbox" type="text">I like walking my dog at </textarea>
37+
</div>
38+
<button disabled="true" class="generate-button" type="button">Generate</button>
39+
40+
</body>

gpt2/demo/index.ts

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {GPT2, load} from '@tensorflow-models/gpt2';
19+
import * as lil from 'lil-gui';
20+
import * as tf from '@tensorflow/tfjs-core';
21+
import {setWasmPaths} from '@tensorflow/tfjs-backend-wasm';
22+
import '@tensorflow/tfjs-backend-webgl';
23+
import '@tensorflow/tfjs-backend-webgpu';
24+
import '@tensorflow/tfjs-backend-cpu';
25+
26+
setWasmPaths('node_modules/@tensorflow/tfjs-backend-wasm/wasm-out/');
27+
28+
(window as any).tf = tf;
29+
30+
tf.setBackend('webgl');
31+
32+
const state = {
33+
backend: tf.getBackend(),
34+
};
35+
36+
const gui = new lil.GUI();
37+
const backendController = gui.add(state, 'backend', ['wasm', 'webgl', 'webgpu', 'cpu'])
38+
.onChange(async (backend: string) => {
39+
const lastBackend = tf.getBackend();
40+
let success = false;
41+
try {
42+
success = await tf.setBackend(backend);
43+
} catch (e) {
44+
console.warn(e.message);
45+
}
46+
if (!success) {
47+
alert(`Failed to use backend ${backend}. Check the console for errors.`);
48+
tf.setBackend(lastBackend);
49+
state.backend = lastBackend;
50+
backendController.updateDisplay();
51+
return;
52+
}
53+
}).listen(true);
54+
55+
const textElement = document.querySelector(".model-textbox") as HTMLTextAreaElement;
56+
57+
function setText(text: string) {
58+
textElement.textContent = text;
59+
}
60+
function getText() {
61+
return textElement.textContent || '';
62+
}
63+
64+
const button = document.querySelector('.generate-button') as HTMLButtonElement;
65+
if (button == null) {
66+
throw new Error('No button found for generating text');
67+
}
68+
69+
button.onclick = generate;
70+
71+
let gpt2: GPT2;
72+
async function init() {
73+
gpt2 = await load();
74+
button.disabled = false;
75+
}
76+
77+
async function generate() {
78+
button.disabled = true;
79+
80+
const text = getText();
81+
setText(text + await gpt2.generate(text));
82+
button.disabled = false;
83+
}
84+
85+
init();

gpt2/demo/package.json

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
{
2+
"name": "gpt2_demo",
3+
"version": "1.0.0",
4+
"license": "Apache-2.0",
5+
"private": true,
6+
"scripts": {
7+
"watch": "esbuild index.ts --bundle --outfile=bundle.js --target=es6 --servedir=. --serve --sourcemap --sources-content=true --preserve-symlinks",
8+
"build": "mkdirp dist && cp index.html dist && esbuild index.ts --bundle --target=es6 --outfile=dist/bundle.js --sourcemap --sources-content=true --preserve-symlinks --minify --metafile=./esbuild_bundle_meta.json",
9+
"build-model": "cd .. && yarn && yarn build-npm",
10+
"build-deps": "yarn build-model"
11+
},
12+
"dependencies": {
13+
"@tensorflow-models/gpt2": "link:../",
14+
"@tensorflow/tfjs-backend-wasm": "^4.10.0",
15+
"@tensorflow/tfjs-backend-webgl": "^4.10.0",
16+
"@tensorflow/tfjs-backend-webgpu": "^4.10.0",
17+
"lil-gui": "^0.18.2"
18+
},
19+
"devDependencies": {
20+
"@tensorflow/tfjs-backend-cpu": "^4.10.0",
21+
"esbuild": "^0.19.0",
22+
"eslint": "^8.46.0",
23+
"eslint-config-google": "^0.14.0",
24+
"mkdirp": "^3.0.1"
25+
},
26+
"eslintConfig": {
27+
"extends": "google",
28+
"rules": {
29+
"require-jsdoc": 0,
30+
"valid-jsdoc": 0
31+
},
32+
"env": {
33+
"es6": true
34+
},
35+
"parserOptions": {
36+
"ecmaVersion": 8,
37+
"sourceType": "module"
38+
}
39+
},
40+
"eslintIgnore": [
41+
"dist/"
42+
]
43+
}

0 commit comments

Comments
 (0)