Skip to content

Commit b0f424b

Browse files
committed
feat: fix persistence and enable in cli
1 parent 085deea commit b0f424b

File tree

2 files changed

+71
-34
lines changed

2 files changed

+71
-34
lines changed

gptme_rag/cli.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,32 @@
1+
import logging
12
import os
3+
import signal
24
import sys
5+
import time
36
from pathlib import Path
47

58
import click
69
from rich.console import Console
710

11+
from .benchmark import RagBenchmark
812
from .indexing.indexer import Indexer
13+
from .indexing.watcher import FileWatcher
914
from .query.context_assembler import ContextAssembler
1015

1116
console = Console()
1217

1318
# TODO: change this to a more appropriate location
14-
default_persist_dir = Path(__file__).parent / "data"
19+
default_persist_dir = Path.home() / ".cache" / "gptme" / "rag"
1520

1621

1722
@click.group()
18-
def cli():
23+
@click.option("--verbose/-v", is_flag=True, help="Enable verbose output")
24+
def cli(verbose: bool):
1925
"""RAG implementation for gptme context management."""
20-
pass
26+
logging.basicConfig(
27+
level=logging.DEBUG if verbose else logging.INFO,
28+
format="%(levelname)s - %(name)s - %(message)s",
29+
)
2130

2231

2332
@cli.command()
@@ -36,20 +45,13 @@ def cli():
3645
def index(directory: Path, pattern: str, persist_dir: Path):
3746
"""Index documents in a directory."""
3847
try:
39-
indexer = Indexer(persist_directory=persist_dir)
48+
indexer = Indexer(persist_directory=persist_dir, enable_persist=True)
4049
console.print(f"Indexing files in {directory} with pattern {pattern}")
4150

42-
# List files that will be indexed
43-
files = list(directory.glob(pattern))
44-
console.print(f"Found {len(files)} files:")
45-
for file in files:
46-
console.print(f" - {file}")
47-
4851
# Index the files
49-
with console.status(f"Indexing {len(files)} files..."):
50-
indexer.index_directory(directory, pattern)
52+
n_indexed = indexer.index_directory(directory, pattern)
5153

52-
console.print(f"✅ Successfully indexed {len(files)} files", style="green")
54+
console.print(f"✅ Successfully indexed {n_indexed} files", style="green")
5355
except Exception as e:
5456
console.print(f"❌ Error indexing directory: {e}", style="red")
5557

@@ -80,7 +82,7 @@ def search(
8082
stdout = sys.stdout
8183
sys.stdout = open(os.devnull, "w")
8284
try:
83-
indexer = Indexer(persist_directory=persist_dir)
85+
indexer = Indexer(persist_directory=persist_dir, enable_persist=True)
8486
assembler = ContextAssembler(max_tokens=max_tokens)
8587
documents, distances = indexer.search(query, n_results=n_results)
8688
finally:
@@ -165,7 +167,6 @@ def watch(directory: Path, pattern: str, persist_dir: Path, ignore_patterns: lis
165167
indexer.index_directory(directory, pattern)
166168

167169
console.print("Starting file watcher...")
168-
from .indexing.watcher import FileWatcher
169170

170171
try:
171172
file_watcher = FileWatcher(
@@ -174,14 +175,11 @@ def watch(directory: Path, pattern: str, persist_dir: Path, ignore_patterns: lis
174175
with file_watcher:
175176
console.print("Watching for changes. Press Ctrl+C to stop.")
176177
# Keep the main thread alive
177-
import signal
178178

179179
try:
180180
signal.pause()
181181
except AttributeError: # Windows doesn't have signal.pause
182182
while True:
183-
import time
184-
185183
time.sleep(1)
186184
except KeyboardInterrupt:
187185
console.print("\nStopping file watcher...")
@@ -212,7 +210,6 @@ def benchmark():
212210
)
213211
def indexing(directory: Path, pattern: str, persist_dir: Path | None):
214212
"""Benchmark document indexing performance."""
215-
from .benchmark import RagBenchmark
216213

217214
benchmark = RagBenchmark(index_dir=persist_dir)
218215

@@ -252,7 +249,6 @@ def search_benchmark(
252249
persist_dir: Path | None,
253250
):
254251
"""Benchmark search performance."""
255-
from .benchmark import RagBenchmark
256252

257253
benchmark = RagBenchmark(index_dir=persist_dir)
258254

@@ -296,7 +292,6 @@ def watch_perf(
296292
persist_dir: Path | None,
297293
):
298294
"""Benchmark file watching performance."""
299-
from .benchmark import RagBenchmark
300295

301296
benchmark = RagBenchmark(index_dir=persist_dir)
302297

gptme_rag/indexing/indexer.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def add_documents(self, documents: list[Document], batch_size: int = 100) -> Non
172172

173173
def _load_gitignore(self, directory: Path) -> list[str]:
174174
"""Load gitignore patterns from all .gitignore files up to root."""
175-
patterns: list[str] = []
175+
patterns: list[str] = [".git/", ".sqlite3", ".db"]
176176
current_dir = directory.resolve()
177177
max_depth = 10 # Limit traversal to avoid infinite loops
178178

@@ -225,14 +225,10 @@ def index_directory(
225225
gitignore_patterns = self._load_gitignore(directory)
226226

227227
# Filter files
228-
valid_files = []
228+
valid_files = set()
229229
for f in files:
230-
if (
231-
f.is_file()
232-
and not f.name.endswith((".sqlite3", ".db"))
233-
and not self._is_ignored(f, gitignore_patterns)
234-
):
235-
valid_files.append(f)
230+
if f.is_file() and not self._is_ignored(f, gitignore_patterns):
231+
valid_files.add(f)
236232

237233
# Check file limit
238234
if len(valid_files) >= file_limit:
@@ -257,20 +253,20 @@ def index_directory(
257253
current_batch = []
258254

259255
for file_path in valid_files:
256+
logger.debug(f"Processing file: {file_path}")
260257
# Process each file into chunks
261258
for doc in Document.from_file(file_path, processor=self.processor):
259+
logger.debug(f"Processing chunk: {doc.source_path} ({doc.chunk_index})")
262260
current_batch.append(doc)
263261
if len(current_batch) >= batch_size:
262+
logger.info(f"Adding {len(current_batch)} documents")
264263
self.add_documents(current_batch)
265264
current_batch = []
266265

267266
# Add any remaining documents
268267
if current_batch:
269-
logger.debug(
270-
f"Adding {len(current_batch)} remaining documents. "
271-
f"First doc preview: {current_batch[0].content[:100]}. "
272-
f"Paths: {[doc.source_path for doc in current_batch]}"
273-
)
268+
self.add_documents(current_batch)
269+
logger.info(f"Adding {len(current_batch)} documents.")
274270
self.add_documents(current_batch)
275271

276272
logger.info(f"Indexed {len(valid_files)} documents from {directory}")
@@ -340,6 +336,52 @@ def search(
340336

341337
return documents, distances[: len(documents)]
342338

339+
def list_documents(self, group_by_source: bool = True) -> list[Document]:
340+
"""List all documents in the index.
341+
342+
Args:
343+
group_by_source: Whether to group chunks from the same document
344+
345+
Returns:
346+
List of documents
347+
"""
348+
# Get all documents from collection
349+
results = self.collection.get()
350+
351+
if not results["ids"]:
352+
return []
353+
354+
if group_by_source:
355+
# Group chunks by source document
356+
doc_groups: dict[str, list[Document]] = {}
357+
358+
for i, doc_id in enumerate(results["ids"]):
359+
doc = Document(
360+
content=results["documents"][i],
361+
metadata=results["metadatas"][i],
362+
doc_id=doc_id,
363+
)
364+
365+
# Get source document ID (remove chunk suffix if present)
366+
source_id = doc_id.split("#chunk")[0]
367+
368+
if source_id not in doc_groups:
369+
doc_groups[source_id] = []
370+
doc_groups[source_id].append(doc)
371+
372+
# Return first chunk from each document group
373+
return [chunks[0] for chunks in doc_groups.values()]
374+
else:
375+
# Return all documents/chunks
376+
return [
377+
Document(
378+
content=results["documents"][i],
379+
metadata=results["metadatas"][i],
380+
doc_id=doc_id,
381+
)
382+
for i, doc_id in enumerate(results["ids"])
383+
]
384+
343385
def get_document_chunks(self, doc_id: str) -> list[Document]:
344386
"""Get all chunks for a document.
345387

0 commit comments

Comments
 (0)