Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ default-optional-dependency-keys = [
]

[project.optional-dependencies]
cli = [
"rich>=14.3.2"
]
mssql = [
"mssql-python>=1.0.0"
]
Expand Down
98 changes: 88 additions & 10 deletions src/databao_context_engine/build_sources/build_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from databao_context_engine.datasources.types import DatasourceId
from databao_context_engine.pluginlib.build_plugin import DatasourceType
from databao_context_engine.plugins.plugin_loader import load_plugins
from databao_context_engine.progress.progress import DatasourceStatus, ProgressCallback, ProgressEmitter
from databao_context_engine.project.layout import ProjectLayout

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,9 +51,7 @@ class IndexSummary:


def build(
project_layout: ProjectLayout,
*,
build_service: BuildService,
project_layout: ProjectLayout, *, build_service: BuildService, progress: ProgressCallback | None = None
) -> list[BuildContextResult]:
"""Build the context for all datasources in the project.

Expand All @@ -70,34 +69,54 @@ def build(

datasource_ids = discover_datasources(project_layout)

emitter = ProgressEmitter(progress)

if not datasource_ids:
logger.info("No sources discovered under %s", project_layout.src_dir)
emitter.task_started(total_datasources=0)
emitter.task_finished(ok=0, failed=0, skipped=0)
return []

emitter.task_started(total_datasources=len(datasource_ids))

number_of_failed_builds = 0
number_of_skipped_builds = 0

build_result = []
reset_all_results(project_layout.output_dir)
for datasource_id in datasource_ids:
for datasource_index, datasource_id in enumerate(datasource_ids, start=1):
try:
prepared_source = prepare_source(project_layout, datasource_id)

logger.info(
f'Found datasource of type "{prepared_source.datasource_type.full_type}" with name {prepared_source.datasource_id.datasource_path}'
)

emitter.datasource_started(
datasource_id=str(datasource_id),
index=datasource_index,
total=len(datasource_ids),
)
plugin = plugins.get(prepared_source.datasource_type)
if plugin is None:
logger.warning(
"No plugin for '%s' (datasource=%s) — skipping.",
prepared_source.datasource_type.full_type,
prepared_source.datasource_id.relative_path_to_config_file(),
)
number_of_failed_builds += 1
emitter.datasource_finished(
datasource_id=str(datasource_id),
index=datasource_index,
total=len(datasource_ids),
status=DatasourceStatus.SKIPPED,
)
number_of_skipped_builds += 1
continue

result = build_service.process_prepared_source(
prepared_source=prepared_source,
plugin=plugin,
progress=progress,
)

output_dir = project_layout.output_dir
Expand All @@ -113,23 +132,46 @@ def build(
context_file_path=context_file_path,
)
)
emitter.datasource_finished(
datasource_id=str(datasource_id),
index=datasource_index,
total=len(datasource_ids),
status=DatasourceStatus.OK,
)
except Exception as e:
logger.debug(str(e), exc_info=True, stack_info=True)
logger.info(f"Failed to build source at ({datasource_id.relative_path_to_config_file()}): {str(e)}")

emitter.datasource_finished(
datasource_id=str(datasource_id),
index=datasource_index,
total=len(datasource_ids),
status=DatasourceStatus.FAILED,
error=str(e),
)
number_of_failed_builds += 1

logger.debug(
"Successfully built %d datasources. %s",
"Successfully built %d datasources. %s %s",
len(build_result),
f"Skipped {number_of_skipped_builds}." if number_of_skipped_builds > 0 else "",
f"Failed to build {number_of_failed_builds}." if number_of_failed_builds > 0 else "",
)

emitter.task_finished(
ok=len(build_result),
failed=number_of_failed_builds,
skipped=number_of_skipped_builds,
)

return build_result


def run_indexing(
*, project_layout: ProjectLayout, build_service: BuildService, contexts: list[DatasourceContext]
*,
project_layout: ProjectLayout,
build_service: BuildService,
contexts: list[DatasourceContext],
progress: ProgressCallback | None = None,
) -> IndexSummary:
"""Index a list of built datasource contexts.

Expand All @@ -144,10 +186,25 @@ def run_indexing(

summary = IndexSummary(total=len(contexts), indexed=0, skipped=0, failed=0)

for context in contexts:
emitter = ProgressEmitter(progress)

if not contexts:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this if-block is not needed. It will behave the exact same way without it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll remove.

emitter.task_started(total_datasources=0)
emitter.task_finished(ok=0, failed=0, skipped=0)
return summary

emitter.task_started(total_datasources=len(contexts))

for datasource_index, context in enumerate(contexts, start=1):
try:
logger.info(f"Indexing datasource {context.datasource_id}")

emitter.datasource_started(
datasource_id=str(context.datasource_id),
index=datasource_index,
total=len(contexts),
)

datasource_type = read_datasource_type_from_context(context)

plugin = plugins.get(datasource_type)
Expand All @@ -158,14 +215,34 @@ def run_indexing(
context.datasource_id,
)
summary.skipped += 1
emitter.datasource_finished(
datasource_id=str(context.datasource_id),
index=datasource_index,
total=len(contexts),
status=DatasourceStatus.SKIPPED,
)
continue

build_service.index_built_context(context=context, plugin=plugin)
build_service.index_built_context(context=context, plugin=plugin, progress=progress)
summary.indexed += 1

emitter.datasource_finished(
datasource_id=str(context.datasource_id),
index=datasource_index,
total=len(contexts),
status=DatasourceStatus.OK,
)
except Exception as e:
logger.debug(str(e), exc_info=True, stack_info=True)
logger.info(f"Failed to build source at ({context.datasource_id}): {str(e)}")
summary.failed += 1
emitter.datasource_finished(
datasource_id=str(context.datasource_id),
index=datasource_index,
total=len(contexts),
status=DatasourceStatus.FAILED,
error=str(e),
)

logger.debug(
"Successfully indexed %d/%d datasource(s). %s",
Expand All @@ -174,4 +251,5 @@ def run_indexing(
f"Skipped {summary.skipped}. Failed {summary.failed}." if (summary.skipped or summary.failed) else "",
)

emitter.task_finished(ok=summary.indexed, failed=summary.failed, skipped=summary.skipped)
return summary
8 changes: 7 additions & 1 deletion src/databao_context_engine/build_sources/build_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from databao_context_engine.pluginlib.build_plugin import (
BuildPlugin,
)
from databao_context_engine.progress.progress import ProgressCallback
from databao_context_engine.project.layout import ProjectLayout
from databao_context_engine.serialization.yaml import to_yaml_string
from databao_context_engine.services.chunk_embedding_service import ChunkEmbeddingService
Expand All @@ -35,6 +36,7 @@ def process_prepared_source(
*,
prepared_source: PreparedDatasource,
plugin: BuildPlugin,
progress: ProgressCallback | None = None,
) -> BuiltDatasourceContext:
"""Process a single source to build its context.

Expand All @@ -58,11 +60,14 @@ def process_prepared_source(
result=to_yaml_string(result.context),
full_type=prepared_source.datasource_type.full_type,
datasource_id=result.datasource_id,
progress=progress,
)

return result

def index_built_context(self, *, context: DatasourceContext, plugin: BuildPlugin) -> None:
def index_built_context(
self, *, context: DatasourceContext, plugin: BuildPlugin, progress: ProgressCallback | None = None
) -> None:
"""Index a context file using the given plugin.

1) Parses the yaml context file contents
Expand All @@ -85,6 +90,7 @@ def index_built_context(self, *, context: DatasourceContext, plugin: BuildPlugin
full_type=built.datasource_type,
datasource_id=built.datasource_id,
override=True,
progress=progress,
)

def _deserialize_built_context(
Expand Down
13 changes: 11 additions & 2 deletions src/databao_context_engine/build_sources/build_wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
create_ollama_embedding_provider,
create_ollama_service,
)
from databao_context_engine.progress.progress import ProgressCallback
from databao_context_engine.project.layout import ProjectLayout
from databao_context_engine.services.chunk_embedding_service import ChunkEmbeddingMode
from databao_context_engine.services.factories import create_chunk_embedding_service
Expand All @@ -23,7 +24,10 @@


def build_all_datasources(
project_layout: ProjectLayout, chunk_embedding_mode: ChunkEmbeddingMode
project_layout: ProjectLayout,
chunk_embedding_mode: ChunkEmbeddingMode,
*,
progress: ProgressCallback | None = None,
) -> list[BuildContextResult]:
"""Build the context for all datasources in the project.

Expand Down Expand Up @@ -62,13 +66,16 @@ def build_all_datasources(
return build(
project_layout=project_layout,
build_service=build_service,
progress=progress,
)


def index_built_contexts(
project_layout: ProjectLayout,
contexts: list[DatasourceContext],
chunk_embedding_mode: ChunkEmbeddingMode,
*,
progress: ProgressCallback | None = None,
) -> IndexSummary:
"""Index the contexts into the database.

Expand Down Expand Up @@ -101,7 +108,9 @@ def index_built_contexts(
description_provider=description_provider,
chunk_embedding_mode=chunk_embedding_mode,
)
return run_indexing(project_layout=project_layout, build_service=build_service, contexts=contexts)
return run_indexing(
project_layout=project_layout, build_service=build_service, contexts=contexts, progress=progress
)


def _create_build_service(
Expand Down
18 changes: 12 additions & 6 deletions src/databao_context_engine/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from databao_context_engine.cli.info import echo_info
from databao_context_engine.config.logging import configure_logging
from databao_context_engine.mcp.mcp_runner import McpTransport, run_mcp_server
from databao_context_engine.progress.rich_progress import rich_progress


@click.group()
Expand Down Expand Up @@ -151,9 +152,12 @@ def build(

Internally, this indexes the context to be used by the MCP server and the "retrieve" command.
"""
result = DatabaoContextProjectManager(project_dir=ctx.obj["project_dir"]).build_context(
datasource_ids=None, chunk_embedding_mode=ChunkEmbeddingMode(chunk_embedding_mode.upper())
)
with rich_progress() as progress_cb:
result = DatabaoContextProjectManager(project_dir=ctx.obj["project_dir"]).build_context(
datasource_ids=None,
chunk_embedding_mode=ChunkEmbeddingMode(chunk_embedding_mode.upper()),
progress=progress_cb,
)

click.echo(f"Build complete. Processed {len(result)} datasources.")

Expand All @@ -175,9 +179,11 @@ def index(ctx: Context, datasources_config_files: tuple[str, ...]) -> None:
[DatasourceId.from_string_repr(p) for p in datasources_config_files] if datasources_config_files else None
)

summary = DatabaoContextProjectManager(project_dir=ctx.obj["project_dir"]).index_built_contexts(
datasource_ids=datasource_ids
)
with rich_progress() as progress_cb:
summary = DatabaoContextProjectManager(project_dir=ctx.obj["project_dir"]).index_built_contexts(
datasource_ids=datasource_ids,
progress=progress_cb,
)

suffix = []
if summary.skipped:
Expand Down
18 changes: 16 additions & 2 deletions src/databao_context_engine/databao_context_project_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from databao_context_engine.datasources.datasource_discovery import get_datasource_list
from databao_context_engine.datasources.types import Datasource, DatasourceId
from databao_context_engine.pluginlib.build_plugin import DatasourceType
from databao_context_engine.progress.progress import ProgressCallback
from databao_context_engine.project.layout import (
ProjectLayout,
ensure_project_dir,
Expand Down Expand Up @@ -77,6 +78,8 @@ def build_context(
self,
datasource_ids: list[DatasourceId] | None = None,
chunk_embedding_mode: ChunkEmbeddingMode = ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY,
*,
progress: ProgressCallback | None = None,
) -> list[BuildContextResult]:
"""Build the context for datasources in the project.

Expand All @@ -85,17 +88,24 @@ def build_context(
Args:
datasource_ids: The list of datasource ids to build. If None, all datasources will be built.
chunk_embedding_mode: The mode to use for chunk embedding.
progress: Optional callback that receives progress events during execution.

Returns:
The list of all built results.
"""
# TODO: Filter which datasources to build by datasource_ids
return build_all_datasources(project_layout=self._project_layout, chunk_embedding_mode=chunk_embedding_mode)
return build_all_datasources(
project_layout=self._project_layout,
chunk_embedding_mode=chunk_embedding_mode,
progress=progress,
)

def index_built_contexts(
self,
datasource_ids: list[DatasourceId] | None = None,
chunk_embedding_mode: ChunkEmbeddingMode = ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY,
*,
progress: ProgressCallback | None = None,
) -> IndexSummary:
"""Index built datasource contexts into the embeddings database.

Expand All @@ -105,6 +115,7 @@ def index_built_contexts(
Args:
datasource_ids: The list of datsource ids to index. If None, all datsources will be indexed.
chunk_embedding_mode: The mode to use for chunk embedding.
progress: Optional callback that receives progress events during execution.

Returns:
The summary of the index operation.
Expand All @@ -117,7 +128,10 @@ def index_built_contexts(
contexts = [c for c in contexts if c.datasource_id.datasource_path in wanted_paths]

return index_built_contexts(
project_layout=self._project_layout, contexts=contexts, chunk_embedding_mode=chunk_embedding_mode
project_layout=self._project_layout,
contexts=contexts,
chunk_embedding_mode=chunk_embedding_mode,
progress=progress,
)

def check_datasource_connection(
Expand Down
Empty file.
Loading