Skip to content

Commit 740ca22

Browse files
authored
feat(example): added suggest_categories_gpt script (#66)
1 parent bbb228a commit 740ca22

File tree

2 files changed

+233
-12
lines changed

2 files changed

+233
-12
lines changed

examples/suggest_categories.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,28 @@
77
from collections import Counter
88
from datetime import datetime, timedelta, timezone
99
from tabulate import tabulate
10-
from typing import Dict
10+
from typing import Dict, List, Tuple, Any
1111

1212
from aw_core import Event
1313
import aw_client
1414
from aw_client import queries
1515

16+
1617
# set up client
1718
awc = aw_client.ActivityWatchClient("test")
1819

1920

20-
def get_events():
21+
def example_categories():
22+
# TODO: Use tools in aw-research to load categories from toml file
23+
return [
24+
(
25+
("Work", "ActivityWatch"),
26+
{"type": "regex", "regex": "aw-|activitywatch", "ignore_case": True},
27+
),
28+
]
29+
30+
31+
def get_events(categories=List[Tuple[Tuple[str], Dict[str, Any]]]):
2132
"""
2233
Retrieves AFK-filtered events, only returns events which are Uncategorized.
2334
"""
@@ -26,14 +37,6 @@ def get_events():
2637
now = datetime.now(tz=timezone.utc)
2738
timeperiods = [(start, now)]
2839

29-
# TODO: Use tools in aw-research to load categories from toml file
30-
categories = [
31-
(
32-
["Work"],
33-
{"type": "regex", "regex": "aw-|activitywatch", "ignore_case": True},
34-
),
35-
]
36-
3740
canonicalQuery = queries.canonicalEvents(
3841
queries.DesktopQueryParams(
3942
bid_window="aw-watcher-window_",
@@ -66,8 +69,9 @@ def events2words(events):
6669
yield (word, e.duration)
6770

6871

69-
if __name__ == "__main__":
70-
events = get_events()
72+
def main():
73+
categories = example_categories()
74+
events = get_events(categories)
7175

7276
# find most common words, by duration
7377
corpus: Dict[str, timedelta] = Counter() # type: ignore
@@ -79,3 +83,7 @@ def events2words(events):
7983
# The top words are rarely useful for categorization, as they are usually browsers and other categories
8084
# of activity which are too broad for it to make sense as a rule (except as a fallback).
8185
print(tabulate(corpus.most_common(50), headers=["word", "duration"])) # type: ignore
86+
87+
88+
if __name__ == "__main__":
89+
main()

examples/suggest_categories_gpt.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""
2+
Uses GPT-3 to suggest new categories.
3+
4+
Builds on suggest_categories.py for basic operations, like getting events.
5+
"""
6+
7+
from datetime import datetime, timezone
8+
from typing import Any, Dict, List, Tuple
9+
10+
from aw_core import Event
11+
from aw_transform.classify import Rule, categorize
12+
13+
from suggest_categories import example_categories, get_events
14+
15+
Category = Tuple[List[str], Dict[str, Any]]
16+
17+
18+
def prompt_preamble(categories: List[Category]) -> str:
19+
categories_str = "\n\n".join(
20+
[
21+
f" - Category: {' > '.join(name)}\n Regex: {rule['regex']}"
22+
for name, rule in categories
23+
]
24+
)
25+
26+
prompt = f"""We will classify window titles into user-defined categories defined by regular expressions.
27+
If a suitable one doesn't exists, we will create one with a suitable regex.
28+
29+
Existing categories:
30+
31+
{categories_str}
32+
33+
---
34+
35+
What category should "ActivityWatch - wwww.github.com" be in?
36+
Category: Work > ActivityWatch
37+
38+
What category should "reddit: the front page of the internet" be in?
39+
New Category: Media > Social > Reddit
40+
Regex: reddit
41+
42+
What category should "Twitter" be in?
43+
New Category: Media > Social > Twitter
44+
Regex: Twitter
45+
46+
What category should "Demis Hassabis: DeepMind - AI, Superintelligence & the Future of Humanity | Lex Fridman Podcast - YouTube - Mozilla Firefox" be in?
47+
New Category: Media > Video > YouTube
48+
Regex: YouTube
49+
50+
What category should "Tweetdeck" be in?
51+
Modify Category: Media > Social > Twitter
52+
Append Regex: Tweetdeck
53+
54+
What category should "cloudflare-ipfs.com | 524: A timeout occurred - cloudflare-ipfs.com - Mozilla Firefox" be in?
55+
Skip: No suitable category found or to suggest, best left as uncategorized.
56+
57+
What category should "Mozilla Firefox" be in?
58+
Skip: No suitable category found or to suggest, best left as uncategorized.
59+
60+
What category should "RimWorld" be in?
61+
New Category: Games > RimWorld
62+
Regex: RimWorld
63+
64+
What category should "Minecraft" be in?
65+
New Category: Games > Minecraft
66+
Regex: Minecraft
67+
68+
What category should "Free Porn Videos & Sex Movies - Porno, XXX, Porn Tube | Pornhub — Mozilla Firefox" be in?
69+
New Category: Media > Porn
70+
Regex: Pornhub"""
71+
72+
return prompt
73+
74+
75+
def process_prompt(prompt, categories, quiet=False) -> List[Category]:
76+
"""processes the prompt preamble for categories created/modified in the prompt"""
77+
for entry in prompt.split("---", 1)[1].split("\n\n"):
78+
if not entry.strip():
79+
continue
80+
# FIXME: Will break if string contains double-quotes
81+
title = entry.split("\n", 1)[0].split('"', 2)[1]
82+
response = entry.strip().split("\n", 1)[1]
83+
categories = parse_gpt_response(response, categories, title=title, quiet=quiet)
84+
return categories
85+
86+
87+
def gpt_suggest(event: Event, categories: List[Category]) -> List[Category]:
88+
"""
89+
Use OpenAI GPT-3 to suggest new categories.
90+
91+
Takes an uncategorized event and the current categories, and returns a list of suggested categories.
92+
"""
93+
import os
94+
import openai
95+
from copy import deepcopy
96+
97+
openai.api_key = os.getenv("OPENAI_API_KEY")
98+
99+
title = event.data["title"]
100+
categories = deepcopy(categories)
101+
102+
prompt = prompt_preamble(categories)
103+
categories = process_prompt(prompt, categories, quiet=True)
104+
105+
prompt = f"""{prompt}
106+
107+
What category should "{title}" be in?"""
108+
109+
response = openai.Completion.create(
110+
model="text-davinci-002",
111+
prompt=prompt,
112+
temperature=0,
113+
max_tokens=64,
114+
top_p=1.0,
115+
frequency_penalty=0.0,
116+
presence_penalty=0.0,
117+
)
118+
text = response["choices"][0]["text"]
119+
print("-" * 80)
120+
# print(prompt + text)
121+
# print("> " + text.strip().replace("\n", "\n> "))
122+
123+
return parse_gpt_response(text, categories, title)
124+
125+
126+
def check_is_category(text: str, category: Category) -> bool:
127+
name, rule = category
128+
event = categorize(
129+
[Event(timestamp=datetime.now(tz=timezone.utc), data={"title": text})],
130+
[(list(name), Rule(rule))],
131+
)[0]
132+
return event.data["$category"] == list(category[0])
133+
134+
135+
def parse_gpt_response(text: str, categories: List[Category], title=None, quiet=False):
136+
category_names = [tuple(name) for name, _ in categories]
137+
138+
line1, *lines = text.strip().split("\n")
139+
if line1.startswith("Category:"):
140+
# chose existing category
141+
cat_name = tuple(line1.split(":", 1)[1].strip().split(" > "))
142+
if not quiet:
143+
print(f"Chose existing category {cat_name} (title: {title})")
144+
if cat_name not in [tuple(name[: len(cat_name)]) for name in category_names]:
145+
print(f"No category named {cat_name} found, skipping")
146+
elif line1.startswith("New Category:"):
147+
line2 = lines[0]
148+
category = line1.split(":", 1)[1].strip().split(" > ")
149+
regex = line2.strip().split(":", 1)[1].strip()
150+
if not quiet:
151+
print(f"Added category {category} with regex {regex} (title: {title})")
152+
cat: Category = (
153+
category,
154+
{"type": "regex", "regex": regex, "ignore_case": True},
155+
)
156+
if title and not check_is_category(title, cat):
157+
print(
158+
f"Bad suggested regex '{cat[1]['regex']}'. Title '{title}' does not match category {category}."
159+
)
160+
else:
161+
categories.append(cat)
162+
elif line1.startswith("Modify Category:"):
163+
line2 = lines[0]
164+
category = line1.split(":", 1)[1].strip().split(" > ")
165+
assert line2.startswith("Append Regex:")
166+
regex = line2.strip().split(":", 1)[1].strip()
167+
# get existing category
168+
for name, rule in categories:
169+
if name == category:
170+
rule["regex"] += "|" + regex
171+
if not quiet:
172+
print(f"Appended {regex} to {category}")
173+
break
174+
elif line1.startswith("Skip:"):
175+
pass
176+
else:
177+
print(f"Unknown response: '{text.strip()}'")
178+
179+
return categories
180+
181+
182+
def main():
183+
categories = example_categories()
184+
events = get_events(categories)
185+
186+
events_by_dur = sorted(events, key=lambda e: e.duration, reverse=True)
187+
for event in events_by_dur[:100]:
188+
# re-categorize and skip event if it is already categorized by a new rule
189+
event, *_ = categorize(
190+
[event], [(list(name), Rule(rule)) for name, rule in categories]
191+
)
192+
if list(event.data["$category"]) != ["Uncategorized"]:
193+
continue
194+
195+
categories = gpt_suggest(event, categories)
196+
197+
198+
def test_parse_gpt_response():
199+
categories = example_categories()
200+
prompt = prompt_preamble(categories)
201+
categories = process_prompt(prompt, categories)
202+
203+
cat_twitter = list(
204+
(name, rule)
205+
for name, rule in categories
206+
if tuple(name) == ("Media", "Social", "Twitter")
207+
)
208+
assert len(cat_twitter) == 1
209+
assert cat_twitter[0][1]["regex"] == "Twitter|Tweetdeck"
210+
211+
212+
if __name__ == "__main__":
213+
main()

0 commit comments

Comments
 (0)