Skip to content

Commit 9995843

Browse files
fix: now if eval_llm returns something unparsable or with finish_reason=length we always retry
Signed-off-by: thiswillbeyourgithub <[email protected]>
1 parent 29a9dae commit 9995843

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

wdoc/wdoc.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,25 +1497,34 @@ def evaluate_doc_chain(
14971497
new_c = 0
14981498

14991499
elif "n" in self.eval_llm_params or self.query_eval_check_number == 1:
1500+
1501+
def _parse_outputs(out) -> List[str]:
1502+
reasons = [
1503+
gen.generation_info["finish_reason"] for gen in out.generations
1504+
]
1505+
outputs = [gen.text for gen in out.generations]
1506+
# don't always crash if finish_reason is not stop, because it can sometimes still be parsed.
1507+
if not all(r == "stop" for r in reasons):
1508+
red(
1509+
f"Unexpected generation finish_reason: '{reasons}' for generations: '{outputs}'. Expected 'stop'"
1510+
)
1511+
assert outputs, "No generations found by query eval llm"
1512+
# parse_eval_output will crash if the output is bad anyway
1513+
outputs = [parse_eval_output(o) for o in outputs]
1514+
return outputs
1515+
15001516
try:
15011517
out = self.eval_llm._generate_with_cache(
15021518
prompts.evaluate.format_messages(**inputs)
15031519
)
1520+
outputs = _parse_outputs(out)
15041521
except Exception: # retry without cache
1522+
yel(f"Failed to run eval_llm on an input. Retrying without cache.")
15051523
out = self.eval_llm._generate(
15061524
prompts.evaluate.format_messages(**inputs)
15071525
)
1508-
reasons = [
1509-
gen.generation_info["finish_reason"] for gen in out.generations
1510-
]
1511-
outputs = [gen.text for gen in out.generations]
1512-
# don't crash if finish_reason is not stop, because it can sometimes still be parsed.
1513-
if not all(r in ["stop", "length"] for r in reasons):
1514-
red(
1515-
f"Unexpected generation finish_reason: '{reasons}' for generations: '{outputs}'"
1516-
)
1517-
assert outputs, "No generations found by query eval llm"
1518-
outputs = [parse_eval_output(o) for o in outputs]
1526+
outputs = _parse_outputs(out)
1527+
15191528
if out.llm_output:
15201529
new_p = out.llm_output["token_usage"]["prompt_tokens"]
15211530
new_c = out.llm_output["token_usage"]["completion_tokens"]
@@ -1533,7 +1542,7 @@ async def do_eval(subinputs):
15331542
val = await self.eval_llm._agenerate_with_cache(
15341543
prompts.evaluate.format_messages(**subinputs)
15351544
)
1536-
except Exception: # retrywithout cache
1545+
except Exception: # retry without cache
15371546
val = await self.eval_llm._agenerate(
15381547
prompts.evaluate.format_messages(**subinputs)
15391548
)

0 commit comments

Comments
 (0)