Last active
November 29, 2024 02:32
-
-
Save dylanpieper/a02cc4b009baa47fc0e0d7350197114f to your computer and use it in GitHub Desktop.
Compare Sequential and Parallel Chats Using Elmer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(shiny) | |
library(furrr) | |
library(elmer) | |
library(text2vec) | |
library(tm) | |
library(bslib) | |
library(waiter) | |
# Set your API keys upfront | |
Sys.setenv(OPENAI_API_KEY = "") | |
Sys.setenv(GOOGLE_API_KEY = "") | |
Sys.setenv(ANTHROPIC_API_KEY = "") | |
# Text preprocessing | |
preprocess_text <- function(text) { | |
text <- tolower(text) | |
text <- removePunctuation(text) | |
text <- removeNumbers(text) | |
text <- stripWhitespace(text) | |
text <- removeWords(text, stopwords("en")) | |
return(text) | |
} | |
# Calculate Jaccard similarity with names | |
calculate_text_similarity <- function(texts, names) { | |
texts <- lapply(texts, preprocess_text) | |
tokenized_texts <- lapply(texts, function(text) unique(unlist(strsplit(text, "\\s+")))) | |
n <- length(tokenized_texts) | |
jaccard_matrix <- matrix(0, n, n, dimnames = list(names, names)) | |
for (i in 1:n) { | |
for (j in i:n) { | |
intersection <- length(intersect(tokenized_texts[[i]], tokenized_texts[[j]])) | |
union <- length(union(tokenized_texts[[i]], tokenized_texts[[j]])) | |
jaccard_matrix[i, j] <- intersection / union | |
jaccard_matrix[j, i] <- jaccard_matrix[i, j] | |
} | |
} | |
return(jaccard_matrix) | |
} | |
# Safe chat for string responses | |
safe_chat <- function(chat_function, prompt, ...) { | |
tryCatch({ | |
result <- chat_function(prompt, echo = FALSE) | |
list(success = TRUE, response = result, error = NULL) | |
}, error = function(e) { | |
list(success = FALSE, response = NULL, error = as.character(e)) | |
}) | |
} | |
# Parallel chat processing | |
parallel_chat <- function(chat_apis, prompts) { | |
if (length(chat_apis) == 0 || length(prompts) == 0) { | |
stop("Must provide at least one chat API and one prompt") | |
} | |
if (length(prompts) < length(chat_apis)) { | |
prompts <- rep(prompts, length.out = length(chat_apis)) | |
} | |
results <- future_map2( | |
chat_apis, | |
prompts, | |
function(chat_api, prompt) { | |
safe_chat(chat_api$chat, prompt) | |
}, | |
.options = furrr_options(seed = TRUE) | |
) | |
names(results) <- names(chat_apis) | |
return(results) | |
} | |
# Shiny UI | |
ui <- fluidPage( | |
theme = bs_theme( | |
bg = "#141414", fg = "white", primary = "#FCC780", | |
base_font = font_google("Space Mono"), | |
code_font = font_google("Space Mono") | |
), | |
use_waiter(), | |
titlePanel("Compare Sequential and Parallel Chats Using Elmer"), | |
sidebarLayout( | |
sidebarPanel( | |
textInput("prompt", "Enter Prompt:", value = "Explain the R programming language in one paragraph."), | |
numericInput("num_simulations", "Number of Simulations:", value = 10, min = 1), | |
actionButton("process_chats", "Process Chats", class = "btn-primary"), | |
actionButton("simulate", "Run Simulation", class = "btn-secondary"), | |
checkboxGroupInput("selected_apis", "Select APIs:", | |
choices = c("OpenAI" = "openai", "Gemini" = "gemini", "Claude" = "claude"), | |
selected = c("openai", "gemini", "claude")) | |
), | |
mainPanel( | |
uiOutput("plan_comparison"), | |
uiOutput("chat_results"), | |
uiOutput("similarity_score"), | |
uiOutput("simulation_results") | |
) | |
) | |
) | |
# Shiny Server | |
server <- function(input, output, session) { | |
chat_results_multisession <- reactiveVal(NULL) | |
chat_results_sequential <- reactiveVal(NULL) | |
similarity_score <- reactiveVal(NULL) | |
# Create a global waiter | |
w <- Waiter$new( | |
html = spin_heartbeat(), | |
color = transparent(0.7) | |
) | |
observeEvent(input$process_chats, { | |
w$show() | |
chat_apis <- list() | |
if ("openai" %in% input$selected_apis) { | |
chat_apis$openai <- chat_openai( | |
model = "gpt-4o-mini", | |
system_prompt = "You are a helpful assistant.", | |
echo = FALSE | |
) | |
} | |
if ("gemini" %in% input$selected_apis) { | |
chat_apis$gemini <- chat_gemini( | |
model = "gemini-1.5-pro", | |
system_prompt = "You are a helpful assistant.", | |
echo = FALSE | |
) | |
} | |
if ("claude" %in% input$selected_apis) { | |
chat_apis$claude <- chat_claude( | |
model = "claude-3-sonnet-20240229", | |
system_prompt = "You are a helpful assistant.", | |
echo = FALSE | |
) | |
} | |
# Measure time for multisession plan | |
multisession_time <- system.time({ | |
plan(multisession) | |
results_multisession <- parallel_chat( | |
chat_apis, | |
rep(input$prompt, length(chat_apis)) | |
) | |
}) | |
# Measure time for sequential plan | |
sequential_time <- system.time({ | |
plan(sequential) | |
results_sequential <- parallel_chat( | |
chat_apis, | |
rep(input$prompt, length(chat_apis)) | |
) | |
}) | |
# Determine winner and store results | |
if (multisession_time["elapsed"] <= sequential_time["elapsed"]) { | |
winner <- "Multisession" | |
winner_results <- results_multisession | |
} else { | |
winner <- "Sequential" | |
winner_results <- results_sequential | |
} | |
chat_results_multisession(winner_results) | |
# Calculate similarity score across all responses from the winning method | |
if (length(winner_results) > 1) { | |
responses <- lapply(winner_results, `[[`, "response") | |
sim_score_matrix <- tryCatch({ | |
calculate_text_similarity(responses, names(chat_apis)) | |
}, error = function(e) NA) | |
similarity_score(sim_score_matrix) | |
} else { | |
similarity_score(NULL) | |
} | |
w$hide() | |
# Output the comparison | |
output$plan_comparison <- renderUI({ | |
tagList( | |
div( | |
class = ifelse(winner == "Multisession", "card bg-success text-white", "card bg-warning text-white"), | |
div( | |
class = "card-header", | |
strong("Multisession Chats") | |
), | |
div( | |
class = "card-body", | |
p(sprintf("Time: %.2f seconds", multisession_time["elapsed"])) | |
) | |
), | |
div( | |
class = ifelse(winner == "Sequential", "card bg-success text-white", "card bg-warning text-white"), | |
div( | |
class = "card-header", | |
strong("Sequential Chats") | |
), | |
div( | |
class = "card-body", | |
p(sprintf("Time: %.2f seconds", sequential_time["elapsed"])) | |
) | |
) | |
) | |
}) | |
}) | |
output$chat_results <- renderUI({ | |
req(chat_results_multisession()) | |
result_panels <- lapply(names(chat_results_multisession()), function(api_name) { | |
result <- chat_results_multisession()[[api_name]] | |
if (result$success) { | |
div( | |
class = "card mb-3", | |
div( | |
class = "card-header bg-info text-white", | |
strong(paste(api_name)) | |
), | |
div( | |
class = "card-body", | |
HTML(markdown::markdownToHTML(text = result$response, fragment.only = TRUE)) | |
) | |
) | |
} else { | |
div( | |
class = "card mb-3", | |
div( | |
class = "card-header bg-warning text-white", | |
strong(paste(api_name, "Error")) | |
), | |
div( | |
class = "card-body text-danger", | |
p(result$error) | |
) | |
) | |
} | |
}) | |
do.call(tagList, result_panels) | |
}) | |
output$similarity_score <- renderUI({ | |
req(similarity_score()) | |
sim_matrix <- similarity_score() | |
sim_text <- paste(capture.output(print(sim_matrix)), collapse = "\n") | |
div( | |
class = "card", | |
div( | |
class = "card-header bg-success text-white", | |
strong("Response Similarity") | |
), | |
div( | |
class = "card-body", | |
pre(sim_text) | |
) | |
) | |
}) | |
# Simulation function | |
run_simulation <- function(n, chat_apis, prompt) { | |
multisession_times <- numeric(n) | |
sequential_times <- numeric(n) | |
for (i in 1:n) { | |
# Measure time for multisession plan | |
multisession_times[i] <- system.time({ | |
plan(multisession) | |
parallel_chat(chat_apis, rep(prompt, length(chat_apis))) | |
})["elapsed"] | |
# Measure time for sequential plan | |
sequential_times[i] <- system.time({ | |
plan(sequential) | |
parallel_chat(chat_apis, rep(prompt, length(chat_apis))) | |
})["elapsed"] | |
} | |
list( | |
multisession_avg = mean(multisession_times), | |
sequential_avg = mean(sequential_times) | |
) | |
} | |
observeEvent(input$simulate, { | |
w$show() | |
chat_apis <- list() | |
if ("openai" %in% input$selected_apis) { | |
chat_apis$openai <- chat_openai( | |
model = "gpt-4o-mini", | |
system_prompt = "You are a helpful assistant.", | |
echo = FALSE | |
) | |
} | |
if ("gemini" %in% input$selected_apis) { | |
chat_apis$gemini <- chat_gemini( | |
model = "gemini-1.5-pro", | |
system_prompt = "You are a helpful assistant.", | |
echo = FALSE | |
) | |
} | |
if ("claude" %in% input$selected_apis) { | |
chat_apis$claude <- chat_claude( | |
model = "claude-3-sonnet-20240229", | |
system_prompt = "You are a helpful assistant.", | |
echo = FALSE | |
) | |
} | |
num_simulations <- input$num_simulations | |
simulation_results <- run_simulation(num_simulations, chat_apis, input$prompt) | |
w$hide() | |
output$simulation_results <- renderUI({ | |
div( | |
class = "card", | |
div( | |
class = "card-header bg-info text-white", | |
strong(sprintf("Simulation Results (%d Runs)", num_simulations)) | |
), | |
div( | |
class = "card-body", | |
p(sprintf("Average Multisession Time: %.2f seconds", simulation_results$multisession_avg)), | |
p(sprintf("Average Sequential Time: %.2f seconds", simulation_results$sequential_avg)) | |
) | |
) | |
}) | |
}) | |
} | |
# Run the application | |
shinyApp(ui, server) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment