Last active
July 23, 2025 17:45
-
-
Save jrosell/94ae6c3c6b8fabb8f67574f67325e530 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 1. Preparations & helper functions ----- | |
rlang::check_installed(c("tidyverse", "repurrrsive", "DBI", "duckdb", "ollamar", "ellmer", "glue", "testthat")) | |
library(tidyverse) | |
library(glue) | |
library(repurrrsive) | |
library(DBI) | |
library(duckdb) | |
library(ellmer) | |
library(testthat) | |
build_create_table_sql <- \(df, name) { | |
schema_df <- con |> dbDataType(df) | |
fields_str <- glue(" {names(schema_df)} {unname(schema_df) }") |> paste0(collapse = ",\n") | |
glue("CREATE TABLE {name} (\n{fields_str}\n)") | |
} | |
generate_sql <- \(question, schema_str, model = "qwen3:4b") { | |
prompt_sql_template <- glue::glue( | |
"\ | |
### Instructions: | |
Your task is to convert a question into a SQL query, given a duckdb database schema. | |
Adhere to these rules: | |
- **Deliberately go through the question and database schema word by word** to appropriately answer the question | |
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT t1.col1, t2.col1 FROM table1 as t1 JOIN table2 as t2 ON t1.id = t2.id`. | |
- When creating a ratio, always cast the numerator as float | |
- Ensure the generated SQL query is directly runnable in DuckDB. | |
- Do not include any explanatory text or markdown outside of the SQL query block. | |
- If the database have more than one table doesn't mean that you have or you don't have to join them. | |
### Input: | |
Generate a SQL query that answers the question '{question}'. | |
This query will run on a database whose schema is represented in this string: | |
{schema_str} | |
### Response: | |
Based on your instructions, here is the SQL query I have generated to answer the question '{question}': | |
```sql | |
" | |
) | |
chat <- chat_ollama(model = model) | |
chat$chat(prompt_sql_template, echo = "none") | |
turns <- chat$get_turns() | |
last_turn <- turns[[length(turns)]] | |
last_text <- str_replace_all(last_turn@text, "(?s)<think>.*?</think>", "") | |
cat(last_text) | |
sql_chunk <- str_match(last_text, "```sql\\s*([\\s\\S]+?)```")[,2] | |
return(sql_chunk) | |
} | |
# 2. Get the data ----- | |
users_df <- map_dfr(gh_users, `[`) | |
build_create_table_sql(users_df, "gh_users") |> cat() | |
# CREATE TABLE gh_users ( | |
# login STRING, | |
# id INTEGER, | |
# avatar_url STRING, | |
# gravatar_id STRING, | |
# url STRING, | |
# html_url STRING, | |
# followers_url STRING, | |
# following_url STRING, | |
# gists_url STRING, | |
# starred_url STRING, | |
# subscriptions_url STRING, | |
# organizations_url STRING, | |
# repos_url STRING, | |
# events_url STRING, | |
# received_events_url STRING, | |
# type STRING, | |
# site_admin BOOLEAN, | |
# name STRING, | |
# company STRING, | |
# blog STRING, | |
# location STRING, | |
# email STRING, | |
# public_repos INTEGER, | |
# public_gists INTEGER, | |
# followers INTEGER, | |
# following INTEGER, | |
# created_at STRING, | |
# updated_at STRING, | |
# bio STRING, | |
# hireable BOOLEAN | |
# ) | |
users_schema <- "CREATE TABLE gh_users ( | |
login STRING PRIMARY KEY, | |
id INTEGER, | |
avatar_url STRING, | |
gravatar_id STRING, | |
url STRING, | |
html_url STRING, | |
followers_url STRING, | |
following_url STRING, | |
gists_url STRING, | |
starred_url STRING, | |
subscriptions_url STRING, | |
organizations_url STRING, | |
repos_url STRING, | |
events_url STRING, | |
received_events_url STRING, | |
type STRING, | |
site_admin BOOLEAN, | |
name STRING, | |
company STRING, | |
blog STRING, | |
location STRING, | |
email STRING, | |
public_repos INTEGER, | |
public_gists INTEGER, | |
followers INTEGER, | |
following INTEGER, | |
created_at STRING, | |
updated_at STRING, | |
bio STRING, | |
hireable BOOLEAN | |
)" | |
repos_df <- gh_repos |> | |
set_names(map_chr(gh_repos, c(1, 4, 1))) |> | |
enframe("username", "gh_repos") |> | |
mutate(n_repos = map_int(gh_repos, length)) |> | |
mutate( | |
repo_info = gh_repos |> | |
map(\(x) map_df(x, `[`, c("name", "fork", "open_issues"))) | |
) |> | |
select(-gh_repos) |> | |
unnest(repo_info) | |
build_create_table_sql(repos_df, "gh_repos") |> cat() | |
# CREATE TABLE gh_repos ( | |
# username STRING, | |
# n_repos INTEGER, | |
# name STRING, | |
# fork BOOLEAN, | |
# open_issues INTEGER | |
# ) | |
repos_schema <- "CREATE TABLE gh_repos ( | |
username STRING, | |
n_repos INTEGER, | |
name STRING, | |
fork BOOLEAN, | |
open_issues INTEGER , | |
PRIMARY KEY (username, name) | |
)" | |
schema_str <- glue(" | |
{repos_schema} | |
{users_schema} | |
") | |
if (exists("con") && DBI::dbIsValid(con)) { | |
dbDisconnect(con, shutdown = TRUE) | |
} | |
con <- dbConnect(duckdb::duckdb(), dbdir = ":memory:") | |
on.exit(dbDisconnect(con, shutdown = TRUE), add = TRUE) | |
con |> dbExecute(users_schema) | |
con |> dbWriteTable("gh_users", users_df, append=TRUE, row.names=FALSE) | |
con |> dbExecute(repos_schema) | |
con |> dbWriteTable("gh_repos", repos_df, append=TRUE, row.names=FALSE) | |
# 3. Examples ---- | |
## 3.1 What are the 3 repos with more open_issues? ---- | |
test_that("LLM correctly gets the top 3 repos by open_issues", { | |
question <- "What are the 3 repos with more open_issues?" | |
sql_chunk <- generate_sql(question, schema_str = schema_str) | |
result_df <- dbGetQuery(con, sql_chunk) |> as_tibble() | |
expected_df <- repos_df |> | |
summarize(.by = name, open_issues = sum(open_issues)) |> | |
arrange(desc(open_issues)) |> | |
head(3) | |
expect_equal(result_df, expected_df) | |
}) | |
# # A tibble: 3 × 2 | |
# name open_issues | |
# <chr> <int> | |
# 1 datasharing 399 | |
# 2 parr 14 | |
# 3 crandatapkgs 12 | |
## 3.2 For each GitHub username in the repos table, find their top 3 original (non-forked) repositories with the most open issues ----- | |
test_that("LLM correctly gets the top 3 repos with most open issues per user", { | |
question <- "For each GitHub username in the repos table, find their top 3 original (non-forked) repositories with the most open issues" | |
sql_chunk <- generate_sql(question, schema_str = schema_str) | |
result_df <- dbGetQuery(con, sql_chunk) |> as_tibble() | |
expected_df <- repos_df |> | |
filter(!fork) |> | |
select(-fork) |> | |
group_by(username) |> | |
arrange(username, desc(open_issues)) |> | |
slice_head(n = 3) |> | |
ungroup() |> | |
select(username, name, open_issues) | |
expect_equal(result_df, expected_df) | |
}) | |
# # A tibble: 18 × 3 | |
# username name open_issues | |
# <chr> <chr> <int> | |
# 1 gaborcsardi gh 8 | |
# 2 gaborcsardi crayon 7 | |
# 3 gaborcsardi argufy 6 | |
# 4 jennybc 2014-01-27-miami 4 | |
# 5 jennybc bingo 3 | |
# 6 jennybc candy 2 | |
# 7 jtleek datasharing 399 | |
# 8 jtleek dataanalysis 5 | |
# 9 jtleek genstats 3 | |
# 10 juliasilge tidytext 5 | |
# 11 juliasilge choroplethrUTCensusTract 0 | |
# 12 juliasilge CountyHealthApp 0 | |
# 13 leeper crandatapkgs 12 | |
# 14 leeper csvy 2 | |
# 15 leeper ciplotm 1 | |
# 16 masalmon cpcb 5 | |
# 17 masalmon rtimicropem 5 | |
# 18 masalmon laads 4 | |
# 4. TODO: RAG? Tools/funcctions? ----- |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment