Created
October 17, 2024 18:06
-
-
Save cpfiffer/9aa894b2e302a96cd792d58c067472b7 to your computer and use it in GitHub Desktop.
Outlines code to generate text-to-sql from a context-free grammar
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import outlines | |
sql_grammar = """ | |
start: set_expr -> final | |
set_expr: query_expr | |
| set_expr "UNION"i ["DISTINCT"i] set_expr -> union_distinct | |
| set_expr "UNION"i "ALL"i set_expr -> union_all | |
| set_expr "INTERSECT"i ["DISTINCT"i] set_expr -> intersect_distinct | |
| set_expr "EXCEPT"i ["DISTINCT"i] set_expr -> except_distinct | |
| set_expr "EXCEPT"i "ALL"i set_expr -> except_all | |
query_expr: select [ "ORDER"i "BY"i (order_by_expr ",")* order_by_expr] [ "LIMIT"i limit_count [ "OFFSET"i skip_rows ] ] | |
select: "SELECT"i [SELECT_CONSTRAINT] [(select_expr ",")*] select_expr "FROM"i [(from_expr ",")*] from_expr [ "WHERE"i where_expr ] [ "GROUP"i "BY"i [(groupby_expr ",")*] groupby_expr ] [ "HAVING"i having_expr] [ "WINDOW"i window_expr ] | |
where_expr: bool_expression | |
select_expr.0: expression_math [ [ "AS"i ] alias ] -> select_expression | |
?from_expr: from_item -> from_expression | |
order_by_expr: order -> order_by_expression | |
having_expr: bool_expression | |
groupby_expr: expression -> group_by | |
window_expr: [window_expr ","] _window_name "AS"i ( window_definition ) | |
from_item: name [ [ "AS"i ] alias ] -> table | |
| join -> join | |
| cross_join -> cross_join_expression | |
| subquery | |
subquery: ( "(" (query_expr | join | cross_join) ")" ) [ [ "AS"i ] alias ] | |
cross_join: from_item "CROSS"i "JOIN"i from_item | |
join: from_item [ JOIN_TYPE ] "JOIN"i from_item [ "ON"i bool_expression ] -> join_expression | |
JOIN_TYPE.5: "INNER"i | /FULL\sOUTER/i | /LEFT\sOUTER/i | /RIGHT\sOUTER/i | "FULL"i | "LEFT"i | "RIGHT"i | |
?expression_math: expression_product | |
| expression_math "+" expression_product -> expression_add | |
| expression_math "-" expression_product -> expression_sub | |
| "CASE"i (when_then)+ "ELSE"i expression_math "END"i -> case_expression | |
| "CAST"i "(" expression_math "AS"i TYPENAME ")" -> as_type | |
| "CAST"i "(" literal "AS"i TYPENAME ")" -> literal_cast | |
| AGGREGATION expression_math ")" [window_form] -> sql_aggregation | |
| "RANK"i "(" ")" window_form -> rank_expression | |
| "DENSE_RANK"i "(" ")" window_form -> dense_rank_expression | |
| "COALESCE"i "(" [(expression_math ",")*] expression_math ")" -> coalesce_expression | |
window_form: "OVER"i "(" ["PARTITION"i "BY"i (partition_by ",")* partition_by] ["ORDER"i "BY"i (order ",")* order [ row_range_clause ] ] ")" | |
partition_by: expression_math | |
row_range_clause: ( ROWS | RANGE ) frame_extent | |
frame_extent: frame_between | frame_preceding | |
frame_between: "BETWEEN"i frame_bound "AND"i frame_bound | |
frame_bound: frame_preceding | frame_following | "CURRENT"i "ROW"i | |
frame_preceding: UNBOUNDED PRECEDING | integer_ PRECEDING | |
frame_following: UNBOUNDED FOLLOWING | integer_ FOLLOWING | |
RANGE: "RANGE"i | |
ROWS: "ROWS"i | |
UNBOUNDED: "UNBOUNDED"i | |
PRECEDING: "PRECEDING"i | |
FOLLOWING: "FOLLOWING"i | |
when_then: "WHEN"i bool_expression "THEN"i expression_math | |
order: expression_math ["ASC"i] -> order_asc | |
| expression_math "DESC"i -> order_desc | |
column_name: [name "."] name | |
?expression_product: expression_parens | |
| expression_product "*" expression_parens -> expression_mul | |
| expression_product "/" expression_parens -> expression_div | |
?expression_parens: expression | |
| "(" expression_parens "*" expression ")" -> expression_mul | |
| "(" expression_parens "/" expression ")" -> expression_div | |
| "(" expression_parens "+" expression ")" -> expression_add | |
| "(" expression_parens "-" expression ")" -> expression_sub | |
?expression: [name "."] (name | STAR) -> column_name | |
| literal | |
SELECT_CONSTRAINT.9: "ALL"i | "DISTINCT"i | |
TYPENAME: "object"i | |
| "varchar"i | |
| "integer"i | |
| "int16"i | |
| "smallint"i | |
| "int32"i | |
| "int64"i | |
| "int"i | |
| "bigint"i | |
| "float16"i | |
| "float32"i | |
| "float64"i | |
| "float"i | |
| "bool"i | |
| "datetime64"i | |
| "timestamp"i | |
| "time"i | |
| "date"i | |
| "category"i | |
| "string"i | |
AGGREGATION.8: ("sum("i | "avg("i | "min("i | "max("i | "count("i "distinct"i | "count("i) | |
alias: name -> alias_string | |
_window_name: name | |
limit_count: integer_ -> limit_count | |
skip_rows: integer_ | |
bool_expression: bool_parentheses | |
| bool_expression "AND"i bool_parentheses -> bool_and | |
| bool_expression "OR"i bool_parentheses -> bool_or | |
bool_parentheses: comparison_type | |
| "(" bool_expression "AND"i comparison_type ")" -> bool_and | |
| "(" bool_expression "OR"i comparison_type ")" -> bool_or | |
comparison_type: equals | not_equals | greater_than | less_than | greater_than_or_equal | |
| less_than_or_equal | between | in_expr | not_in_expr | subquery_in | is_null | is_not_null | |
equals: expression_math "=" expression_math | |
is_null: expression_math "is"i "null"i | |
is_not_null: expression_math "is"i "not"i "null"i | |
not_equals: expression_math ("<>" | "!=") expression_math | |
greater_than: expression_math ">" expression_math | |
less_than: expression_math "<" expression_math | |
greater_than_or_equal: expression_math ">=" expression_math | |
less_than_or_equal: expression_math "<=" expression_math | |
between: expression_math "BETWEEN"i expression_math "AND"i expression_math | |
in_expr: expression_math "IN"i "(" [expression_math ","]* expression_math ")" | |
subquery_in: expression_math "IN"i subquery | |
not_in_expr: expression_math "NOT"i "IN"i "(" [expression_math ","]* expression_math ")" | |
?literal: boolean -> bool | |
| number_expr -> number | |
| /'([^']|\s)+'|''/ -> string | |
| timestamp_expression -> timestamp_expression | |
boolean: "true"i -> true | |
| "false"i -> false | |
?number_expr: product | |
?product: NUMBER | |
integer_: /[1-9][0-9]*/ | |
STAR: "*" | |
window_definition: | |
timestamp_expression: "NOW"i "(" ")" -> datetime_now | |
| "TODAY"i "(" ")" -> date_today | |
| "TIMESTAMP"i "(" "'" date "'" "," "'" time "'" ")" -> custom_timestamp | |
date: YEAR "-" MONTH "-" DAY | |
YEAR: /[0-9]{4}/ | |
MONTH: /[0-9]{2}/ | |
DAY: /[0-9]{2}/ | |
time: HOURS ":" MINUTES ":" SECONDS | |
HOURS: /[0-9]{2}/ | |
MINUTES: /[0-9]{2}/ | |
SECONDS: /[0-9]{2}/ | |
name: CNAME | ESCAPED_STRING | |
%import common.ESCAPED_STRING | |
%import common.CNAME | |
%import common.NUMBER | |
%import common.WS | |
%ignore WS | |
""" | |
model = outlines.models.transformers( | |
"microsoft/Phi-3-mini-4k-instruct", | |
device="cpu" | |
) | |
generator = outlines.generate.cfg(model, sql_grammar) | |
sequence = generator( | |
"Write valid SQL to select the name and age from the users table where the age is greater than 20" | |
) | |
print(sequence) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment