Skip to content

[ES|QL] RERANK command validation support #221004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/platform/packages/shared/kbn-esql-ast/src/ast/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import type {
ESQLIdentifier,
ESQLIntegerLiteral,
ESQLLiteral,
ESQLLocation,
ESQLParamLiteral,
ESQLProperNode,
ESQLSource,
Expand Down Expand Up @@ -45,6 +46,10 @@ export const isFunctionExpression = (node: unknown): node is ESQLFunction =>
export const isBinaryExpression = (node: unknown): node is ESQLBinaryExpression =>
isFunctionExpression(node) && node.subtype === 'binary-expression';

export const isAssignment = (node: unknown): node is ESQLBinaryExpression<'='> => {
return isBinaryExpression(node) && node.name === '=';
};

export const isWhereExpression = (
node: unknown
): node is ESQLBinaryExpression<BinaryExpressionWhereOperator> =>
Expand Down Expand Up @@ -84,6 +89,21 @@ export const isSource = (node: unknown): node is ESQLSource =>
export const isIdentifier = (node: unknown): node is ESQLIdentifier =>
isProperNode(node) && node.type === 'identifier';

export const isContainedLocation = (container: ESQLLocation, contained: ESQLLocation): boolean => {
return container.min <= contained.min && container.max >= contained.max;
};

export const isContained = (
container: { location?: ESQLLocation },
contained: { location?: ESQLLocation }
): boolean => {
if (!container.location || !contained.location) {
return false;
}

return isContainedLocation(container.location, contained.location);
};

/**
* Returns the group of a binary expression:
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ export class ESQLErrorListener extends ErrorListener<any> {
}

const textMessage = `SyntaxError: ${message}`;

const tokenPosition = getPosition(offendingSymbol);
const startColumn = offendingSymbol && tokenPosition ? tokenPosition.min + 1 : column + 1;
const endColumn = offendingSymbol && tokenPosition ? tokenPosition.max + 1 : column + 2;
Expand All @@ -54,6 +53,10 @@ export class ESQLErrorListener extends ErrorListener<any> {
startColumn,
endColumn,
message: textMessage,
location: {
min: tokenPosition.min,
max: tokenPosition.max - 1,
},
severity: 'error',
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ export const createRerankCommand = (ctx: RerankCommandContext): ESQLAstRerankCom
const fields = visitRerankFields(fieldsCtx);
const inferenceIdCtx = ctx._inferenceId;
const maybeInferenceId = inferenceIdCtx ? createIdentifierOrParam(inferenceIdCtx) : undefined;
const inferenceId = maybeInferenceId ?? Builder.identifier('', { incomplete: true });
const inferenceId =
maybeInferenceId ??
Builder.identifier('', {
incomplete: true,
location: inferenceIdCtx ? getPosition(inferenceIdCtx.start, inferenceIdCtx.stop) : undefined,
});
const command = createCommand<'rerank', ESQLAstRerankCommand>('rerank', ctx, {
query,
fields,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ export const parse = (text: string | undefined, options: ParseOptions = {}): Par
endColumn: 0,
message: `Invalid query [${text}]`,
severity: 'error',
location: {
min: 0,
max: 0,
},
},
],
tokens: [],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,12 @@ export class BasicPrettyPrinter {
args += (args ? ', ' : '') + arg;
}

const formatted = `${operator}(${args})`;
const isParentShowCommand =
!args.length &&
(ctx.parent?.node as any)?.type === 'command' &&
(ctx.parent?.node as any)?.name === 'show';

const formatted = isParentShowCommand ? operator : `${operator}(${args})`;

return this.decorateWithComments(ctx.node, formatted);
}
Expand Down
1 change: 1 addition & 0 deletions src/platform/packages/shared/kbn-esql-ast/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ export interface EditorError {
startColumn: number;
endColumn: number;
message: string;
location: ESQLLocation;
code?: string;
severity: 'error' | 'warning' | number;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ import { METADATA_FIELDS } from '../shared/constants';
import { getMessageFromId } from '../validation/errors';
import { isNumericType } from '../shared/esql_types';

import { definition as rerankDefinition } from './commands/rerank';

const statsValidator = (command: ESQLCommand) => {
const messages: ESQLMessage[] = [];
const commandName = command.name.toUpperCase();
Expand Down Expand Up @@ -706,4 +708,5 @@ export const commandDefinitions: Array<CommandDefinition<any>> = [

fieldsSuggestionsAfter: fieldsSuggestionsAfterFork,
},
rerankDefinition,
];
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
import { i18n } from '@kbn/i18n';
import { ESQLAstExpression, ESQLAstRerankCommand, EditorError } from '@kbn/esql-ast/src/types';
import {
isAssignment,
isBinaryExpression,
isColumn,
isContained,
isStringLiteral,
} from '@kbn/esql-ast/src/ast/helpers';
import { Walker, type ESQLCommand, type ESQLMessage } from '@kbn/esql-ast';
import { isParam } from '../../shared/helpers';
import { errors } from '../../validation/errors';
import type { CommandDefinition } from '../types';
import { validateColumnForCommand } from '../../validation/validation';
import { ReferenceMaps } from '../../validation/types';

const parsingErrorsToMessages = (parsingErrors: EditorError[], cmd: ESQLCommand): ESQLMessage[] => {
const command = cmd as ESQLAstRerankCommand;
const messages: ESQLMessage[] = [];

const { inferenceId } = command;
const inferenceIdParsingError = parsingErrors.some((error) => isContained(inferenceId, error));

// Check if there is a problem with parsing inference ID.
if (inferenceIdParsingError) {
const error = errors.rerankInferenceIdMustBeIdentifier(inferenceId);

messages.push(error);
}

return messages;
};

/**
* Returns tru if a field is *named*. Named field is one where a column is
* used directly, e.g. `field.name`, or where a new column is defined using
* an assignment, e.g. `field.name = AVG(1, 2, 4)`.
*/
const isNamedField = (field: ESQLAstExpression) => {
if (isColumn(field)) {
return true;
}

if (isBinaryExpression(field)) {
if (field.name !== '=') {
return false;
}

const left = field.args[0];

return isColumn(left);
}

return false;
};

const validate = (cmd: ESQLCommand, references: ReferenceMaps) => {
const command = cmd as ESQLAstRerankCommand;
const messages: ESQLMessage[] = [];

if (command.args.length < 3) {
messages.push({
location: command.location,
text: i18n.translate('kbn-esql-validation-autocomplete.esql.validation.forkTooFewArguments', {
defaultMessage: '[RERANK] Command is not complete.',
}),
type: 'error',
code: 'rerankTooFewArguments',
});
}

const { query, fields } = command;

// Check that <query> is a string literal or a parameter
if (!isStringLiteral(query) && !isParam(query)) {
const error = errors.rerankQueryMustBeString(query);

messages.push(error);
}

const fieldLength = fields.length;

for (let i = 0; i < fieldLength; i++) {
const field = fields[i];

// Check that <fields> are either columns or new column definitions
if (!isNamedField(field)) {
const error = errors.rerankFieldMustBeNamed(field);

messages.push(error);
}

// Check if all (deeply nested) columns exist.
const columnExpressionToCheck = isAssignment(field) ? field.args[1] : field;

Walker.walk(columnExpressionToCheck, {
visitColumn: (node) => {
const fieldErrors = validateColumnForCommand(node, 'rerank', references);

if (fieldErrors.length > 0) {
messages.push(...fieldErrors);
}
},
});
}

return messages;
};

export const definition = {
hidden: true,
name: 'rerank',
preview: true,
description: i18n.translate('kbn-esql-validation-autocomplete.esql.definitions.rerankDoc', {
defaultMessage: 'Reorder results using a semantic reranker.',
}),
declaration: 'RERANK <query> ON <field1> [, <field2> [, ...]] WITH <inferenceID>',
examples: [],
suggest: () => {
throw new Error('Not implemented');
},
parsingErrorsToMessages,
validate,
// TODO: implement `.fieldsSuggestionsAfter()`
} satisfies CommandDefinition<'rerank'>;
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import type {
ESQLMessage,
ESQLSource,
ESQLAstCommand,
EditorError,
} from '@kbn/esql-ast';
import { ESQLControlVariable } from '@kbn/esql-types';
import { GetColumnsByTypeFn, SuggestionRawDefinition } from '../autocomplete/types';
Expand Down Expand Up @@ -414,6 +415,25 @@ export interface CommandDefinition<CommandName extends string> {
*/
hidden?: boolean;

/**
* Return nicely formatted human-readable out of parser errors. This callback
* lets commands construct their own error messages out of parser errors,
* as parser errors have the following drawbacks:
*
* 1. Not human-readable, even hard to read for developers.
* 2. Not translated to other languages, e.g. Chinese.
* 3. Depend on ANTLR grammar, which is not stable and may change in the future.
*
* @param parsingErrors List of parsing errors returned by the ANTLR parser
* for this command.
* @returns Human-readable, translatable messages for the user.
*/
parsingErrorsToMessages?: (
parsingErrors: EditorError[],
command: ESQLCommand<CommandName>,
references: ReferenceMaps
) => ESQLMessage[];

/**
* This method is run when the command is being validated, but it does not
* prevent the default behavior. If you need a full override, we are currently
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ function buildCommandLookup(): Map<string, CommandDefinition<string>> {
export function getCommandDefinition<CommandName extends string>(
name: CommandName
): CommandDefinition<CommandName> {
return buildCommandLookup().get(name.toLowerCase()) as unknown as CommandDefinition<CommandName>;
const map = buildCommandLookup();

return map.get(name.toLowerCase()) as unknown as CommandDefinition<CommandName>;
}

export function getAllCommands() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { EditorError, ESQLMessage } from '@kbn/esql-ast';
import { ESQLCallbacks } from '../../shared/types';
import { getCallbackMocks } from '../../__tests__/helpers';
import { ValidationOptions } from '../types';
import { validateQuery } from '../validation';
import * as validation from '../validation';

/** Validation test API factory, can be called at the start of each unit test. */
export type Setup = typeof setup;
Expand All @@ -29,7 +29,7 @@ export const setup = async () => {
opts: ValidationOptions = {},
cb: ESQLCallbacks = callbacks
) => {
return await validateQuery(query, opts, cb);
return await validation.validateQuery(query, opts, cb);
};

const assertErrors = (errors: unknown[], expectedErrors: string[], query?: string) => {
Expand Down Expand Up @@ -66,7 +66,7 @@ export const setup = async () => {
opts: ValidationOptions = {},
cb: ESQLCallbacks = callbacks
) => {
const { errors, warnings } = await validateQuery(query, opts, cb);
const { errors, warnings } = await validation.validateQuery(query, opts, cb);
assertErrors(errors, expectedErrors, query);
if (expectedWarnings) {
assertErrors(warnings, expectedWarnings, query);
Expand Down
Loading