Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ default-optional-dependency-keys = [
]

[project.optional-dependencies]
cli = [
"rich>=14.3.2"
]
mssql = [
"mssql-python>=1.0.0"
]
Expand Down
4 changes: 2 additions & 2 deletions src/databao_context_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from databao_context_engine.build_sources.types import (
BuildDatasourceResult,
DatasourceExecutionStatus,
DatasourceResult,
DatasourceStatus,
IndexDatasourceResult,
)
from databao_context_engine.databao_context_engine import ContextSearchResult, DatabaoContextEngine
Expand Down Expand Up @@ -68,7 +68,7 @@
"OllamaPermanentError",
"BuildDatasourceResult",
"DatasourceResult",
"DatasourceStatus",
"DatasourceExecutionStatus",
"IndexDatasourceResult",
"CheckDatasourceConnectionResult",
]
4 changes: 2 additions & 2 deletions src/databao_context_engine/build_sources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from databao_context_engine.build_sources.build_wiring import build_all_datasources
from databao_context_engine.build_sources.types import (
BuildDatasourceResult,
DatasourceExecutionStatus,
DatasourceResult,
DatasourceStatus,
IndexDatasourceResult,
)

__all__ = [
"build_all_datasources",
"DatasourceStatus",
"DatasourceExecutionStatus",
"DatasourceResult",
"BuildDatasourceResult",
"IndexDatasourceResult",
Expand Down
119 changes: 101 additions & 18 deletions src/databao_context_engine/build_sources/build_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from databao_context_engine.build_sources.types import (
BuildDatasourceResult,
DatasourceStatus,
DatasourceExecutionStatus,
IndexDatasourceResult,
)
from databao_context_engine.datasources.datasource_context import (
Expand All @@ -18,13 +18,18 @@
from databao_context_engine.datasources.datasource_discovery import discover_datasources, prepare_source
from databao_context_engine.pluginlib.build_plugin import DatasourceType
from databao_context_engine.plugins.plugin_loader import load_plugins
from databao_context_engine.progress.progress import DatasourceStatus, ProgressCallback, ProgressEmitter
from databao_context_engine.project.layout import ProjectLayout

logger = logging.getLogger(__name__)


def build(
project_layout: ProjectLayout, *, build_service: BuildService, generate_embeddings: bool = True
project_layout: ProjectLayout,
*,
build_service: BuildService,
generate_embeddings: bool = True,
progress: ProgressCallback | None = None,
) -> list[BuildDatasourceResult]:
"""Build the context for all datasources in the project.

Expand All @@ -42,35 +47,58 @@ def build(

datasource_ids = discover_datasources(project_layout)

emitter = ProgressEmitter(progress)

if not datasource_ids:
logger.info("No sources discovered under %s", project_layout.src_dir)
emitter.task_started(total_datasources=0)
emitter.task_finished(ok=0, failed=0, skipped=0)
return []

emitter.task_started(total_datasources=len(datasource_ids))

results: list[BuildDatasourceResult] = []
failed = 0
skipped = 0
reset_all_results(project_layout.output_dir)
for datasource_id in datasource_ids:
for datasource_index, datasource_id in enumerate(datasource_ids, start=1):
try:
prepared_source = prepare_source(project_layout, datasource_id)

logger.info(
f'Found datasource of type "{prepared_source.datasource_type.full_type}" with name {prepared_source.datasource_id.datasource_path}'
)

emitter.datasource_started(
datasource_id=str(datasource_id),
index=datasource_index,
total=len(datasource_ids),
)
plugin = plugins.get(prepared_source.datasource_type)
if plugin is None:
logger.warning(
"No plugin for '%s' (datasource=%s) — skipping.",
prepared_source.datasource_type.full_type,
prepared_source.datasource_id.relative_path_to_config_file(),
)

emitter.datasource_finished(
datasource_id=str(datasource_id),
index=datasource_index,
total=len(datasource_ids),
status=DatasourceStatus.SKIPPED,
)
results.append(
BuildDatasourceResult(datasource_id=datasource_id, status=DatasourceExecutionStatus.SKIPPED)
)
skipped += 1
results.append(BuildDatasourceResult(datasource_id=datasource_id, status=DatasourceStatus.SKIPPED))
continue

result = build_service.process_prepared_source(
prepared_source=prepared_source, plugin=plugin, generate_embeddings=generate_embeddings
prepared_source=prepared_source,
plugin=plugin,
generate_embeddings=generate_embeddings,
progress=progress,
)

output_dir = project_layout.output_dir
Expand All @@ -81,34 +109,59 @@ def build(
results.append(
BuildDatasourceResult(
datasource_id=datasource_id,
status=DatasourceStatus.OK,
status=DatasourceExecutionStatus.OK,
datasource_type=DatasourceType(full_type=result.datasource_type),
context_built_at=result.context_built_at,
context_file_path=context_file_path,
)
)
emitter.datasource_finished(
datasource_id=str(datasource_id),
index=datasource_index,
total=len(datasource_ids),
status=DatasourceStatus.OK,
)
except Exception as e:
logger.debug(str(e), exc_info=True, stack_info=True)
logger.info(f"Failed to build source at ({datasource_id.relative_path_to_config_file()}): {str(e)}")

emitter.datasource_finished(
datasource_id=str(datasource_id),
index=datasource_index,
total=len(datasource_ids),
status=DatasourceStatus.FAILED,
error=str(e),
)
failed += 1
results.append(
BuildDatasourceResult(datasource_id=datasource_id, status=DatasourceStatus.FAILED, error=str(e))
BuildDatasourceResult(
datasource_id=datasource_id, status=DatasourceExecutionStatus.FAILED, error=str(e)
)
)

ok = sum(1 for result in results if result.status == DatasourceStatus.OK)
ok = sum(1 for result in results if result.status == DatasourceExecutionStatus.OK)
logger.debug(
"Successfully built %d/%d datasources. %s",
"Successfully built %d datasources. %s %s",
ok,
len(datasource_ids),
f"Skipped {skipped}. Failed {failed}." if (skipped or failed) else "",
f"Skipped {skipped}." if skipped > 0 else "",
f"Failed to build {failed}." if failed > 0 else "",
)

emitter.task_finished(
ok=ok,
failed=failed,
skipped=skipped,
)

return results


def run_indexing(
*, project_layout: ProjectLayout, build_service: BuildService, contexts: list[DatasourceContext]
*,
project_layout: ProjectLayout,
build_service: BuildService,
contexts: list[DatasourceContext],
progress: ProgressCallback | None = None,
) -> list[IndexDatasourceResult]:
"""Index a list of built datasource contexts.

Expand All @@ -126,10 +179,19 @@ def run_indexing(
skipped = 0
failed = 0

for context in contexts:
emitter = ProgressEmitter(progress)
emitter.task_started(total_datasources=len(contexts))

for datasource_index, context in enumerate(contexts, start=1):
try:
logger.info(f"Indexing datasource {context.datasource_id}")

emitter.datasource_started(
datasource_id=str(context.datasource_id),
index=datasource_index,
total=len(contexts),
)

datasource_type = read_datasource_type_from_context(context)

plugin = plugins.get(datasource_type)
Expand All @@ -140,20 +202,40 @@ def run_indexing(
context.datasource_id,
)
skipped += 1
results.append(
IndexDatasourceResult(datasource_id=context.datasource_id, status=DatasourceStatus.SKIPPED)
emitter.datasource_finished(
datasource_id=str(context.datasource_id),
index=datasource_index,
total=len(contexts),
status=DatasourceStatus.SKIPPED,
)
continue

build_service.index_built_context(context=context, plugin=plugin)
build_service.index_built_context(context=context, plugin=plugin, progress=progress)
ok += 1
results.append(IndexDatasourceResult(datasource_id=context.datasource_id, status=DatasourceStatus.OK))
emitter.datasource_finished(
datasource_id=str(context.datasource_id),
index=datasource_index,
total=len(contexts),
status=DatasourceStatus.OK,
)
results.append(
IndexDatasourceResult(datasource_id=context.datasource_id, status=DatasourceExecutionStatus.OK)
)
except Exception as e:
logger.debug(str(e), exc_info=True, stack_info=True)
logger.info(f"Failed to build source at ({context.datasource_id}): {str(e)}")
failed += 1
results.append(
IndexDatasourceResult(datasource_id=context.datasource_id, status=DatasourceStatus.FAILED, error=str(e))
IndexDatasourceResult(
datasource_id=context.datasource_id, status=DatasourceExecutionStatus.FAILED, error=str(e)
)
)
emitter.datasource_finished(
datasource_id=str(context.datasource_id),
index=datasource_index,
total=len(contexts),
status=DatasourceStatus.FAILED,
error=str(e),
)

logger.debug(
Expand All @@ -163,4 +245,5 @@ def run_indexing(
f"Skipped {skipped}. Failed {failed}." if (skipped or failed) else "",
)

emitter.task_finished(ok=ok, failed=failed, skipped=skipped)
return results
14 changes: 12 additions & 2 deletions src/databao_context_engine/build_sources/build_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from databao_context_engine.pluginlib.build_plugin import (
BuildPlugin,
)
from databao_context_engine.progress.progress import ProgressCallback
from databao_context_engine.project.layout import ProjectLayout
from databao_context_engine.serialization.yaml import to_yaml_string
from databao_context_engine.services.chunk_embedding_service import ChunkEmbeddingService
Expand All @@ -31,7 +32,12 @@ def __init__(
self._chunk_embedding_service = chunk_embedding_service

def process_prepared_source(
self, *, prepared_source: PreparedDatasource, plugin: BuildPlugin, generate_embeddings: bool = True
self,
*,
prepared_source: PreparedDatasource,
plugin: BuildPlugin,
generate_embeddings: bool = True,
progress: ProgressCallback | None = None,
) -> BuiltDatasourceContext:
"""Process a single source to build its context.

Expand All @@ -58,11 +64,14 @@ def process_prepared_source(
result=to_yaml_string(result.context),
full_type=prepared_source.datasource_type.full_type,
datasource_id=result.datasource_id,
progress=progress,
)

return result

def index_built_context(self, *, context: DatasourceContext, plugin: BuildPlugin) -> None:
def index_built_context(
self, *, context: DatasourceContext, plugin: BuildPlugin, progress: ProgressCallback | None = None
) -> None:
"""Index a context file using the given plugin.

1) Parses the yaml context file contents
Expand All @@ -85,6 +94,7 @@ def index_built_context(self, *, context: DatasourceContext, plugin: BuildPlugin
full_type=built.datasource_type,
datasource_id=built.datasource_id,
override=True,
progress=progress,
)

def _deserialize_built_context(
Expand Down
14 changes: 12 additions & 2 deletions src/databao_context_engine/build_sources/build_wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
create_ollama_embedding_provider,
create_ollama_service,
)
from databao_context_engine.progress.progress import ProgressCallback
from databao_context_engine.project.layout import ProjectLayout
from databao_context_engine.services.chunk_embedding_service import ChunkEmbeddingMode
from databao_context_engine.services.factories import create_chunk_embedding_service
Expand All @@ -29,9 +30,11 @@
def build_all_datasources(
project_layout: ProjectLayout,
chunk_embedding_mode: ChunkEmbeddingMode,
*,
generate_embeddings: bool = True,
ollama_model_id: str | None = None,
ollama_model_dim: int | None = None,
progress: ProgressCallback | None = None,
) -> list[BuildDatasourceResult]:
"""Build the context for all datasources in the project.
Expand Down Expand Up @@ -70,14 +73,19 @@ def build_all_datasources(
chunk_embedding_mode=chunk_embedding_mode,
)
return build(
project_layout=project_layout, build_service=build_service, generate_embeddings=generate_embeddings
project_layout=project_layout,
build_service=build_service,
generate_embeddings=generate_embeddings,
progress=progress,
)


def index_built_contexts(
project_layout: ProjectLayout,
contexts: list[DatasourceContext],
chunk_embedding_mode: ChunkEmbeddingMode,
*,
progress: ProgressCallback | None = None,
ollama_model_id: str | None = None,
ollama_model_dim: int | None = None,
) -> list[IndexDatasourceResult]:
Expand Down Expand Up @@ -114,7 +122,9 @@ def index_built_contexts(
description_provider=description_provider,
chunk_embedding_mode=chunk_embedding_mode,
)
return run_indexing(project_layout=project_layout, build_service=build_service, contexts=contexts)
return run_indexing(
project_layout=project_layout, build_service=build_service, contexts=contexts, progress=progress
)


def _create_build_service(
Expand Down
4 changes: 2 additions & 2 deletions src/databao_context_engine/build_sources/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from databao_context_engine.pluginlib.build_plugin import DatasourceType


class DatasourceStatus(Enum):
class DatasourceExecutionStatus(Enum):
"""Status of an operation for a single datasource."""

OK = "ok"
Expand All @@ -29,7 +29,7 @@ class DatasourceResult:
"""

datasource_id: DatasourceId
status: DatasourceStatus
status: DatasourceExecutionStatus
error: str | None = None


Expand Down
Loading