Skip to content

Commit 689e915

Browse files
committed
chore: add langchain recursive strategy
1 parent a7f8232 commit 689e915

File tree

5 files changed

+187
-24
lines changed

5 files changed

+187
-24
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rag-chunk"
3-
version = "0.2.0"
3+
version = "0.3.0"
44
description = "CLI tool to parse, chunk, and evaluate Markdown documents for RAG pipelines with token-accurate chunking support"
55
authors = [ { name = "messkan" } ]
66
license = { text = "MIT" }
@@ -23,3 +23,5 @@ build-backend = "setuptools.build_meta"
2323
[project.optional-dependencies]
2424
rich = ["rich>=12.0.0"]
2525
tiktoken = ["tiktoken>=0.5.0"]
26+
langchain = ["langchain>=0.1.0", "langchain-text-splitters>=0.0.1"]
27+
all = ["rich>=12.0.0", "tiktoken>=0.5.0", "langchain>=0.1.0", "langchain-text-splitters>=0.0.1"]

src/chunker.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010
TIKTOKEN_AVAILABLE = False
1111
tiktoken = None
1212

13+
try:
14+
from langchain_text_splitters import RecursiveCharacterTextSplitter
15+
16+
LANGCHAIN_AVAILABLE = True
17+
except ImportError:
18+
LANGCHAIN_AVAILABLE = False
19+
RecursiveCharacterTextSplitter = None
20+
1321

1422
def tokenize(
1523
text: str, use_tiktoken: bool = False, model: str = "gpt-3.5-turbo"
@@ -128,6 +136,57 @@ def paragraph_chunks(text: str) -> List[Dict]:
128136
return chunks
129137

130138

139+
def recursive_character_chunks(
140+
text: str,
141+
chunk_size: int = 200,
142+
overlap: int = 50,
143+
use_tiktoken: bool = False,
144+
model: str = "gpt-3.5-turbo",
145+
) -> List[Dict]:
146+
"""Split text using LangChain's RecursiveCharacterTextSplitter.
147+
148+
Recursively splits by paragraphs, sentences, then words for semantic coherence.
149+
150+
Args:
151+
text: Text to chunk
152+
chunk_size: Target size per chunk (words or tokens)
153+
overlap: Overlap between chunks
154+
use_tiktoken: If True, use tiktoken for token-based chunking
155+
model: Model name for tiktoken encoding
156+
157+
Returns:
158+
List of chunk dictionaries with 'id' and 'text' keys
159+
"""
160+
if not LANGCHAIN_AVAILABLE:
161+
raise ImportError(
162+
"LangChain is required for recursive-character strategy. "
163+
"Install with: pip install rag-chunk[langchain]"
164+
)
165+
166+
if use_tiktoken:
167+
if not TIKTOKEN_AVAILABLE:
168+
raise ImportError(
169+
"tiktoken is required for token-based chunking. "
170+
"Install with: pip install rag-chunk[tiktoken]"
171+
)
172+
import tiktoken
173+
174+
enc = tiktoken.encoding_for_model(model)
175+
splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
176+
encoding_name=enc.name, chunk_size=chunk_size, chunk_overlap=overlap
177+
)
178+
else:
179+
splitter = RecursiveCharacterTextSplitter(
180+
chunk_size=chunk_size,
181+
chunk_overlap=overlap,
182+
length_function=len,
183+
separators=["\n\n", "\n", ". ", " ", ""],
184+
)
185+
186+
texts = splitter.split_text(text)
187+
return [{"id": i, "text": t} for i, t in enumerate(texts)]
188+
189+
131190
STRATEGIES = {
132191
"fixed-size": (
133192
lambda text, chunk_size=200, overlap=0, use_tiktoken=False, model="gpt-3.5-turbo":
@@ -152,4 +211,14 @@ def paragraph_chunks(text: str) -> List[Dict]:
152211
lambda text, chunk_size=0, overlap=0, use_tiktoken=False, model="gpt-3.5-turbo":
153212
paragraph_chunks(text)
154213
),
214+
"recursive-character": (
215+
lambda text, chunk_size=200, overlap=50, use_tiktoken=False, model="gpt-3.5-turbo":
216+
recursive_character_chunks(
217+
text,
218+
chunk_size,
219+
overlap,
220+
use_tiktoken=use_tiktoken,
221+
model=model
222+
)
223+
),
155224
}

src/cli.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -107,21 +107,24 @@ def _run_strategy(text, func, strat, args):
107107
)
108108
outdir = write_chunks(chunks, strat)
109109

110-
avg_recall, per_questions = 0.0, []
110+
metrics = {"avg_recall": 0.0, "avg_precision": 0.0, "avg_f1": 0.0}
111+
per_questions = []
111112
questions = (
112113
scorer.load_test_file(args.test_file)
113114
if getattr(args, "test_file", None)
114115
else None
115116
)
116117
if questions:
117-
avg_recall, per_questions = scorer.evaluate_strategy(
118+
metrics, per_questions = scorer.evaluate_strategy(
118119
chunks, questions, args.top_k
119120
)
120121

121122
return {
122123
"strategy": strat,
123124
"chunks": len(chunks),
124-
"avg_recall": round(avg_recall, 4),
125+
"avg_recall": round(metrics["avg_recall"], 4),
126+
"avg_precision": round(metrics["avg_precision"], 4),
127+
"avg_f1": round(metrics["avg_f1"], 4),
125128
"saved": str(outdir),
126129
}, per_questions
127130

@@ -137,27 +140,55 @@ def _write_results(results, detail, output):
137140
table.add_column("strategy", style="cyan")
138141
table.add_column("chunks", justify="right")
139142
table.add_column("avg_recall", justify="right")
143+
table.add_column("avg_precision", justify="right")
144+
table.add_column("avg_f1", justify="right")
140145
table.add_column("saved")
141146
for r in results:
142-
avg = r.get("avg_recall", 0.0)
147+
recall = r.get("avg_recall", 0.0)
148+
precision = r.get("avg_precision", 0.0)
149+
f1 = r.get("avg_f1", 0.0)
150+
151+
# Format recall with color
143152
try:
144-
pct = f"{avg*100:.2f}%"
153+
recall_pct = f"{recall*100:.2f}%"
145154
except (TypeError, ValueError):
146-
pct = str(avg)
147-
if isinstance(avg, float):
148-
if avg >= 0.85:
155+
recall_pct = str(recall)
156+
if isinstance(recall, float):
157+
if recall >= 0.85:
149158
color = "green"
150-
elif avg >= 0.7:
159+
elif recall >= 0.7:
151160
color = "yellow"
152161
else:
153162
color = "red"
154-
pct_cell = f"[{color}]{pct}[/{color}]"
163+
recall_cell = f"[{color}]{recall_pct}[/{color}]"
155164
else:
156-
pct_cell = pct
165+
recall_cell = recall_pct
166+
167+
# Format precision
168+
precision_pct = f"{precision*100:.2f}%" if isinstance(precision, float) else str(precision)
169+
170+
# Format F1 with color
171+
try:
172+
f1_pct = f"{f1*100:.2f}%"
173+
except (TypeError, ValueError):
174+
f1_pct = str(f1)
175+
if isinstance(f1, float):
176+
if f1 >= 0.85:
177+
color = "green"
178+
elif f1 >= 0.7:
179+
color = "yellow"
180+
else:
181+
color = "red"
182+
f1_cell = f"[{color}]{f1_pct}[/{color}]"
183+
else:
184+
f1_cell = f1_pct
185+
157186
table.add_row(
158187
str(r.get("strategy", "")),
159188
str(r.get("chunks", "")),
160-
pct_cell,
189+
recall_cell,
190+
precision_pct,
191+
f1_cell,
161192
str(r.get("saved", "")),
162193
)
163194
console.print(table)
@@ -172,9 +203,11 @@ def _write_results(results, detail, output):
172203
wpath = Path("analysis_results.csv")
173204
with wpath.open("w", newline="", encoding="utf-8") as f:
174205
w = csv.writer(f)
175-
w.writerow(["strategy", "chunks", "avg_recall", "saved"])
206+
w.writerow(["strategy", "chunks", "avg_recall", "avg_precision",
207+
"avg_f1", "saved"])
176208
for r in results:
177-
w.writerow([r["strategy"], r["chunks"], r["avg_recall"], r["saved"]])
209+
w.writerow([r["strategy"], r["chunks"], r["avg_recall"],
210+
r["avg_precision"], r["avg_f1"], r["saved"]])
178211
print(str(wpath))
179212
return
180213
print("Unsupported output format")
@@ -191,7 +224,8 @@ def build_parser():
191224
"--strategy",
192225
type=str,
193226
default="fixed-size",
194-
choices=["fixed-size", "sliding-window", "paragraph", "all"],
227+
choices=["fixed-size", "sliding-window", "paragraph",
228+
"recursive-character", "all"],
195229
help="Chunking strategy or all",
196230
)
197231
analyze_p.add_argument(

src/parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55

66
def read_markdown_folder(folder: str) -> list:
7-
"""Return list of (path, text) for all .md files in folder (non-recursive)."""
7+
"""Return list of (path, text) for all .md and .txt files in folder (non-recursive)."""
88
p = Path(folder)
9-
files = [f for f in p.iterdir() if f.is_file() and f.suffix.lower() == ".md"]
9+
files = [f for f in p.iterdir() if f.is_file() and f.suffix.lower() in [".md", ".txt"]]
1010
result = []
1111
for f in files:
1212
try:

src/scorer.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,76 @@ def compute_recall(retrieved: List[Dict], relevant_phrases: List[str]) -> float:
4545
return found / len(relevant_phrases)
4646

4747

48+
def compute_precision_recall_f1(
49+
retrieved: List[Dict], relevant_phrases: List[str]
50+
) -> Tuple[float, float, float]:
51+
"""Compute precision, recall, and F1 score.
52+
53+
Args:
54+
retrieved: List of retrieved chunk dictionaries
55+
relevant_phrases: List of phrases that should be found
56+
57+
Returns:
58+
Tuple of (precision, recall, f1)
59+
"""
60+
if not relevant_phrases:
61+
return 0.0, 0.0, 0.0
62+
63+
lower_texts = [c["text"].lower() for c in retrieved]
64+
found_phrases = set()
65+
for phrase in relevant_phrases:
66+
lp = phrase.lower()
67+
if any(lp in t for t in lower_texts):
68+
found_phrases.add(phrase)
69+
70+
tp = len(found_phrases) # True positives
71+
fn = len(relevant_phrases) - tp # False negatives
72+
# For precision: assume each relevant phrase found is a "correct" retrieval
73+
# FP = 0 in this simplified model (we only check relevant phrases)
74+
fp = 0
75+
76+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
77+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
78+
f1 = (
79+
2 * precision * recall / (precision + recall)
80+
if (precision + recall) > 0
81+
else 0.0
82+
)
83+
84+
return precision, recall, f1
85+
86+
4887
def evaluate_strategy(
4988
chunks: List[Dict], questions: List[Dict], top_k: int
50-
) -> Tuple[float, List[Dict]]:
51-
"""Return average recall and per-question details."""
89+
) -> Tuple[Dict, List[Dict]]:
90+
"""Return average metrics and per-question details.
91+
92+
Returns:
93+
Tuple of (metrics_dict, per_question_list)
94+
metrics_dict contains: avg_recall, avg_precision, avg_f1
95+
"""
5296
per = []
5397
recalls = []
98+
precisions = []
99+
f1s = []
54100
for q in questions:
55101
question = q.get("question", "")
56102
relevant = q.get("relevant", [])
57103
retrieved = retrieve_top_k(chunks, question, top_k)
58-
recall = compute_recall(retrieved, relevant)
104+
precision, recall, f1 = compute_precision_recall_f1(retrieved, relevant)
59105
recalls.append(recall)
60-
per.append({"question": question, "recall": recall})
61-
avg = sum(recalls) / len(recalls) if recalls else 0.0
62-
return avg, per
106+
precisions.append(precision)
107+
f1s.append(f1)
108+
per.append({
109+
"question": question,
110+
"recall": recall,
111+
"precision": precision,
112+
"f1": f1
113+
})
114+
115+
metrics = {
116+
"avg_recall": sum(recalls) / len(recalls) if recalls else 0.0,
117+
"avg_precision": sum(precisions) / len(precisions) if precisions else 0.0,
118+
"avg_f1": sum(f1s) / len(f1s) if f1s else 0.0,
119+
}
120+
return metrics, per

0 commit comments

Comments
 (0)