|
| 1 | +import json |
1 | 2 | import logging
|
2 | 3 | import os
|
3 | 4 | import signal
|
|
9 | 10 | from rich.console import Console
|
10 | 11 | from rich.logging import RichHandler
|
11 | 12 | from rich.syntax import Syntax
|
| 13 | +from tqdm import tqdm |
12 | 14 |
|
13 | 15 | from .benchmark import RagBenchmark
|
14 | 16 | from .indexing.indexer import Indexer
|
15 | 17 | from .indexing.watcher import FileWatcher
|
16 | 18 | from .query.context_assembler import ContextAssembler
|
17 | 19 |
|
18 | 20 | console = Console()
|
| 21 | +logger = logging.getLogger(__name__) |
19 | 22 |
|
20 | 23 | # TODO: change this to a more appropriate location
|
21 | 24 | default_persist_dir = Path.home() / ".cache" / "gptme" / "rag"
|
@@ -53,23 +56,45 @@ def index(paths: list[Path], pattern: str, persist_dir: Path):
|
53 | 56 |
|
54 | 57 | try:
|
55 | 58 | indexer = Indexer(persist_directory=persist_dir, enable_persist=True)
|
56 |
| - total_indexed = 0 |
57 |
| - |
58 |
| - for path in paths: |
59 |
| - if path.is_file(): |
60 |
| - console.print(f"Indexing file: {path}") |
61 |
| - n_indexed = indexer.index_file(path) |
62 |
| - if n_indexed is not None: |
63 |
| - total_indexed += n_indexed |
64 |
| - else: |
65 |
| - console.print(f"Indexing files in {path} with pattern {pattern}") |
66 |
| - n_indexed = indexer.index_directory(path, pattern) |
67 |
| - if n_indexed is not None: |
68 |
| - total_indexed += n_indexed |
69 | 59 |
|
70 |
| - console.print(f"✅ Successfully indexed {total_indexed} files", style="green") |
| 60 | + # First, collect all documents |
| 61 | + all_documents = [] |
| 62 | + with console.status("Collecting documents...") as status: |
| 63 | + for path in paths: |
| 64 | + if path.is_file(): |
| 65 | + status.update(f"Processing file: {path}") |
| 66 | + else: |
| 67 | + status.update(f"Processing directory: {path}") |
| 68 | + documents = indexer.collect_documents(path) |
| 69 | + all_documents.extend(documents) |
| 70 | + |
| 71 | + if not all_documents: |
| 72 | + console.print("No documents found to index", style="yellow") |
| 73 | + return |
| 74 | + |
| 75 | + # Then process them with a progress bar |
| 76 | + n_files = len(set(doc.metadata.get("source", "") for doc in all_documents)) |
| 77 | + n_chunks = len(all_documents) |
| 78 | + |
| 79 | + logger.info(f"Found {n_files} files to index ({n_chunks} chunks)") |
| 80 | + |
| 81 | + with tqdm( |
| 82 | + total=n_chunks, |
| 83 | + desc="Indexing documents", |
| 84 | + unit="chunk", |
| 85 | + disable=not sys.stdout.isatty(), |
| 86 | + ) as pbar: |
| 87 | + for progress in indexer.add_documents_progress(all_documents): |
| 88 | + pbar.update(progress) |
| 89 | + |
| 90 | + console.print( |
| 91 | + f"✅ Successfully indexed {n_files} files ({n_chunks} chunks)", |
| 92 | + style="green", |
| 93 | + ) |
71 | 94 | except Exception as e:
|
72 | 95 | console.print(f"❌ Error indexing directory: {e}", style="red")
|
| 96 | + if logger.isEnabledFor(logging.DEBUG): |
| 97 | + console.print_exception() |
73 | 98 |
|
74 | 99 |
|
75 | 100 | @cli.command()
|
@@ -111,8 +136,6 @@ def search(
|
111 | 136 | scoring_weights = None
|
112 | 137 | if weights:
|
113 | 138 | try:
|
114 |
| - import json |
115 |
| - |
116 | 139 | scoring_weights = json.loads(weights)
|
117 | 140 | except json.JSONDecodeError as e:
|
118 | 141 | console.print(f"❌ Invalid weights JSON: {e}", style="red")
|
|
0 commit comments