Skip to content

Instantly share code, notes, and snippets.

@dylanpieper
Last active November 29, 2024 02:32
Show Gist options
  • Save dylanpieper/a02cc4b009baa47fc0e0d7350197114f to your computer and use it in GitHub Desktop.
Save dylanpieper/a02cc4b009baa47fc0e0d7350197114f to your computer and use it in GitHub Desktop.
Compare Sequential and Parallel Chats Using Elmer
library(shiny)
library(furrr)
library(elmer)
library(text2vec)
library(tm)
library(bslib)
library(waiter)
# Set your API keys upfront
Sys.setenv(OPENAI_API_KEY = "")
Sys.setenv(GOOGLE_API_KEY = "")
Sys.setenv(ANTHROPIC_API_KEY = "")
# Text preprocessing
preprocess_text <- function(text) {
text <- tolower(text)
text <- removePunctuation(text)
text <- removeNumbers(text)
text <- stripWhitespace(text)
text <- removeWords(text, stopwords("en"))
return(text)
}
# Calculate Jaccard similarity with names
calculate_text_similarity <- function(texts, names) {
texts <- lapply(texts, preprocess_text)
tokenized_texts <- lapply(texts, function(text) unique(unlist(strsplit(text, "\\s+"))))
n <- length(tokenized_texts)
jaccard_matrix <- matrix(0, n, n, dimnames = list(names, names))
for (i in 1:n) {
for (j in i:n) {
intersection <- length(intersect(tokenized_texts[[i]], tokenized_texts[[j]]))
union <- length(union(tokenized_texts[[i]], tokenized_texts[[j]]))
jaccard_matrix[i, j] <- intersection / union
jaccard_matrix[j, i] <- jaccard_matrix[i, j]
}
}
return(jaccard_matrix)
}
# Safe chat for string responses
safe_chat <- function(chat_function, prompt, ...) {
tryCatch({
result <- chat_function(prompt, echo = FALSE)
list(success = TRUE, response = result, error = NULL)
}, error = function(e) {
list(success = FALSE, response = NULL, error = as.character(e))
})
}
# Parallel chat processing
parallel_chat <- function(chat_apis, prompts) {
if (length(chat_apis) == 0 || length(prompts) == 0) {
stop("Must provide at least one chat API and one prompt")
}
if (length(prompts) < length(chat_apis)) {
prompts <- rep(prompts, length.out = length(chat_apis))
}
results <- future_map2(
chat_apis,
prompts,
function(chat_api, prompt) {
safe_chat(chat_api$chat, prompt)
},
.options = furrr_options(seed = TRUE)
)
names(results) <- names(chat_apis)
return(results)
}
# Shiny UI
ui <- fluidPage(
theme = bs_theme(
bg = "#141414", fg = "white", primary = "#FCC780",
base_font = font_google("Space Mono"),
code_font = font_google("Space Mono")
),
use_waiter(),
titlePanel("Compare Sequential and Parallel Chats Using Elmer"),
sidebarLayout(
sidebarPanel(
textInput("prompt", "Enter Prompt:", value = "Explain the R programming language in one paragraph."),
numericInput("num_simulations", "Number of Simulations:", value = 10, min = 1),
actionButton("process_chats", "Process Chats", class = "btn-primary"),
actionButton("simulate", "Run Simulation", class = "btn-secondary"),
checkboxGroupInput("selected_apis", "Select APIs:",
choices = c("OpenAI" = "openai", "Gemini" = "gemini", "Claude" = "claude"),
selected = c("openai", "gemini", "claude"))
),
mainPanel(
uiOutput("plan_comparison"),
uiOutput("chat_results"),
uiOutput("similarity_score"),
uiOutput("simulation_results")
)
)
)
# Shiny Server
server <- function(input, output, session) {
chat_results_multisession <- reactiveVal(NULL)
chat_results_sequential <- reactiveVal(NULL)
similarity_score <- reactiveVal(NULL)
# Create a global waiter
w <- Waiter$new(
html = spin_heartbeat(),
color = transparent(0.7)
)
observeEvent(input$process_chats, {
w$show()
chat_apis <- list()
if ("openai" %in% input$selected_apis) {
chat_apis$openai <- chat_openai(
model = "gpt-4o-mini",
system_prompt = "You are a helpful assistant.",
echo = FALSE
)
}
if ("gemini" %in% input$selected_apis) {
chat_apis$gemini <- chat_gemini(
model = "gemini-1.5-pro",
system_prompt = "You are a helpful assistant.",
echo = FALSE
)
}
if ("claude" %in% input$selected_apis) {
chat_apis$claude <- chat_claude(
model = "claude-3-sonnet-20240229",
system_prompt = "You are a helpful assistant.",
echo = FALSE
)
}
# Measure time for multisession plan
multisession_time <- system.time({
plan(multisession)
results_multisession <- parallel_chat(
chat_apis,
rep(input$prompt, length(chat_apis))
)
})
# Measure time for sequential plan
sequential_time <- system.time({
plan(sequential)
results_sequential <- parallel_chat(
chat_apis,
rep(input$prompt, length(chat_apis))
)
})
# Determine winner and store results
if (multisession_time["elapsed"] <= sequential_time["elapsed"]) {
winner <- "Multisession"
winner_results <- results_multisession
} else {
winner <- "Sequential"
winner_results <- results_sequential
}
chat_results_multisession(winner_results)
# Calculate similarity score across all responses from the winning method
if (length(winner_results) > 1) {
responses <- lapply(winner_results, `[[`, "response")
sim_score_matrix <- tryCatch({
calculate_text_similarity(responses, names(chat_apis))
}, error = function(e) NA)
similarity_score(sim_score_matrix)
} else {
similarity_score(NULL)
}
w$hide()
# Output the comparison
output$plan_comparison <- renderUI({
tagList(
div(
class = ifelse(winner == "Multisession", "card bg-success text-white", "card bg-warning text-white"),
div(
class = "card-header",
strong("Multisession Chats")
),
div(
class = "card-body",
p(sprintf("Time: %.2f seconds", multisession_time["elapsed"]))
)
),
div(
class = ifelse(winner == "Sequential", "card bg-success text-white", "card bg-warning text-white"),
div(
class = "card-header",
strong("Sequential Chats")
),
div(
class = "card-body",
p(sprintf("Time: %.2f seconds", sequential_time["elapsed"]))
)
)
)
})
})
output$chat_results <- renderUI({
req(chat_results_multisession())
result_panels <- lapply(names(chat_results_multisession()), function(api_name) {
result <- chat_results_multisession()[[api_name]]
if (result$success) {
div(
class = "card mb-3",
div(
class = "card-header bg-info text-white",
strong(paste(api_name))
),
div(
class = "card-body",
HTML(markdown::markdownToHTML(text = result$response, fragment.only = TRUE))
)
)
} else {
div(
class = "card mb-3",
div(
class = "card-header bg-warning text-white",
strong(paste(api_name, "Error"))
),
div(
class = "card-body text-danger",
p(result$error)
)
)
}
})
do.call(tagList, result_panels)
})
output$similarity_score <- renderUI({
req(similarity_score())
sim_matrix <- similarity_score()
sim_text <- paste(capture.output(print(sim_matrix)), collapse = "\n")
div(
class = "card",
div(
class = "card-header bg-success text-white",
strong("Response Similarity")
),
div(
class = "card-body",
pre(sim_text)
)
)
})
# Simulation function
run_simulation <- function(n, chat_apis, prompt) {
multisession_times <- numeric(n)
sequential_times <- numeric(n)
for (i in 1:n) {
# Measure time for multisession plan
multisession_times[i] <- system.time({
plan(multisession)
parallel_chat(chat_apis, rep(prompt, length(chat_apis)))
})["elapsed"]
# Measure time for sequential plan
sequential_times[i] <- system.time({
plan(sequential)
parallel_chat(chat_apis, rep(prompt, length(chat_apis)))
})["elapsed"]
}
list(
multisession_avg = mean(multisession_times),
sequential_avg = mean(sequential_times)
)
}
observeEvent(input$simulate, {
w$show()
chat_apis <- list()
if ("openai" %in% input$selected_apis) {
chat_apis$openai <- chat_openai(
model = "gpt-4o-mini",
system_prompt = "You are a helpful assistant.",
echo = FALSE
)
}
if ("gemini" %in% input$selected_apis) {
chat_apis$gemini <- chat_gemini(
model = "gemini-1.5-pro",
system_prompt = "You are a helpful assistant.",
echo = FALSE
)
}
if ("claude" %in% input$selected_apis) {
chat_apis$claude <- chat_claude(
model = "claude-3-sonnet-20240229",
system_prompt = "You are a helpful assistant.",
echo = FALSE
)
}
num_simulations <- input$num_simulations
simulation_results <- run_simulation(num_simulations, chat_apis, input$prompt)
w$hide()
output$simulation_results <- renderUI({
div(
class = "card",
div(
class = "card-header bg-info text-white",
strong(sprintf("Simulation Results (%d Runs)", num_simulations))
),
div(
class = "card-body",
p(sprintf("Average Multisession Time: %.2f seconds", simulation_results$multisession_avg)),
p(sprintf("Average Sequential Time: %.2f seconds", simulation_results$sequential_avg))
)
)
})
})
}
# Run the application
shinyApp(ui, server)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment