diff --git a/extensions/positron-assistant/package.json b/extensions/positron-assistant/package.json index 9e41379c35db..14ece2795a3b 100644 --- a/extensions/positron-assistant/package.json +++ b/extensions/positron-assistant/package.json @@ -360,6 +360,39 @@ "positron-assistant" ] }, + { + "name": "getTableSummary", + "displayName": "Get Table Summary", + "modelDescription": "Get structured information about table variables in the current session.", + "inputSchema": { + "type": "object", + "properties": { + "sessionIdentifier": { + "type": "string", + "description": "The identifier of the session that contains the tables." + }, + "accessKeys": { + "type": "array", + "description": "An array of table variables to summarize.", + "items": { + "type": "array", + "description": "A list of access keys that identify a variable by specifying its path.", + "items": { + "type": "string", + "description": "An access key that uniquely identifies a variable among its siblings." + } + } + } + }, + "required": [ + "sessionIdentifier", + "accessKeys" + ] + }, + "tags": [ + "positron-assistant" + ] + }, { "name": "getProjectTree", "displayName": "Get Project Tree", diff --git a/extensions/positron-assistant/src/md/prompts/chat/agent.md b/extensions/positron-assistant/src/md/prompts/chat/agent.md index b02624f28269..63e1326e905b 100644 --- a/extensions/positron-assistant/src/md/prompts/chat/agent.md +++ b/extensions/positron-assistant/src/md/prompts/chat/agent.md @@ -20,6 +20,29 @@ If the user asks you _how_ to do something, or asks for code rather than results, generate the code and return it directly without trying to execute it. + + +**Data Object Information Workflow:** + +When the user asks questions that require detailed information about tabular +data objects (DataFrames, arrays, matrices, etc.), use the `getTableSummary` +tool to retrieve structured information such as data summaries and statistics. +Currently, this tool is only available for Python so in R sessions you will need +to execute code to query in-memory data. + +To use the tool effectively: + +1. First ensure you have the correct `sessionIdentifier` from the user context +2. Provide the `accessKeys` array with the path to the specific data objects + - Each access key is an array of strings representing the path to the variable + - If the user references a variable by name, determine the access key from context or previous tool results +3. Do not call this tool when: + - The variables do not appear in the user context + - There is no active session + - The user only wants to see the structure/children of objects (use `inspectVariables` instead) + + + You adhere to the following workflow when dealing with package management: diff --git a/extensions/positron-assistant/src/participants.ts b/extensions/positron-assistant/src/participants.ts index 52960daa7c26..7454e4b9322c 100644 --- a/extensions/positron-assistant/src/participants.ts +++ b/extensions/positron-assistant/src/participants.ts @@ -283,6 +283,11 @@ abstract class PositronAssistantParticipant implements IPositronAssistantPartici // Only include the documentCreate tool in the chat pane in edit or agent mode. case PositronAssistantToolName.DocumentCreate: return inChatPane && (isEditMode || isAgentMode); + // Only include the getTableSummary tool for Python sessions until supported in R + case PositronAssistantToolName.GetTableSummary: + // TODO: Remove this restriction when the tool is supported in R https://github.com/posit-dev/positron/issues/8343 + // The logic above with TOOL_TAG_REQUIRES_ACTIVE_SESSION will handle checking for active sessions once this is removed. + return activeSessions.has('python'); // Otherwise, include the tool if it is tagged for use with Positron Assistant. // Allow all tools in Agent mode. default: diff --git a/extensions/positron-assistant/src/tools.ts b/extensions/positron-assistant/src/tools.ts index 9e8aaa564fd7..21ec3755d457 100644 --- a/extensions/positron-assistant/src/tools.ts +++ b/extensions/positron-assistant/src/tools.ts @@ -299,6 +299,56 @@ export function registerAssistantTools( context.subscriptions.push(inspectVariablesTool); + const getTableSummaryTool = vscode.lm.registerTool<{ sessionIdentifier: string; accessKeys: Array> }>(PositronAssistantToolName.GetTableSummary, { + /** + * Called to get a summary information for one or more tabular datasets in the current session. + * @param options The options for the tool invocation. + * @param token The cancellation token. + * @returns A vscode.LanguageModelToolResult containing the data summary. + */ + invoke: async (options, token) => { + + // If no session identifier is provided, return an empty array. + if (!options.input.sessionIdentifier || options.input.sessionIdentifier === 'undefined') { + return new vscode.LanguageModelToolResult([ + new vscode.LanguageModelTextPart('[[]]') + ]); + } + + // temporarily only enable for Python sessions + let session: positron.LanguageRuntimeSession | undefined; + const sessions = await positron.runtime.getActiveSessions(); + if (sessions && sessions.length > 0) { + session = sessions.find( + (session) => session.metadata.sessionId === options.input.sessionIdentifier, + ); + } + if (!session) { + return new vscode.LanguageModelToolResult([ + new vscode.LanguageModelTextPart('[[]]') + ]); + } + + if (session.runtimeMetadata.languageId !== 'python') { + return new vscode.LanguageModelToolResult([ + new vscode.LanguageModelTextPart('[[]]') + ]); + } + + // Call the Positron API to get the session variable data summaries + const result = await positron.runtime.querySessionTables( + options.input.sessionIdentifier, + options.input.accessKeys, + ['summary_stats']); + + // Return the result as a JSON string to the model + return new vscode.LanguageModelToolResult([ + new vscode.LanguageModelTextPart(JSON.stringify(result)) + ]); + } + }); + context.subscriptions.push(getTableSummaryTool); + const installPythonPackageTool = vscode.lm.registerTool<{ packages: string[]; }>(PositronAssistantToolName.InstallPythonPackage, { diff --git a/extensions/positron-assistant/src/types.ts b/extensions/positron-assistant/src/types.ts index 01e2550a4151..cfed6669a29e 100644 --- a/extensions/positron-assistant/src/types.ts +++ b/extensions/positron-assistant/src/types.ts @@ -7,6 +7,7 @@ export enum PositronAssistantToolName { DocumentEdit = 'documentEdit', EditFile = 'positron_editFile_internal', ExecuteCode = 'executeCode', + GetTableSummary = 'getTableSummary', GetPlot = 'getPlot', InstallPythonPackage = 'installPythonPackage', InspectVariables = 'inspectVariables', diff --git a/extensions/positron-python/python_files/posit/positron/data_explorer.py b/extensions/positron-python/python_files/posit/positron/data_explorer.py index 8afa9a3a75e7..7fb44ecc5408 100644 --- a/extensions/positron-python/python_files/posit/positron/data_explorer.py +++ b/extensions/positron-python/python_files/posit/positron/data_explorer.py @@ -52,7 +52,7 @@ DataSelectionRange, DataSelectionSingleCell, ExportDataSelectionFeatures, - ExportDataSelectionRequest, + ExportDataSelectionParams, ExportedData, ExportFormat, FilterBetween, @@ -64,24 +64,23 @@ FilterTextSearch, FormatOptions, GetColumnProfilesFeatures, - GetColumnProfilesRequest, - GetDataValuesRequest, - GetRowLabelsRequest, - GetSchemaRequest, - GetStateRequest, + GetColumnProfilesParams, + GetDataValuesParams, + GetRowLabelsParams, + GetSchemaParams, RowFilter, RowFilterCondition, RowFilterType, RowFilterTypeSupportStatus, SearchSchemaFeatures, - SearchSchemaRequest, + SearchSchemaParams, SearchSchemaResult, SetColumnFiltersFeatures, - SetColumnFiltersRequest, + SetColumnFiltersParams, SetRowFiltersFeatures, - SetRowFiltersRequest, + SetRowFiltersParams, SetSortColumnsFeatures, - SetSortColumnsRequest, + SetSortColumnsParams, SummaryStatsBoolean, SummaryStatsDate, SummaryStatsDatetime, @@ -261,10 +260,10 @@ def _update_row_view_indices(self): self._sort_data() # Gets the schema from a list of column indices. - def get_schema(self, request: GetSchemaRequest): + def get_schema(self, params: GetSchemaParams): # Loop over the sorted column indices to get the column schemas the user requested. column_schemas = [] - for column_index in sorted(request.params.column_indices): + for column_index in sorted(params.column_indices): # Validate that the column index isn't negative. if column_index < 0: raise IndexError @@ -277,15 +276,15 @@ def get_schema(self, request: GetSchemaRequest): column_schemas.append(self._get_single_column_schema(column_index)) # Return the column schemas. - return TableSchema(columns=column_schemas).dict() + return TableSchema(columns=column_schemas) def _get_single_column_schema(self, column_index: int) -> ColumnSchema: raise NotImplementedError - def search_schema(self, request: SearchSchemaRequest): - filters = request.params.filters - start_index = request.params.start_index - max_results = request.params.max_results + def search_schema(self, params: SearchSchemaParams): + filters = params.filters + start_index = params.start_index + max_results = params.max_results if self._search_schema_last_result is not None: last_filters, matches = self._search_schema_last_result if last_filters != filters: @@ -299,7 +298,7 @@ def search_schema(self, request: SearchSchemaRequest): return SearchSchemaResult( matches=TableSchema(columns=[self._get_single_column_schema(i) for i in matches_slice]), total_num_matches=len(matches), - ).dict() + ) def _column_filter_get_matches(self, filters: list[ColumnFilter]): matchers = self._get_column_filter_functions(filters) @@ -352,18 +351,18 @@ def matcher(index): def _get_column_name(self, column_index: int) -> str: raise NotImplementedError - def get_data_values(self, request: GetDataValuesRequest): + def get_data_values(self, params: GetDataValuesParams): self._recompute_if_needed() return self._get_data_values( - request.params.columns, - request.params.format_options, + params.columns, + params.format_options, ) - def get_row_labels(self, request: GetRowLabelsRequest): + def get_row_labels(self, params: GetRowLabelsParams): self._recompute_if_needed() return self._get_row_labels( - request.params.selection, - request.params.format_options, + params.selection, + params.format_options, ) def _get_row_labels(self, _selection: ArraySelection, _format_options: FormatOptions): @@ -371,11 +370,11 @@ def _get_row_labels(self, _selection: ArraySelection, _format_options: FormatOpt # be implemented for pandas return {"row_labels": []} - def export_data_selection(self, request: ExportDataSelectionRequest): + def export_data_selection(self, params: ExportDataSelectionParams): self._recompute_if_needed() - kind = request.params.selection.kind - sel = request.params.selection.selection - fmt = request.params.format + kind = params.selection.kind + sel = params.selection.selection + fmt = params.format if kind == TableSelectionKind.SingleCell: assert isinstance(sel, DataSelectionSingleCell) row_index = sel.row_index @@ -418,14 +417,14 @@ def _export_cell(self, row_index: int, column_index: int, fmt: ExportFormat): def _export_tabular(self, row_selector, column_selector, fmt: ExportFormat): raise NotImplementedError - def set_column_filters(self, request: SetColumnFiltersRequest): - return self._set_column_filters(request.params.filters) + def set_column_filters(self, params: SetColumnFiltersParams): + return self._set_column_filters(params.filters) def _set_column_filters(self, filters: list[ColumnFilter]): raise NotImplementedError - def set_row_filters(self, request: SetRowFiltersRequest): - return self._set_row_filters(request.params.filters) + def set_row_filters(self, params: SetRowFiltersParams): + return self._set_row_filters(params.filters) def _set_row_filters(self, filters: list[RowFilter]): self.state.row_filters = filters @@ -439,7 +438,7 @@ def _set_row_filters(self, filters: list[RowFilter]): # Simply reset if empty filter set passed self.filtered_indices = None self._update_row_view_indices() - return FilterResult(selected_num_rows=len(self.table), had_errors=False).dict() + return FilterResult(selected_num_rows=len(self.table), had_errors=False) # Evaluate all the filters and combine them using the # indicated conditions @@ -479,7 +478,7 @@ def _set_row_filters(self, filters: list[RowFilter]): # Update the view indices, re-sorting if needed self._update_row_view_indices() - return FilterResult(selected_num_rows=selected_num_rows, had_errors=had_errors).dict() + return FilterResult(selected_num_rows=selected_num_rows, had_errors=had_errors) def _mask_to_indices(self, mask): raise NotImplementedError @@ -487,8 +486,8 @@ def _mask_to_indices(self, mask): def _eval_filter(self, filt: RowFilter): raise NotImplementedError - def set_sort_columns(self, request: SetSortColumnsRequest): - self._set_sort_keys(request.params.sort_keys) + def set_sort_columns(self, params: SetSortColumnsParams): + self._set_sort_keys(params.sort_keys) if not self._recompute_if_needed(): # If a re-filter is pending, then it will automatically @@ -498,22 +497,22 @@ def set_sort_columns(self, request: SetSortColumnsRequest): def _sort_data(self): raise NotImplementedError - def get_column_profiles(self, request: GetColumnProfilesRequest): + def get_column_profiles(self, params: GetColumnProfilesParams): # Launch task to compute column profiles and return them # asynchronously, and return an empty result right away - self.job_queue.submit(self._get_column_profiles_task, request) + self.job_queue.submit(self._get_column_profiles_task, params) return {} - def _get_column_profiles_task(self, request: GetColumnProfilesRequest): + def _get_column_profiles_task(self, params: GetColumnProfilesParams): self._recompute_if_needed() results = [] - for req in request.params.profiles: + for req in params.profiles: try: result = self._compute_profiles( req.column_index, req.profiles, - request.params.format_options, + params.format_options, ) results.append(result.dict()) except Exception as e: # noqa: PERF203 @@ -524,7 +523,7 @@ def _get_column_profiles_task(self, request: GetColumnProfilesRequest): self.comm.send_event( DataExplorerFrontendEvent.ReturnColumnProfiles.value, - {"callback_id": request.params.callback_id, "profiles": results}, + {"callback_id": params.callback_id, "profiles": results}, ) def _compute_profiles( @@ -564,7 +563,7 @@ def _compute_profiles( raise NotImplementedError(profile_type) return ColumnProfileResult(**results) - def get_state(self, _: GetStateRequest): + def get_state(self, _unused): self._recompute_if_needed() num_rows, num_columns = self.table.shape @@ -592,7 +591,7 @@ def get_state(self, _: GetStateRequest): row_filters=self.state.row_filters, sort_keys=self.state.sort_keys, supported_features=self.FEATURES, - ).dict() + ) def _recompute(self): # Re-setting the column filters will trigger filtering AND @@ -751,7 +750,7 @@ def _prof_null_count(self, column_index: int) -> int: _SUMMARIZERS: MappingProxyType[str, SummarizerType] = MappingProxyType({}) - def _prof_summary_stats(self, column_index: int, options: FormatOptions): + def _prof_summary_stats(self, column_index: int, options: FormatOptions) -> ColumnSummaryStats: col_schema = self._get_single_column_schema(column_index) col = self._get_column(column_index) @@ -1593,10 +1592,10 @@ def _export_tabular(self, row_selector, column_selector, fmt: ExportFormat): if result[-1] == "\n": result = result[:-1] - return ExportedData(data=result, format=fmt).dict() + return ExportedData(data=result, format=fmt) def _export_cell(self, row_index: int, column_index: int, fmt: ExportFormat): - return ExportedData(data=str(self.table.iloc[row_index, column_index]), format=fmt).dict() + return ExportedData(data=str(self.table.iloc[row_index, column_index]), format=fmt) def _mask_to_indices(self, mask): if mask is not None: @@ -2439,10 +2438,10 @@ def _export_tabular(self, row_selector, column_selector, fmt: ExportFormat): elif fmt == ExportFormat.Html: raise NotImplementedError(f"Unsupported export format {fmt}") - return ExportedData(data=result, format=fmt).dict() + return ExportedData(data=result, format=fmt) def _export_cell(self, row_index: int, column_index: int, fmt: ExportFormat): - return ExportedData(data=str(self.table[row_index, column_index]), format=fmt).dict() + return ExportedData(data=str(self.table[row_index, column_index]), format=fmt) SUPPORTED_FILTERS = frozenset( { @@ -3051,10 +3050,14 @@ def handle_msg(self, msg: CommMessage[DataExplorerBackendMessageContent], _raw_m comm = self.comms[comm_id] table = self.table_views[comm_id] - result = getattr(table, request.method.value)(request) + # GetState is the only method that doesn't have params + result = getattr(table, request.method.value)(getattr(request, "params", None)) # To help remember to convert pydantic types to dicts if result is not None: + # Convert pydantic types to dict + if not isinstance(result, dict): + result = result.dict() if isinstance(result, list): for x in result: assert isinstance(x, dict) @@ -3062,3 +3065,29 @@ def handle_msg(self, msg: CommMessage[DataExplorerBackendMessageContent], _raw_m assert isinstance(result, dict) comm.send_result(result) + + +def _get_column_profiles(table_view, schema, query_types, format_options): + """Generate column profiles for a table view.""" + profiles = [] + skipped_columns = [] + + for i, column in enumerate(schema.columns): + summary_stats = None + if "summary_stats" in query_types: + try: + summary_stats = table_view._prof_summary_stats(i, format_options) # noqa: SLF001 + except Exception as e: + # Collect failed columns for later logging + skipped_columns.append((i, column.column_name, e)) + continue + + profiles.append( + { + "column_name": column.column_name, + "type_display": column.type_display, + "summary_stats": summary_stats.dict() if summary_stats else None, + } + ) + + return profiles, skipped_columns diff --git a/extensions/positron-python/python_files/posit/positron/tests/test_data_explorer.py b/extensions/positron-python/python_files/posit/positron/tests/test_data_explorer.py index db1d4fc9c8a4..d09178399cdf 100644 --- a/extensions/positron-python/python_files/posit/positron/tests/test_data_explorer.py +++ b/extensions/positron-python/python_files/posit/positron/tests/test_data_explorer.py @@ -55,8 +55,8 @@ ) from ..utils import guid from .conftest import DummyComm, PositronShell -from .test_variables import BIG_ARRAY_LENGTH -from .utils import json_rpc_notification, json_rpc_request +from .test_variables import BIG_ARRAY_LENGTH, _assign_variables +from .utils import dummy_rpc_request, json_rpc_notification, json_rpc_request TARGET_NAME = "positron.dataExplorer" @@ -143,13 +143,9 @@ def test_service_properties(de_service: DataExplorerService): assert de_service.comm_target == TARGET_NAME -def _dummy_rpc_request(*args): - return json_rpc_request(*args, comm_id="dummy_comm_id") - - -def _open_viewer(variables_comm, path): +def _open_viewer(variables_comm: DummyComm, path: list[str]): path = [encode_access_key(p) for p in path] - msg = _dummy_rpc_request("view", {"path": path}) + msg = dummy_rpc_request("view", {"path": path}) variables_comm.handle_msg(msg) # We should get a string back as a result, naming the ID of the viewer comm # that was opened @@ -191,15 +187,6 @@ def test_explorer_open_close_delete( assert len(de_service.table_views) == 0 -def _assign_variables(shell: PositronShell, variables_comm: DummyComm, **variables): - # A hack to make sure that change events are fired when we - # manipulate user_ns - shell.kernel.variables_service.snapshot_user_ns() - shell.user_ns.update(**variables) - shell.kernel.variables_service.poll_variables() - variables_comm.messages.clear() - - def test_explorer_delete_variable( shell: PositronShell, de_service: DataExplorerService, @@ -226,7 +213,7 @@ def test_explorer_delete_variable( # Delete x, check cleaned up and def _check_delete_variable(name): - msg = _dummy_rpc_request("delete", {"names": [name]}) + msg = dummy_rpc_request("delete", {"names": [name]}) paths = de_service.get_paths_for_variable(name) assert len(paths) > 0 diff --git a/extensions/positron-python/python_files/posit/positron/tests/test_variables.py b/extensions/positron-python/python_files/posit/positron/tests/test_variables.py index 6b3cf37eeba9..09df5933698c 100644 --- a/extensions/positron-python/python_files/posit/positron/tests/test_variables.py +++ b/extensions/positron-python/python_files/posit/positron/tests/test_variables.py @@ -25,6 +25,7 @@ from .utils import ( assert_register_table_called, comm_open_message, + dummy_rpc_request, json_rpc_error, json_rpc_notification, json_rpc_request, @@ -803,7 +804,7 @@ def _do_view( mock_dataexplorer_service: Mock, ): path = _encode_path([name]) - msg = json_rpc_request("view", {"path": path}, comm_id="dummy_comm_id") + msg = dummy_rpc_request("view", {"path": path}) variables_comm.handle_msg(msg) # An acknowledgment message is sent @@ -902,6 +903,45 @@ def fail_is_supported(_value): ] +def _assign_variables(shell: PositronShell, variables_comm: DummyComm, **variables): + # A hack to make sure that change events are fired when we + # manipulate user_ns + shell.kernel.variables_service.snapshot_user_ns() + shell.user_ns.update(**variables) + shell.kernel.variables_service.poll_variables() + variables_comm.messages.clear() + + +def test_query_table_summary(shell: PositronShell, variables_comm: DummyComm): + from .test_data_explorer import SIMPLE_PANDAS_DF + + _assign_variables(shell, variables_comm, df=SIMPLE_PANDAS_DF.iloc[:, :2]) + + msg = json_rpc_request( + "query_table_summary", + {"path": ["df"], "query_types": ["summary_stats"]}, + comm_id="dummy_comm_id", + ) + variables_comm.handle_msg(msg) + + assert variables_comm.messages == [ + json_rpc_response( + { + "num_rows": 5, + "num_columns": 2, + "column_schemas": [ + '{"column_name": "a", "column_index": 0, "type_name": "int64", "type_display": "number", "description": null, "children": null, "precision": null, "scale": null, "timezone": null, "type_size": null}', + '{"column_name": "b", "column_index": 1, "type_name": "bool", "type_display": "boolean", "description": null, "children": null, "precision": null, "scale": null, "timezone": null, "type_size": null}', + ], + "column_profiles": [ + '{"column_name": "a", "type_display": "number", "summary_stats": {"type_display": "number", "number_stats": {"min_value": "1.0000", "max_value": "5.0000", "mean": "3.0000", "median": "3.0000", "stdev": "1.5811"}, "string_stats": null, "boolean_stats": null, "date_stats": null, "datetime_stats": null, "other_stats": null}}', + '{"column_name": "b", "type_display": "boolean", "summary_stats": {"type_display": "boolean", "number_stats": null, "string_stats": null, "boolean_stats": {"true_count": 3, "false_count": 1}, "date_stats": null, "datetime_stats": null, "other_stats": null}}', + ], + } + ) + ] + + def test_unknown_method(variables_comm: DummyComm) -> None: msg = json_rpc_request("unknown_method", comm_id="dummy_comm_id") variables_comm.handle_msg(msg, raise_errors=False) diff --git a/extensions/positron-python/python_files/posit/positron/tests/utils.py b/extensions/positron-python/python_files/posit/positron/tests/utils.py index e32858d0d37d..514dd8fecc8a 100644 --- a/extensions/positron-python/python_files/posit/positron/tests/utils.py +++ b/extensions/positron-python/python_files/posit/positron/tests/utils.py @@ -137,3 +137,7 @@ def get_type_as_str(value: Any) -> str: def percent_difference(actual: float, expected: float) -> float: return abs(actual - expected) / actual + + +def dummy_rpc_request(*args): + return json_rpc_request(*args, comm_id="dummy_comm_id") diff --git a/extensions/positron-python/python_files/posit/positron/variables.py b/extensions/positron-python/python_files/posit/positron/variables.py index 4d68edd6c138..64b5b7cba09c 100644 --- a/extensions/positron-python/python_files/posit/positron/variables.py +++ b/extensions/positron-python/python_files/posit/positron/variables.py @@ -6,6 +6,7 @@ import contextlib import copy +import json import logging import time import types @@ -30,6 +31,7 @@ InspectedVariable, InspectRequest, ListRequest, + QueryTableSummaryRequest, RefreshParams, UpdateParams, Variable, @@ -137,6 +139,9 @@ def handle_msg( elif isinstance(request, ViewRequest): self._perform_view_action(request.params.path) + elif isinstance(request, QueryTableSummaryRequest): + self._perform_get_table_summary(request.params.path, request.params.query_types) + else: logger.warning(f"Unhandled request: {request}") @@ -707,6 +712,77 @@ def _send_details(self, _path: list[str], value: Any = None): msg = InspectedVariable(children=children, length=len(children)) self._send_result(msg.dict()) + def _perform_get_table_summary(self, path: list[str], query_types: list[str]) -> None: + """RPC handler for getting table summary.""" + import traceback + + try: + self._get_table_summary(path, query_types) + except Exception as err: + self._send_error( + JsonRpcErrorCode.INTERNAL_ERROR, + f"Error summarizing table at '{path}': {err}\n{traceback.format_exc()}", + ) + + def _get_table_summary(self, path: list[str], query_types: list[str]) -> None: + """Compute statistical summary for a table without opening a data explorer.""" + from .data_explorer import ( + DataExplorerState, + _get_column_profiles, + _get_table_view, + _value_type_is_supported, + ) + from .data_explorer_comm import FormatOptions, GetSchemaParams + + is_known, value = self._find_var(path) + if not is_known: + raise ValueError(f"Cannot find table at '{path}' to summarize") + + if not _value_type_is_supported(value): + raise ValueError(f"Variable at '{path}' is not supported for table summary") + + try: + # Create a temporary table view with a temporary comm + temp_state = DataExplorerState("temp_summary") + temp_comm = PositronComm.create(target_name="temp_summary", comm_id="temp_summary_comm") + table_view = _get_table_view(value, temp_comm, temp_state, self.kernel.job_queue) + except Exception as e: + raise ValueError(f"Failed to create table view: {e}") from e + + # Get schema using the helper function + num_rows = table_view.table.shape[0] + num_columns = table_view.table.shape[1] + schema = table_view.get_schema(GetSchemaParams(column_indices=list(range(num_columns)))) + + # Create default format options + format_options = FormatOptions( + large_num_digits=4, + small_num_digits=6, + max_integral_digits=7, + max_value_length=1000, + thousands_sep=None, + ) + + # Get column profiles using the helper function + profiles, skipped_columns = _get_column_profiles( + table_view, schema, query_types, format_options + ) + + # Log all skipped columns at once + for i, column_name, error in skipped_columns: + logger.warning(f"Skipping summary stats for column {i} ({column_name}): {error}") + + self._send_result( + { + "num_rows": num_rows, + "num_columns": num_columns, + # convert each column schema to serialized JSON + "column_schemas": [json.dumps(x.dict()) for x in schema.columns], + # convert each column profile to serialized JSON + "column_profiles": [json.dumps(x) for x in profiles], + } + ) + def _summarize_variable(key: Any, value: Any, display_name: str | None = None) -> Variable | None: """ diff --git a/extensions/positron-python/python_files/posit/positron/variables_comm.py b/extensions/positron-python/python_files/posit/positron/variables_comm.py index 6c432b4c0c35..ce011f15cdee 100644 --- a/extensions/positron-python/python_files/posit/positron/variables_comm.py +++ b/extensions/positron-python/python_files/posit/positron/variables_comm.py @@ -105,6 +105,28 @@ class FormattedVariable(BaseModel): ) +class QueryTableSummaryResult(BaseModel): + """ + Result of the summarize operation + """ + + num_rows: StrictInt = Field( + description="The total number of rows in the table.", + ) + + num_columns: StrictInt = Field( + description="The total number of columns in the table.", + ) + + column_schemas: List[StrictStr] = Field( + description="The column schemas in the table.", + ) + + column_profiles: List[StrictStr] = Field( + description="The column profiles in the table.", + ) + + class Variable(BaseModel): """ A single variable in the runtime. @@ -183,6 +205,9 @@ class VariablesBackendRequest(str, enum.Enum): # Request a viewer for a variable View = "view" + # Query table summary + QueryTableSummary = "query_table_summary" + class ListRequest(BaseModel): """ @@ -352,6 +377,39 @@ class ViewRequest(BaseModel): ) +class QueryTableSummaryParams(BaseModel): + """ + Request a data summary for a table variable. + """ + + path: List[StrictStr] = Field( + description="The path to the table to summarize, as an array of access keys.", + ) + + query_types: List[StrictStr] = Field( + description="A list of query types.", + ) + + +class QueryTableSummaryRequest(BaseModel): + """ + Request a data summary for a table variable. + """ + + params: QueryTableSummaryParams = Field( + description="Parameters to the QueryTableSummary method", + ) + + method: Literal[VariablesBackendRequest.QueryTableSummary] = Field( + description="The JSON-RPC method name (query_table_summary)", + ) + + jsonrpc: str = Field( + default="2.0", + description="The JSON-RPC version specifier", + ) + + class VariablesBackendMessageContent(BaseModel): comm_id: str data: Union[ @@ -361,6 +419,7 @@ class VariablesBackendMessageContent(BaseModel): InspectRequest, ClipboardFormatRequest, ViewRequest, + QueryTableSummaryRequest, ] = Field(..., discriminator="method") @@ -423,6 +482,8 @@ class RefreshParams(BaseModel): FormattedVariable.update_forward_refs() +QueryTableSummaryResult.update_forward_refs() + Variable.update_forward_refs() ListRequest.update_forward_refs() @@ -447,6 +508,10 @@ class RefreshParams(BaseModel): ViewRequest.update_forward_refs() +QueryTableSummaryParams.update_forward_refs() + +QueryTableSummaryRequest.update_forward_refs() + UpdateParams.update_forward_refs() RefreshParams.update_forward_refs() diff --git a/positron/comms/variables-backend-openrpc.json b/positron/comms/variables-backend-openrpc.json index b4a6eaec0169..8c68d4dd8388 100644 --- a/positron/comms/variables-backend-openrpc.json +++ b/positron/comms/variables-backend-openrpc.json @@ -189,6 +189,70 @@ }, "required": false } + }, + { + "name": "query_table_summary", + "summary": "Query table summary", + "description": "Request a data summary for a table variable.", + "params": [ + { + "name": "path", + "description": "The path to the table to summarize, as an array of access keys.", + "schema": { + "type": "array", + "items": { + "type": "string" + } + } + }, + { + "name": "query_types", + "description": "A list of query types.", + "schema": { + "type": "array", + "items": { + "type": "string" + } + } + } + ], + "result": { + "schema": { + "name": "query_table_summary_result", + "description": "Result of the summarize operation", + "type": "object", + "properties": { + "num_rows": { + "type": "integer", + "description": "The total number of rows in the table." + }, + "num_columns": { + "type": "integer", + "description": "The total number of columns in the table." + }, + "column_schemas": { + "type": "array", + "description": "The column schemas in the table.", + "items": { + "type": "string" + } + }, + "column_profiles": { + "type": "array", + "description": "The column profiles in the table.", + "items": { + "type": "string" + } + } + }, + "required": [ + "num_rows", + "num_columns", + "column_schemas", + "column_profiles" + ] + } + } } ], "components": { diff --git a/src/positron-dts/positron.d.ts b/src/positron-dts/positron.d.ts index 607a89e709a1..3db0da272feb 100644 --- a/src/positron-dts/positron.d.ts +++ b/src/positron-dts/positron.d.ts @@ -708,6 +708,32 @@ declare module 'positron' { has_children: boolean; } + /** + * A result of calling the QuerySessionTables API. + */ + export interface QueryTableSummaryResult { + /** + * The total number of rows in the table. + */ + num_rows: number; + + /** + * The total number of columns in the table. + */ + num_columns: number; + + /** + * The column schemas in the table. + */ + column_schemas: Array; + + /** + * The column profiles in the table. + */ + column_profiles: Array; + + } + /** * The possible types of language model that can be used with the Positron Assistant. */ @@ -1788,6 +1814,22 @@ declare module 'positron' { accessKeys?: Array>): Thenable>>; + /** + * Query a table in a session. + * + * @param sessionId The session ID of the session to query tables. + * @param accessKeys The access keys of the tables to query. + * @param queryTypes The types of data to query for the tables. + * + * @returns A Thenable that resolves with an array of runtime + * table query results. + */ + export function querySessionTables( + sessionId: string, + accessKeys: Array>, + queryTypes: Array): + Thenable>; + /** * Register a handler for runtime client instances. This handler will be called * whenever a new client instance is created by a language runtime of the given diff --git a/src/vs/workbench/api/browser/positron/mainThreadLanguageRuntime.ts b/src/vs/workbench/api/browser/positron/mainThreadLanguageRuntime.ts index 22983eb1239c..d04f9ce6e1ff 100644 --- a/src/vs/workbench/api/browser/positron/mainThreadLanguageRuntime.ts +++ b/src/vs/workbench/api/browser/positron/mainThreadLanguageRuntime.ts @@ -46,7 +46,7 @@ import { basename } from '../../../../base/common/resources.js'; import { RuntimeOnlineState } from '../../common/extHostTypes.js'; import { VSBuffer } from '../../../../base/common/buffer.js'; import { CodeAttributionSource, IConsoleCodeAttribution } from '../../../services/positronConsole/common/positronConsoleCodeExecution.js'; -import { Variable } from '../../../services/languageRuntime/common/positronVariablesComm.js'; +import { QueryTableSummaryResult, Variable } from '../../../services/languageRuntime/common/positronVariablesComm.js'; import { IPositronVariablesInstance } from '../../../services/positronVariables/common/interfaces/positronVariablesInstance.js'; import { isWebviewPreloadMessage, isWebviewReplayMessage } from '../../../services/positronIPyWidgets/common/webviewPreloadUtils.js'; @@ -1590,6 +1590,33 @@ export class MainThreadLanguageRuntime } } + $querySessionTables(handle: number, accessKeys: Array>, queryTypes: Array): Promise> { + const sessionId = this.findSession(handle).sessionId; + const instances = this._positronVariablesService.positronVariablesInstances; + for (const instance of instances) { + if (instance.session.sessionId === sessionId) { + return this.querySessionTables(instance, accessKeys, queryTypes); + } + } + throw new Error(`No variables provider found for session ${sessionId}`); + } + + async querySessionTables(instance: IPositronVariablesInstance, accessKeys: Array>, queryTypes: Array): + Promise> { + const client = instance.getClientInstance(); + if (!client) { + throw new Error(`No variables provider available for session ${instance.session.sessionId}`); + } + if (accessKeys.length === 0) { + throw new Error('No access keys provided for variable data retrieval'); + } + const result = []; + for (const accessKey of accessKeys) { + result.push(await client.comm.queryTableSummary(accessKey, queryTypes)); + } + return result; + } + // Signals that language runtime discovery is complete. $completeLanguageRuntimeDiscovery(): void { this._runtimeStartupService.completeDiscovery(this._id); diff --git a/src/vs/workbench/api/common/positron/extHost.positron.api.impl.ts b/src/vs/workbench/api/common/positron/extHost.positron.api.impl.ts index 57178f87c5e3..ed5e423c3a1a 100644 --- a/src/vs/workbench/api/common/positron/extHost.positron.api.impl.ts +++ b/src/vs/workbench/api/common/positron/extHost.positron.api.impl.ts @@ -140,6 +140,10 @@ export function createPositronApiFactoryAndRegisterActors(accessor: ServicesAcce Thenable>> { return extHostLanguageRuntime.getSessionVariables(sessionId, accessKeys); }, + querySessionTables(sessionId: string, accessKeys: Array>, queryTypes: Array): + Thenable> { + return extHostLanguageRuntime.querySessionTables(sessionId, accessKeys, queryTypes); + }, registerClientHandler(handler: positron.RuntimeClientHandler): vscode.Disposable { return extHostLanguageRuntime.registerClientHandler(handler); }, diff --git a/src/vs/workbench/api/common/positron/extHost.positron.protocol.ts b/src/vs/workbench/api/common/positron/extHost.positron.protocol.ts index 9dc1cf1d636d..e258fb5d07f8 100644 --- a/src/vs/workbench/api/common/positron/extHost.positron.protocol.ts +++ b/src/vs/workbench/api/common/positron/extHost.positron.protocol.ts @@ -16,7 +16,7 @@ import { IAvailableDriverMethods } from '../../browser/positron/mainThreadConnec import { IChatRequestData, IPositronChatContext, IPositronLanguageModelConfig, IPositronLanguageModelSource } from '../../../contrib/positronAssistant/common/interfaces/positronAssistantService.js'; import { IChatAgentData } from '../../../contrib/chat/common/chatAgents.js'; import { PlotRenderSettings } from '../../../services/positronPlots/common/positronPlots.js'; -import { Variable } from '../../../services/languageRuntime/common/positronVariablesComm.js'; +import { QueryTableSummaryResult, Variable } from '../../../services/languageRuntime/common/positronVariablesComm.js'; import { ILanguageRuntimeCodeExecutedEvent } from '../../../services/positronConsole/common/positronConsoleCodeExecution.js'; // NOTE: This check is really to ensure that extHost.protocol is included by the TypeScript compiler @@ -53,6 +53,7 @@ export interface MainThreadLanguageRuntimeShape extends IDisposable { $interruptSession(handle: number): Promise; $focusSession(handle: number): void; $getSessionVariables(handle: number, accessKeys?: Array>): Promise>>; + $querySessionTables(handle: number, accessKeys: Array>, queryTypes: Array): Promise>; $emitLanguageRuntimeMessage(handle: number, handled: boolean, message: SerializableObjectWithBuffers): void; $emitLanguageRuntimeState(handle: number, clock: number, state: RuntimeState): void; $emitLanguageRuntimeExit(handle: number, exit: ILanguageRuntimeExit): void; diff --git a/src/vs/workbench/api/common/positron/extHostLanguageRuntime.ts b/src/vs/workbench/api/common/positron/extHostLanguageRuntime.ts index ee6fb245aed2..2f227f22c2c1 100644 --- a/src/vs/workbench/api/common/positron/extHostLanguageRuntime.ts +++ b/src/vs/workbench/api/common/positron/extHostLanguageRuntime.ts @@ -21,7 +21,7 @@ import { SerializableObjectWithBuffers } from '../../../services/extensions/comm import { VSBuffer } from '../../../../base/common/buffer.js'; import { generateUuid } from '../../../../base/common/uuid.js'; import { CancellationToken } from '../../../../base/common/cancellation.js'; -import { Variable } from '../../../services/languageRuntime/common/positronVariablesComm.js'; +import { QueryTableSummaryResult, Variable } from '../../../services/languageRuntime/common/positronVariablesComm.js'; import { ILanguageRuntimeCodeExecutedEvent } from '../../../services/positronConsole/common/positronConsoleCodeExecution.js'; /** @@ -1224,6 +1224,16 @@ export class ExtHostLanguageRuntime implements extHostProtocol.ExtHostLanguageRu throw new Error(`Session with ID '${sessionId}' not found`); } + public querySessionTables(sessionId: string, accessKeys: Array>, queryTypes: Array): + Promise> { + for (let i = 0; i < this._runtimeSessions.length; i++) { + if (this._runtimeSessions[i].metadata.sessionId === sessionId) { + return this._proxy.$querySessionTables(i, accessKeys, queryTypes); + } + } + throw new Error(`Session with ID '${sessionId}' not found`); + } + /** * Interrupts an active session. * diff --git a/src/vs/workbench/api/common/positron/extHostTypes.positron.ts b/src/vs/workbench/api/common/positron/extHostTypes.positron.ts index 2254c76f4dae..730f865bbb9c 100644 --- a/src/vs/workbench/api/common/positron/extHostTypes.positron.ts +++ b/src/vs/workbench/api/common/positron/extHostTypes.positron.ts @@ -411,5 +411,5 @@ export enum CodeAttributionSource { Script = 'script', } -export { UiRuntimeNotifications } from '../../../services/languageRuntime/common/languageRuntimeService.js' +export { UiRuntimeNotifications } from '../../../services/languageRuntime/common/languageRuntimeService.js'; export { PlotRenderSettings, PlotRenderFormat } from '../../../services/positronPlots/common/positronPlots.js'; diff --git a/src/vs/workbench/services/languageRuntime/common/positronVariablesComm.ts b/src/vs/workbench/services/languageRuntime/common/positronVariablesComm.ts index 09117f30ac5d..5b6dadceebaa 100644 --- a/src/vs/workbench/services/languageRuntime/common/positronVariablesComm.ts +++ b/src/vs/workbench/services/languageRuntime/common/positronVariablesComm.ts @@ -62,6 +62,32 @@ export interface FormattedVariable { } +/** + * Result of the summarize operation + */ +export interface QueryTableSummaryResult { + /** + * The total number of rows in the table. + */ + num_rows: number; + + /** + * The total number of columns in the table. + */ + num_columns: number; + + /** + * The column schemas in the table. + */ + column_schemas: Array; + + /** + * The column profiles in the table. + */ + column_profiles: Array; + +} + /** * A single variable in the runtime. */ @@ -216,6 +242,21 @@ export interface ViewParams { path: Array; } +/** + * Parameters for the QueryTableSummary method. + */ +export interface QueryTableSummaryParams { + /** + * The path to the table to summarize, as an array of access keys. + */ + path: Array; + + /** + * A list of query types. + */ + query_types: Array; +} + /** * Parameters for the Update method. */ @@ -323,7 +364,8 @@ export enum VariablesBackendRequest { Delete = 'delete', Inspect = 'inspect', ClipboardFormat = 'clipboard_format', - View = 'view' + View = 'view', + QueryTableSummary = 'query_table_summary' } export class PositronVariablesComm extends PositronBaseComm { @@ -419,6 +461,21 @@ export class PositronVariablesComm extends PositronBaseComm { return super.performRpc('view', ['path'], [path]); } + /** + * Query table summary + * + * Request a data summary for a table variable. + * + * @param path The path to the table to summarize, as an array of access + * keys. + * @param queryTypes A list of query types. + * + * @returns Result of the summarize operation + */ + queryTableSummary(path: Array, queryTypes: Array): Promise { + return super.performRpc('query_table_summary', ['path', 'query_types'], [path, queryTypes]); + } + /** * Update variables