Skip to content

Commit 9eccb74

Browse files
Johannes HötterJohannes Hötter
authored andcommitted
change to graphql and add autolf v1
1 parent 2497c03 commit 9eccb74

File tree

8 files changed

+333
-151
lines changed

8 files changed

+333
-151
lines changed

onetask/__init__.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
# -*- coding: utf-8 -*-
22

3-
from typing import Callable, List
4-
5-
from onetask import api_calls, settings
3+
from typing import Callable
64
from wasabi import msg
7-
8-
from onetask.labeling_function import build_keywords_lf, unpack_python_function
5+
from onetask import api_calls, settings, util, auto_lf
96

107

118
class Client:
@@ -16,38 +13,43 @@ def __init__(
1613
self.session_token = api_calls.create_session_token(
1714
user_name=user_name, password=password
1815
)
16+
self.project_id = project_id
1917
if self.session_token is not None:
2018
msg.good("Logged in to system.")
19+
if not api_calls.ProjectByProjectId(
20+
self.project_id, self.session_token
21+
).exists:
22+
msg.fail(f"Project with ID {self.project_id} does not exist.")
2123
else:
2224
msg.fail("Could not log in. Please check your username and password.")
23-
self.project_id = project_id
2425

25-
def register_custom_lf(self, lf: Callable) -> None:
26-
fn_name, source_code, description = unpack_python_function(lf)
27-
_ = api_calls.RegisterLabelingFunctionCall(
28-
fn_name=fn_name,
29-
source_code=source_code,
30-
description=description,
31-
project_id=self.project_id,
32-
session_token=self.session_token,
33-
)
34-
msg.good(f"Registered labeling function '{fn_name}'.")
26+
def manually_labeled_records(self, as_df: bool = True):
27+
fetched_records = api_calls.ManuallyLabeledRecords(
28+
self.project_id, self.session_token
29+
).data
30+
records = util.unpack_records(fetched_records)
31+
if as_df and len(records) > 0:
32+
return util.records_to_df(records)
33+
else:
34+
return records
3535

36-
def register_keywords_lf(
37-
self,
38-
label: str,
39-
keywords: List[str],
40-
attributes: List[str],
41-
lowercase: bool = True,
36+
def autogenerate_regex_labeling_functions(
37+
self, nlp, attribute, num_functions: int = 10
4238
):
43-
fn_name, source_code, description = build_keywords_lf(
44-
label, keywords, attributes, lowercase
39+
records = self.manually_labeled_records(as_df=True)
40+
if len(records) > 0:
41+
candidates = auto_lf.derive_regex_candidates(
42+
nlp, records, attribute, most_common=num_functions
43+
)
44+
auto_lf.create_regex_fns(records, candidates, attribute)
45+
else:
46+
msg.fail("No manually labeled records available!")
47+
48+
def register_lf(self, lf: Callable) -> None:
49+
project_id, name, source_code, docs = util.unpack_python_function(
50+
lf, self.project_id
4551
)
46-
_ = api_calls.RegisterLabelingFunctionCall(
47-
fn_name=fn_name,
48-
source_code=source_code,
49-
description=description,
50-
project_id=self.project_id,
51-
session_token=self.session_token,
52+
api_calls.CreateLabelingFunction(
53+
project_id, name, source_code, docs, self.session_token
5254
)
53-
msg.good(f"Registered labeling function '{fn_name}'.")
55+
msg.good(f"Registered labeling function '{name}'.")

onetask/api_calls.py

Lines changed: 104 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
# -*- coding: utf-8 -*-
22
import pkg_resources
3-
43
from onetask import exceptions, settings
4+
import requests
55

66
try:
77
version = pkg_resources.get_distribution("onetask").version
88
except pkg_resources.DistributionNotFound:
99
version = "noversion"
1010

11-
import requests
12-
from typing import Dict, Any, Optional, Union
13-
14-
from better_abc import ABCMeta, abstract_attribute
15-
1611

1712
# no call to the onetask system, therefore include it here
1813
def create_session_token(user_name: str, password: str):
@@ -40,32 +35,29 @@ def create_session_token(user_name: str, password: str):
4035
return session_token
4136

4237

43-
class OneTaskCall(metaclass=ABCMeta):
44-
def __init__(
45-
self, url: str, session_token: str, data: Optional[Dict[str, Any]] = None
46-
):
38+
class GraphQLRequest:
39+
def __init__(self, query, variables, session_token):
40+
self.query = query
41+
self.variables = variables
42+
self.session_token = session_token
43+
44+
def execute(self):
45+
body = {
46+
"query": self.query,
47+
"variables": self.variables,
48+
}
49+
4750
headers = {
4851
"Content-Type": "application/json",
4952
"User-Agent": f"python-sdk-{version}",
50-
"Authorization": f"Bearer {session_token}",
53+
"Authorization": f"Bearer {self.session_token}",
5154
}
5255

53-
if data is None:
54-
self.response = requests.request(self.method, url, headers=headers)
55-
else:
56-
self.response = requests.request(
57-
self.method, url, json=data, headers=headers
58-
)
59-
60-
@abstract_attribute
61-
def method(self):
62-
pass
56+
response = requests.post(url=settings.graphql(), json=body, headers=headers)
6357

64-
@property
65-
def content(self) -> Union[Dict[str, Any], exceptions.APIError]:
66-
status_code = self.response.status_code
58+
status_code = response.status_code
6759

68-
json_data = self.response.json()
60+
json_data = response.json()
6961

7062
if status_code == 200:
7163
return json_data
@@ -80,35 +72,95 @@ def content(self) -> Union[Dict[str, Any], exceptions.APIError]:
8072
raise exception
8173

8274

83-
class PostCall(OneTaskCall):
84-
def __init__(
85-
self,
86-
url: str,
87-
session_token: str,
88-
data: Dict[str, Any],
89-
):
90-
self.method = "POST"
75+
class ProjectByProjectId(GraphQLRequest):
76+
def __init__(self, project_id, session_token):
77+
QUERY = """
78+
query ($projectId: ID!) {
79+
projectByProjectId(projectId: $projectId) {
80+
id
81+
labels {
82+
edges {
83+
node {
84+
name
85+
}
86+
}
87+
}
88+
}
89+
}
90+
"""
91+
92+
variables = {
93+
"projectId": project_id,
94+
}
9195

92-
super().__init__(url=url, session_token=session_token, data=data)
96+
super().__init__(QUERY, variables, session_token)
97+
try:
98+
self.data = self.execute()
99+
self.exists = self.data.get("data").get("projectByProjectId") is not None
100+
except exceptions.APIError:
101+
self.exists = False
102+
103+
104+
class ManuallyLabeledRecords(GraphQLRequest):
105+
def __init__(self, project_id, session_token):
106+
data = ProjectByProjectId(project_id, session_token).data
107+
edges = data["data"]["projectByProjectId"]["labels"]["edges"]
108+
manual = [edge["node"]["name"] for edge in edges]
109+
110+
QUERY = """
111+
query ($projectId: ID!, $manual: [String!]) {
112+
searchRecords(projectId: $projectId, manual: $manual) {
113+
data
114+
labelAssociations {
115+
edges {
116+
node {
117+
label {
118+
name
119+
}
120+
source
121+
}
122+
}
123+
}
124+
}
125+
}
93126
127+
"""
128+
129+
variables = {"projectId": project_id, "manual": manual}
130+
131+
super().__init__(QUERY, variables, session_token)
132+
self.data = self.execute()
133+
134+
135+
class CreateLabelingFunction(GraphQLRequest):
136+
def __init__(self, project_id, name, function, description, session_token):
137+
QUERY = """
138+
mutation (
139+
$projectId: ID!,
140+
$name: String!,
141+
$function: String!,
142+
$description: String!
143+
) {
144+
createLabelingFunction(
145+
projectId: $projectId,
146+
name: $name,
147+
function:
148+
$function,
149+
description: $description
150+
) {
151+
labelingFunction {
152+
id
153+
}
154+
}
155+
}
156+
"""
94157

95-
class RegisterLabelingFunctionCall(PostCall):
96-
def __init__(
97-
self,
98-
fn_name: str,
99-
source_code: str,
100-
description: str,
101-
project_id: str,
102-
session_token: str,
103-
):
104-
body = {
105-
"project_id": project_id,
106-
"name": fn_name,
107-
"function": source_code,
158+
variables = {
159+
"projectId": project_id,
160+
"name": name,
161+
"function": function,
108162
"description": description,
109163
}
110-
super().__init__(
111-
url=settings.get_labeling_function_url(),
112-
session_token=session_token,
113-
data=body,
114-
)
164+
165+
super().__init__(QUERY, variables, session_token)
166+
_ = self.execute()

onetask/auto_lf.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from collections import defaultdict
2+
from tqdm import tqdm
3+
from collections import Counter
4+
import re
5+
from collections import defaultdict
6+
import numpy as np
7+
from wasabi import msg
8+
9+
10+
def derive_regex_candidates(nlp, df, attribute, most_common=10):
11+
if len(df) < 100:
12+
msg.warn(
13+
"Only very few records to analyze; it's best to continue labeling further records before analysis"
14+
)
15+
16+
def normalize_token(token):
17+
if "d" in token.shape_ and not "x" in token.shape_:
18+
return token.shape_.replace("d", "[0-9]")
19+
else:
20+
return token.text
21+
22+
def is_relevant_token(token):
23+
return not (token.is_punct or token.is_stop or token.is_bracket)
24+
25+
candidates = []
26+
for text in tqdm(df[attribute], total=len(df)):
27+
doc = nlp(text.lower())
28+
for token in doc:
29+
if is_relevant_token(token):
30+
has_children = False
31+
for token_left in token.lefts:
32+
if is_relevant_token(token_left):
33+
prefix = "^" if token_left.idx == 0 else " "
34+
suffix = "$" if token.idx == len(doc) - 1 else " "
35+
candidate = f"{prefix}{normalize_token(token_left)}.*?{normalize_token(token)}{suffix}"
36+
candidates.append(candidate)
37+
has_children = True
38+
for token_right in token.rights:
39+
if is_relevant_token(token_right):
40+
prefix = "^" if token.idx == 0 else " "
41+
suffix = "$" if token_right.idx == len(doc) - 1 else " "
42+
candidate = f"{prefix}{normalize_token(token)}.*?{normalize_token(token_right)}{suffix}"
43+
candidates.append(candidate)
44+
has_children = True
45+
if not has_children:
46+
prefix = "^" if token.idx == 0 else " "
47+
suffix = "$" if token.idx == len(doc) - 1 else " "
48+
candidate = f"{prefix}{normalize_token(token)}{suffix}"
49+
candidates.append(candidate)
50+
return [regex for regex, _ in Counter(candidates).most_common(most_common)]
51+
52+
53+
def create_regex_fns(df, candidates, regex_col, label_col="label"):
54+
def regex_explainer(regex, attribute):
55+
description = ""
56+
terms = regex.replace("^", "").replace("$", "").split(".*?")
57+
if "^" in regex:
58+
description += f"attribute '{attribute}' starts with term '{terms[0]}'"
59+
if len(terms) > 1:
60+
for term in terms[1:]:
61+
description += f" (in-)directly followed by term '{term}'"
62+
if "$" in regex:
63+
description += " and then ends"
64+
elif "$" in regex:
65+
description += (
66+
f"attribute '{attribute}' somewhere contains term '{terms[0]}'"
67+
)
68+
if len(terms) > 1:
69+
for term in terms[1:]:
70+
description += f" (in-)directly followed by term '{term}'"
71+
description += " and then ends"
72+
else:
73+
description += (
74+
f"attribute '{attribute}' somewhere contains term '{terms[0]}'"
75+
)
76+
if len(terms) > 1:
77+
for term in terms[1:]:
78+
description += f" followed by term '{term}'"
79+
if "[0-9]" in regex:
80+
description += ", where [0-9] is an arbitrary number"
81+
description += "."
82+
return description
83+
84+
def build_regex_lf(regex, attribute, prediction, iteration):
85+
source_code = f"""
86+
def regex_{iteration}(record):
87+
'''{regex_explainer(regex, attribute)}'''
88+
import re
89+
if re.search(r'{regex}', record['{attribute}'].lower()):
90+
return '{prediction}'
91+
92+
client.register_lf(regex_{iteration})
93+
"""
94+
95+
return source_code.strip()
96+
97+
regex_nr = 1
98+
for regex in candidates:
99+
labels = defaultdict(int)
100+
for text, label in zip(df[regex_col], df[label_col]):
101+
if re.search(regex, text.lower()):
102+
labels[label] += 1
103+
coverage = sum(labels.values())
104+
if coverage > 0:
105+
regex_prediction, max_count = None, 0
106+
for prediction, count in labels.items():
107+
if count > max_count:
108+
max_count = count
109+
regex_prediction = prediction
110+
precision = np.round(labels[regex_prediction] / coverage, 2)
111+
coverage = np.round(coverage / len(df), 2)
112+
if precision > 0.75 and coverage >= 0.01:
113+
lf = build_regex_lf(regex, regex_col, regex_prediction, regex_nr)
114+
regex_nr += 1
115+
print(f"# Cov:\t{coverage}\tPrec:{precision}")
116+
print(lf)
117+
print()

0 commit comments

Comments
 (0)