Skip to content

Commit 9ddecfc

Browse files
chore: Cleanup release
1 parent 9f8cbf8 commit 9ddecfc

File tree

5 files changed

+27
-113
lines changed

5 files changed

+27
-113
lines changed

gridworld_td/agents/world-agent.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { ITDAgentOptions, TDAgent } from "../../lib/agents/q-agent";
1+
import { IQAgentOptions, QAgent } from "../../lib/agents/q-agent";
22
import { zeros } from "../../lib/zeros";
33

44
export interface IWorldAgentState {
@@ -7,13 +7,13 @@ export interface IWorldAgentState {
77
reset_episode?: boolean;
88
}
99

10-
export interface IWorldAgentOpts extends ITDAgentOptions {
10+
export interface IWorldAgentOpts extends IQAgentOptions {
1111
gw: number;
1212
gh: number;
1313
gs: number;
1414
}
1515

16-
export class WorldAgent extends TDAgent {
16+
export class WorldAgent extends QAgent {
1717
V?: number[];
1818
G?: number[];
1919
gh: number;

lib/agents/q-agent.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { zeros } from "../zeros";
22
import { setConst, sampleWeighted, randi } from "../utilities";
33

4-
export interface ITDAgentOptions {
4+
export interface IQAgentOptions {
55
update?: 'qlearn' | 'sarsa';
66
gamma?: number;
77
epsilon?: number;
@@ -18,7 +18,7 @@ export interface ITDAgentOptions {
1818
// QAgent uses TD (Q-Learning, SARSA)
1919
// - does not require environment model :)
2020
// - learns from experience :)
21-
export abstract class TDAgent {
21+
export abstract class QAgent {
2222
update: string;
2323
gamma: number;
2424
epsilon: number;
@@ -39,7 +39,7 @@ export abstract class TDAgent {
3939
sa_seen: number[];
4040
pq: number[] | Float64Array;
4141

42-
constructor(opt: ITDAgentOptions) {
42+
constructor(opt: IQAgentOptions) {
4343
this.update = opt.update ?? 'qlearn'; // qlearn | sarsa
4444
this.gamma = opt.gamma ?? 0.75; // future reward discount factor
4545
this.epsilon = opt.epsilon ?? 0.1; // for epsilon-greedy policy

lib/index.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// agents
2+
export { DeterministPG, IDeterministPGOptions } from "./agents/determinist-pg";
3+
export { DPAgent, IDPAgentOptions } from "./agents/dp-agent";
4+
export { DQNAgent, IDQNAgentOptions, IDQNAgentJSON } from "./agents/dqn-agent";
5+
export { QAgent, IQAgentOptions } from "./agents/q-agent";
6+
export { RecurrentReinforceAgent, IRecurrentReinforceAgentOption } from "./agents/recurrent-reinforce-agent";
7+
export { SimpleReinforceAgent, ISimpleReinforceAgentOption } from "./agents/simple-reinforce-agent";
8+
9+
// lib
10+
export { Graph } from "./graph";
11+
export { LSTM, ILstmModelLayer, ILstmModel, ILSTMCell } from "./lstm";
12+
export { Mat, IMatJson } from "./mat";
13+
export { Net, INetJSON } from "./net";
14+
export { RandMat } from "./rand-mat";
15+
export { Solver } from "./solver";
16+
export { Tuple } from "./tuple";
17+
export { sig, fillRand, fillRandn, randn, gaussRandom, randf, randi, maxi, samplei, sampleWeighted, setConst, assert } from "./utilities"
18+
export { zeros } from "./zeros";

lib/lstm.ts

Lines changed: 0 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -155,98 +155,3 @@ export class LSTM {
155155
bd.update(alpha);
156156
}
157157
}
158-
159-
// export const initLSTM = function(input_size: number, hidden_sizes: number[], output_size: number) {
160-
// // hidden size should be a list
161-
// var model = {};
162-
// for(let d=0;d<hidden_sizes.length;d++) { // loop over depths
163-
// const prev_size = d === 0 ? input_size : hidden_sizes[d - 1];
164-
// const hidden_size = hidden_sizes[d];
165-
//
166-
// // gates parameters
167-
// model['Wix'+d] = new RandMat(hidden_size, prev_size , 0, 0.08);
168-
// model['Wih'+d] = new RandMat(hidden_size, hidden_size , 0, 0.08);
169-
// model['bi'+d] = new Mat(hidden_size, 1);
170-
// model['Wfx'+d] = new RandMat(hidden_size, prev_size , 0, 0.08);
171-
// model['Wfh'+d] = new RandMat(hidden_size, hidden_size , 0, 0.08);
172-
// model['bf'+d] = new Mat(hidden_size, 1);
173-
// model['Wox'+d] = new RandMat(hidden_size, prev_size , 0, 0.08);
174-
// model['Woh'+d] = new RandMat(hidden_size, hidden_size , 0, 0.08);
175-
// model['bo'+d] = new Mat(hidden_size, 1);
176-
// // cell write params
177-
// model['Wcx'+d] = new RandMat(hidden_size, prev_size , 0, 0.08);
178-
// model['Wch'+d] = new RandMat(hidden_size, hidden_size , 0, 0.08);
179-
// model['bc'+d] = new Mat(hidden_size, 1);
180-
// }
181-
// // decoder params
182-
// model['Whd'] = new RandMat(output_size, hidden_size, 0, 0.08);
183-
// model['bd'] = new Mat(output_size, 1);
184-
// return model;
185-
// }
186-
//
187-
// export const forwardLSTM = function(G: Graph, model, hidden_sizes, x, prev) {
188-
// // forward prop for a single tick of LSTM
189-
// // G is graph to append ops to
190-
// // model contains LSTM parameters
191-
// // x is 1D column vector with observation
192-
// // prev is a struct containing hidden and cell
193-
// // from previous iteration
194-
//
195-
// if(prev == null || typeof prev.h === 'undefined') {
196-
// var hidden_prevs = [];
197-
// var cell_prevs = [];
198-
// for(var d=0;d<hidden_sizes.length;d++) {
199-
// hidden_prevs.push(new Mat(hidden_sizes[d],1));
200-
// cell_prevs.push(new Mat(hidden_sizes[d],1));
201-
// }
202-
// } else {
203-
// var hidden_prevs = prev.h;
204-
// var cell_prevs = prev.c;
205-
// }
206-
//
207-
// var hidden = [];
208-
// var cell = [];
209-
// for(var d=0;d<hidden_sizes.length;d++) {
210-
//
211-
// var input_vector = d === 0 ? x : hidden[d-1];
212-
// var hidden_prev = hidden_prevs[d];
213-
// var cell_prev = cell_prevs[d];
214-
//
215-
// // input gate
216-
// var h0 = G.mul(model['Wix'+d], input_vector);
217-
// var h1 = G.mul(model['Wih'+d], hidden_prev);
218-
// var input_gate = G.sigmoid(G.add(G.add(h0,h1),model['bi'+d]));
219-
//
220-
// // forget gate
221-
// var h2 = G.mul(model['Wfx'+d], input_vector);
222-
// var h3 = G.mul(model['Wfh'+d], hidden_prev);
223-
// var forget_gate = G.sigmoid(G.add(G.add(h2, h3),model['bf'+d]));
224-
//
225-
// // output gate
226-
// var h4 = G.mul(model['Wox'+d], input_vector);
227-
// var h5 = G.mul(model['Woh'+d], hidden_prev);
228-
// var output_gate = G.sigmoid(G.add(G.add(h4, h5),model['bo'+d]));
229-
//
230-
// // write operation on cells
231-
// var h6 = G.mul(model['Wcx'+d], input_vector);
232-
// var h7 = G.mul(model['Wch'+d], hidden_prev);
233-
// var cell_write = G.tanh(G.add(G.add(h6, h7),model['bc'+d]));
234-
//
235-
// // compute new cell activation
236-
// var retain_cell = G.eltmul(forget_gate, cell_prev); // what do we keep from cell
237-
// var write_cell = G.eltmul(input_gate, cell_write); // what do we write to cell
238-
// var cell_d = G.add(retain_cell, write_cell); // new cell contents
239-
//
240-
// // compute hidden state as gated, saturated cell activations
241-
// var hidden_d = G.eltmul(output_gate, G.tanh(cell_d));
242-
//
243-
// hidden.push(hidden_d);
244-
// cell.push(cell_d);
245-
// }
246-
//
247-
// // one decoder to outputs at end
248-
// var output = G.add(G.mul(model['Whd'], hidden[hidden.length - 1]),model['bd']);
249-
//
250-
// // return cell memory, hidden representation and output
251-
// return {'h':hidden, 'c':cell, 'o' : output};
252-
// }

webpack.config.js

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,26 @@
11
const path = require('path');
2-
const webpack = require('webpack');
32

43
const web = {
54
target: 'web',
6-
entry: './gridworld_td/index.ts',
5+
entry: './lib/index.ts',
76
devtool: 'source-map',
87
devServer: {
98
contentBase: path.join(__dirname, 'dist'),
109
compress: true,
1110
port: 9000
1211
},
1312
mode: 'production',
14-
plugins: [
15-
// new webpack.ProvidePlugin({
16-
// jQuery: 'jQuery',
17-
// }),
18-
],
1913
module: {
2014
rules: [
2115
{
22-
test: /\.tsx?$/,
16+
test: /\.ts$/,
2317
use: 'ts-loader',
2418
exclude: /node_modules/,
2519
},
2620
],
2721
},
2822
resolve: {
29-
extensions: ['.tsx', '.ts', '.js'],
30-
// alias: {
31-
// 'jQuery': path.resolve(__dirname, './node_modules/jquery/dist/jquery.js'),
32-
// }
23+
extensions: ['.ts'],
3324
},
3425
output: {
3526
filename: 'reinforce-browser.js',

0 commit comments

Comments
 (0)