Skip to content

Commit 3988b32

Browse files
committed
Implement worker function affinity for more efficient subsequent executions
1 parent ce27d3e commit 3988b32

File tree

5 files changed

+153
-109
lines changed

5 files changed

+153
-109
lines changed

src/lib/lib.ts

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@ export * from "./sync/mod.ts";
88
export { SharedJsonBuffer } from "./json_buffer.ts";
99

1010
let globalPool: WorkerPool | null = null;
11-
const functionIdCache = new WeakMap<UserFunction, string>();
1211
let globalConfig = { maxWorkers: navigator.hardwareConcurrency || 4 };
1312

13+
const globalFunctionRegistry = new WeakMap<
14+
UserFunction,
15+
{ id: string; code: string }
16+
>();
17+
1418
export function initRuntime(config: { maxWorkers: number }) {
1519
if (globalPool) throw new Error("Runtime already initialized");
1620
globalConfig = { ...globalConfig, ...config };
@@ -56,7 +60,6 @@ export function spawn<R>(
5660
fn: (this: void) => R | Promise<R>,
5761
): JoinHandle<R>;
5862

59-
// Implementation
6063
export function spawn(arg1: any, arg2?: any): JoinHandle<any> {
6164
const pool = getPool();
6265
const { resolve, reject, promise } = Promise.withResolvers<
@@ -66,49 +69,58 @@ export function spawn(arg1: any, arg2?: any): JoinHandle<any> {
6669
let args: any[] = [];
6770
let fn: UserFunction;
6871

69-
// Runtime Logic: Keeps checking for the Symbol
72+
// Argument parsing
7073
if (arg1 && Object.prototype.hasOwnProperty.call(arg1, moveTag)) {
7174
args = arg1;
7275
fn = arg2;
7376
} else {
7477
fn = arg1;
7578
}
7679

77-
let fnId = functionIdCache.get(fn);
78-
if (!fnId) {
79-
fnId = Math.random().toString().slice(2);
80-
functionIdCache.set(fn, fnId);
81-
}
80+
let meta = globalFunctionRegistry.get(fn);
8281

83-
const callerLocation = getCallerLocation();
82+
if (!meta) {
83+
// Cache miss: Generate ID and patch code
84+
const id = Math.random().toString(36).slice(2);
85+
const callerLocation = getCallerLocation();
8486

85-
(async () => {
87+
// We wrap this in a try-catch block inside the cache logic
88+
// to fail early if toString fails
8689
try {
87-
const finalCode = patchDynamicImports(
90+
const code = patchDynamicImports(
8891
"export default " + fn.toString(),
8992
callerLocation.filePath,
9093
);
94+
meta = { id, code };
95+
globalFunctionRegistry.set(fn, meta);
96+
} catch (err) {
97+
console.error(err);
98+
return {
99+
join: () =>
100+
Promise.resolve({
101+
ok: false,
102+
error: err instanceof Error
103+
? err
104+
: new Error("Failed to compile function"),
105+
}),
106+
abort: () => {},
107+
};
108+
}
109+
}
91110

111+
// Task submission
112+
(async () => {
113+
try {
92114
const task: ThreadTask = {
93-
fnId,
94-
code: finalCode,
115+
fnId: meta!.id,
116+
code: meta!.code,
95117
args,
96118
};
97119

98-
try {
99-
const val = await pool.submit(task);
100-
resolve({ ok: true, value: val });
101-
} catch (err) {
102-
resolve({ ok: false, error: err as Error });
103-
}
120+
const val = await pool.submit(task);
121+
resolve({ ok: true, value: val });
104122
} catch (err) {
105-
console.error(err);
106-
resolve({
107-
ok: false,
108-
error: err instanceof Error
109-
? err
110-
: new Error("Failed to extract function source"),
111-
});
123+
resolve({ ok: false, error: err as Error });
112124
}
113125
})();
114126

@@ -121,6 +133,7 @@ export function spawn(arg1: any, arg2?: any): JoinHandle<any> {
121133
export function shutdown() {
122134
if (globalPool) {
123135
globalPool.terminate();
136+
globalPool = null;
124137
}
125138
}
126139

src/lib/patch_import.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ export function patchDynamicImports(
9696
// Apply replacements
9797
replacements.sort((a, b) => b.start - a.start);
9898
let modifiedCode = code;
99-
for (const rep of replacements) {
99+
for (let i = 0; i < replacements.length; i++) {
100+
const rep = replacements[i]!;
100101
const before = modifiedCode.slice(0, rep.start);
101102
const after = modifiedCode.slice(rep.end);
102103
modifiedCode = before + rep.text + after;

src/lib/pool.ts

Lines changed: 82 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,97 @@
11
import { deserialize, serialize } from "./shared.ts";
22
import type { ThreadTask, WorkerResponsePayload } from "./types.ts";
33

4-
/**
5-
* In Vite, the worker detection will only work if the new URL() constructor is used directly inside the new Worker() declaration.
6-
* Additionally, all options parameters must be static values (i.e. string literals).
7-
*/
84
let newWorker: () => Worker;
95

106
export function workerOverride(fn: () => Worker) {
117
newWorker = fn;
128
}
139

10+
interface TrackedWorker extends Worker {
11+
// Set of function IDs this worker has already compiled
12+
_loadedFnIds: Set<string>;
13+
_pending: Map<
14+
number,
15+
{ resolve: (val: any) => void; reject: (err: any) => void }
16+
>;
17+
}
18+
1419
export class WorkerPool {
15-
private workers: Worker[] = [];
16-
private workerLoad = new Map<Worker, number>();
17-
private pending = new Map<
18-
Worker,
19-
Map<number, { resolve: (val: any) => void; reject: (err: any) => void }>
20-
>();
21-
22-
private codeCache = new Map<string, Promise<string>>();
20+
private workers: TrackedWorker[] = [];
2321
private maxThreads: number;
2422
private taskIdCounter = 0;
2523

2624
constructor(maxThreads?: number) {
27-
this.maxThreads = maxThreads ?? navigator.hardwareConcurrency ?? 4;
25+
this.maxThreads = maxThreads || navigator.hardwareConcurrency || 4;
2826
}
2927

30-
private spawnWorker(): Worker {
31-
const worker = newWorker();
32-
this.pending.set(worker, new Map());
33-
this.workerLoad.set(worker, 0);
28+
private createWorker(): TrackedWorker {
29+
const worker = newWorker() as TrackedWorker;
30+
31+
worker._loadedFnIds = new Set();
32+
worker._pending = new Map();
3433

35-
// 1. Success / Task Error Handler
3634
worker.onmessage = (e: MessageEvent<WorkerResponsePayload>) => {
3735
const { taskId, type } = e.data;
38-
const workerPending = this.pending.get(worker);
39-
const p = workerPending?.get(taskId);
36+
const p = worker._pending.get(taskId);
4037

4138
if (p) {
42-
this.workerLoad.set(
43-
worker,
44-
Math.max(0, this.workerLoad.get(worker) ?? 0 - 1),
45-
);
46-
4739
if (type === "ERROR") {
4840
const err = new Error(e.data.error);
4941
if (e.data.stack) err.stack = e.data.stack;
5042
p.reject(err);
5143
} else {
52-
// Rehydrate the result (Restore Mutex/Channel methods)
53-
const result = deserialize(e.data.result);
54-
p.resolve(result);
44+
p.resolve(deserialize(e.data.result));
5545
}
56-
workerPending?.delete(taskId);
46+
worker._pending.delete(taskId);
5747
}
5848
};
5949

60-
// 2. Crash Handler
6150
worker.onerror = (e) => {
6251
e.preventDefault();
63-
const workerPending = this.pending.get(worker);
64-
if (workerPending) {
65-
for (const [_, p] of workerPending) {
66-
p.reject(new Error(`Worker Crashed: ${e.message}`));
67-
}
68-
workerPending.clear();
69-
}
70-
this.workerLoad.delete(worker);
71-
this.pending.delete(worker);
72-
this.workers = this.workers.filter((w) => w !== worker);
52+
const err = new Error(`Worker Crashed: ${e.message}`);
53+
for (const p of worker._pending.values()) p.reject(err);
54+
worker._pending.clear();
55+
this.removeWorker(worker);
7356
};
7457

7558
this.workers.push(worker);
7659
return worker;
7760
}
7861

79-
async submit(task: ThreadTask): Promise<any> {
80-
const { fnId, code, args } = task;
62+
private removeWorker(worker: TrackedWorker) {
63+
this.workers = this.workers.filter((w) => w !== worker);
64+
worker.terminate();
65+
}
8166

82-
// Select Worker (Least Loaded)
83-
let selectedWorker: Worker;
84-
if (this.workers.length < this.maxThreads) {
85-
selectedWorker = this.spawnWorker();
86-
} else {
87-
selectedWorker = this.workers.reduce((prev, curr) => {
88-
const prevLoad = this.workerLoad.get(prev)!;
89-
const currLoad = this.workerLoad.get(curr)!;
90-
return prevLoad < currLoad ? prev : curr;
91-
});
92-
}
67+
private async executeTask(
68+
worker: TrackedWorker,
69+
task: ThreadTask,
70+
): Promise<any> {
71+
const { fnId, code, args } = task;
72+
const taskId = this.taskIdCounter++;
9373

9474
const { promise, resolve, reject } = Promise.withResolvers();
9575

96-
const taskId = this.taskIdCounter++;
97-
this.pending.get(selectedWorker)!.set(taskId, { resolve, reject });
98-
this.workerLoad.set(
99-
selectedWorker,
100-
(this.workerLoad.get(selectedWorker) || 0) + 1,
101-
);
76+
worker._pending.set(taskId, { resolve, reject });
10277

103-
// Serialize Args & Unify Transferables
10478
const serializedArgs = args.map(serialize);
10579
const values = serializedArgs.map((r) => r.value);
10680
const transferList = [
10781
...new Set(serializedArgs.flatMap((r) => r.transfer)),
10882
];
10983

110-
selectedWorker.postMessage(
84+
const hasCode = worker._loadedFnIds.has(fnId);
85+
if (!hasCode) {
86+
worker._loadedFnIds.add(fnId);
87+
}
88+
89+
worker.postMessage(
11190
{
11291
type: "RUN",
11392
taskId,
11493
fnId,
115-
code,
94+
code: hasCode ? undefined : code,
11695
args: values,
11796
},
11897
{ transfer: transferList },
@@ -121,11 +100,49 @@ export class WorkerPool {
121100
return await promise;
122101
}
123102

103+
async submit(task: ThreadTask): Promise<any> {
104+
let bestCandidate: TrackedWorker | undefined;
105+
let bestCandidateLoad = Infinity;
106+
// Score: 0 = Idle+Affinity, 1 = Idle, 2 = Busy+Affinity, 3 = Busy
107+
let bestCandidateScore = 4;
108+
109+
for (let i = 0; i < this.workers.length; i++) {
110+
const w = this.workers[i]!;
111+
const load = w._pending.size;
112+
const hasAffinity = w._loadedFnIds.has(task.fnId);
113+
114+
if (load === 0 && hasAffinity) {
115+
return await this.executeTask(w, task);
116+
}
117+
118+
let score = 4;
119+
if (load === 0) score = 1;
120+
else if (hasAffinity) score = 2;
121+
else score = 3;
122+
123+
if (
124+
score < bestCandidateScore ||
125+
(score === bestCandidateScore && load < bestCandidateLoad)
126+
) {
127+
bestCandidate = w;
128+
bestCandidateScore = score;
129+
bestCandidateLoad = load;
130+
}
131+
}
132+
133+
if (bestCandidateScore >= 2 && this.workers.length < this.maxThreads) {
134+
return await this.executeTask(this.createWorker(), task);
135+
}
136+
137+
if (bestCandidate) {
138+
return await this.executeTask(bestCandidate, task);
139+
}
140+
141+
return await this.executeTask(this.createWorker(), task);
142+
}
143+
124144
terminate() {
125145
for (const w of this.workers) w.terminate();
126146
this.workers = [];
127-
this.workerLoad.clear();
128-
this.pending.clear();
129-
this.codeCache.clear();
130147
}
131148
}

src/lib/types.ts

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,6 @@ export type Result<T, E = Error> =
55
| { ok: true; value: T }
66
| { ok: false; error: E };
77

8-
export interface ThreadTask {
9-
fnId: string;
10-
code: string;
11-
args: unknown[];
12-
}
13-
148
export interface JoinHandle<T> {
159
join(): Promise<Result<T, Error>>;
1610
abort(): void;
@@ -31,6 +25,15 @@ export type SharedMemoryView =
3125
| DataView<SharedArrayBuffer>
3226
| SharedJsonBuffer<any>;
3327

28+
/**
29+
* The internal task structure passed from lib -> pool
30+
*/
31+
export interface ThreadTask {
32+
fnId: string;
33+
code: string; // The code is always available here, but Pool decides if it sends it
34+
args: unknown[];
35+
}
36+
3437
/**
3538
* The wire format.
3639
* 't': type
@@ -45,7 +48,7 @@ export type WorkerTaskPayload = {
4548
type: "RUN";
4649
taskId: number;
4750
fnId: string;
48-
code: string;
51+
code?: string; // Optional: Only sent if worker doesn't have it
4952
args: Envelope[];
5053
};
5154

0 commit comments

Comments
 (0)