-
Notifications
You must be signed in to change notification settings - Fork 35
Add SQL Metrics Implementation #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
da8a897
6b7bd38
b362c54
5c405ba
c0920e4
53cd8b5
892bc2b
abcc1f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import List, Union | ||
|
||
import sqlparse | ||
|
||
from continuous_eval.metrics.base import Metric | ||
|
||
|
||
class SQLSyntaxMatch(Metric): | ||
""" | ||
This metric evaluates the syntactic similarity between the generated SQL query and a set of ground truth queries. | ||
It uses the sqlparse library to format and compare the SQL queries. | ||
""" | ||
|
||
def __call__(self, answer: str, ground_truth_answers: Union[List[str], str]): | ||
if isinstance(ground_truth_answers, str): | ||
ground_truth_answers = [ground_truth_answers] | ||
|
||
# Format the answer and ground truth answers using sqlparse for consistent comparison | ||
formatted_answer = sqlparse.format(answer, reindent=True, keyword_case="upper") | ||
formatted_ground_truths = [ | ||
sqlparse.format(gt, reindent=True, keyword_case="upper") | ||
for gt in ground_truth_answers | ||
] | ||
|
||
# Initialize the maximum match score | ||
max_match_score = 0 | ||
|
||
# Compare the formatted answer with each formatted ground truth answer | ||
for formatted_gt in formatted_ground_truths: | ||
# Simple string comparison for now, can be improved with more sophisticated methods | ||
match_score = float(formatted_answer == formatted_gt) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider implementing a more sophisticated comparison method than simple string equality to handle cases where SQL queries might be functionally identical but differ in formatting or syntax. This could improve the robustness of the syntactic similarity evaluation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ellipsis-dev come up with a few more sophisticated ways to handle functionally identical cases There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
if match_score > max_match_score: | ||
max_match_score = match_score | ||
|
||
# Return the maximum match score | ||
return {"SQL_Syntax_Match_Score": max_match_score} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
--- | ||
title: SQL Syntax Match | ||
sidebar: | ||
order: 2 | ||
--- | ||
|
||
## Definitions | ||
|
||
**SQL Syntax Match** evaluates the syntactic similarity between generated SQL queries and a set of ground truth queries. It compares the structure and syntax of SQL statements to determine how closely they match, considering the order and type of clauses, keywords, and conditions. | ||
|
||
:::note | ||
The metric requires syntactically correct SQL queries to function properly. If the queries contain syntax errors and cannot be parsed, the metric will yield a score of 0.0. | ||
::: | ||
|
||
## Example Usage | ||
|
||
Required data items: `answer`, `ground_truth_answers` | ||
|
||
```python | ||
from continuous_eval.metrics.code.sql.sql_deterministic_metrics import SQLSyntaxMatch | ||
|
||
# Instantiate the metric | ||
sql_syntax_match = SQLSyntaxMatch() | ||
|
||
# Evaluate syntactic similarity | ||
result = sql_syntax_match(answer="SELECT * FROM users;", ground_truth_answers=["SELECT * FROM users;"]) | ||
print(result) # Output: {"SQL_Syntax_Match_Score": 1.0} | ||
``` | ||
|
||
## Example Output | ||
|
||
```JSON | ||
{ | ||
"SQL_Syntax_Match_Score": 1.0 | ||
} | ||
``` | ||
|
||
The `SQLSyntaxMatch` class returns a dictionary with a single key-value pair. The key is `SQL_Syntax_Match_Score`, and the value is a float representing the syntactic match score. A score of `1.0` indicates an exact match, while a score of `0.0` indicates no match. | ||
|
||
For more detailed examples and advanced usage, please refer to the test cases in the `code_metrics_test.py` file in the `tests` directory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be better I think: https://github.com/tobymao/sqlglot?tab=readme-ov-file#ast-diff