Last active
July 7, 2024 18:15
-
-
Save farach/ac3484464dd3a3d3fbc4e56573e28283 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load necessary libraries | |
library(janitor) | |
library(httr) | |
library(jsonlite) | |
library(tidyverse) | |
library(furrr) | |
library(stringr) | |
library(glue) | |
# Setup parallel processing | |
plan(multisession) | |
# Define the URL for the tasks data | |
tasks_url <- "https://www.onetcenter.org/dl_files/database/db_28_3_text/Task%20Ratings.txt" | |
tasks_statement <- "https://www.onetcenter.org/dl_files/database/db_28_3_text/Task%20Statements.txt" | |
# Download and preprocess tasks data | |
tasks <- read_tsv(tasks_url) %>% | |
clean_names() %>% | |
filter(scale_id == "IM") |> | |
left_join( | |
read_tsv(tasks_statement) %>% | |
clean_names() %>% | |
select(o_net_soc_code, task_id, task) | |
) | |
# Combine and aggregate datasets | |
onet_combined <- tasks %>% | |
select(o_net_soc_code, task_id, task, task_importance = data_value) %>% | |
filter(o_net_soc_code == "19-2012.00") # For testing, remove to do all tasks | |
# Define the prompt generator function with few-shot CoT examples | |
prompt_generator <- function(task) { | |
task_info <- glue("Tasks: {paste(task, collapse = '; ')}") | |
examples <- glue( | |
"Example 1:\n", | |
"Task Information: Given a set of routine tasks such as data entry and simple report generation, which are highly repetitive and well-defined.\n", | |
"Classification: E1\n", | |
"Explanation: These tasks are highly routine and procedural, making them suitable for automation using a large language model. The LLM can effectively reduce the time to complete these tasks by at least 50% while maintaining quality.\n\n", | |
"Example 2:\n", | |
"Task Information: Given a set of complex analytical tasks that require extensive domain knowledge and critical thinking.\n", | |
"Classification: E2\n", | |
"Explanation: While the LLM can assist with parts of these tasks, additional specialized software is required to handle the complexity and domain-specific requirements. The LLM alone cannot achieve a 50% reduction in time without compromising quality.\n\n", | |
"Example 3:\n", | |
"Task Information: Given a set of creative tasks such as content generation and brainstorming new ideas.\n", | |
"Classification: E1\n", | |
"Explanation: The LLM can significantly aid in generating new content and ideas, reducing the time required for these tasks by at least 50% while maintaining or enhancing quality." | |
) | |
glue( | |
"Given the following task information: {task_info},\n", | |
"Can a state-of-the-art large language model reduce the time to complete tasks associated with these skills and knowledge by at least 50% while maintaining quality?\n", | |
"Classify as: E0 for no exposure, E1 for direct exposure by LLM alone, or E2 for exposure with additional software. Provide a brief explanation.\n\n", | |
"{examples}\n", | |
"Task Information: {task_info}\n", | |
"Classification:\n", | |
"Explanation:" | |
) | |
} | |
# Generate prompts for classification | |
onet_combined <- onet_combined %>% | |
mutate(prompt_task = map_chr(task, prompt_generator)) | |
# Set OpenAI API key | |
openai_api_key <- keyring::key_get("OPENAI_API_KEY", username = "YOUR_USERNAME") | |
# Define a function to interact with OpenAI API with bootstrap sampling | |
openai_classifier_bootstrap <- function(text, api_key, n_times = 5) { | |
url <- "https://api.openai.com/v1/completions" | |
classifications <- character(n_times) | |
explanations <- character(n_times) | |
for (i in seq_len(n_times)) { | |
data <- list( | |
model = "gpt-3.5-turbo-instruct", | |
prompt = text, | |
temperature = 0.7, | |
max_tokens = 150, | |
top_p = 0.9, | |
frequency_penalty = 0.0, | |
presence_penalty = 0.0 | |
) | |
json_body <- toJSON(data, auto_unbox = TRUE) | |
response <- POST( | |
url, | |
add_headers( | |
Authorization = paste("Bearer", api_key), | |
"Content-Type" = "application/json" | |
), | |
body = json_body, | |
encode = "json" | |
) | |
if (status_code(response) == 200) { | |
response_content <- content(response, as = "parsed", type = "application/json") | |
response_text <- response_content$choices[[1]]$text | |
# Extract the classification and explanation from the response_text | |
classification <- str_extract(response_text, "E[012]") | |
explanation <- str_extract(response_text, "(?<=Explanation: ).*") | |
classifications[i] <- classification | |
explanations[i] <- explanation | |
} else { | |
warning(paste("Error in API request:", status_code(response))) | |
classifications[i] <- NA | |
explanations[i] <- NA | |
} | |
} | |
# Determine the most frequent classification | |
mode_classification <- names(sort(table(classifications), decreasing = TRUE))[1] | |
# Filter out NA explanations before sampling | |
relevant_explanations <- explanations[classifications == mode_classification & !is.na(explanations)] | |
chosen_explanation <- sample(relevant_explanations, 1) | |
list(classification = mode_classification, explanation = chosen_explanation) | |
} | |
# Apply the classification function to each task using parallel processing | |
results <- future_map(onet_combined$prompt_task, ~ openai_classifier_bootstrap(.x, openai_api_key)) | |
# Process the results | |
onet_combined <- onet_combined %>% | |
mutate( | |
classification = map_chr(results, "classification"), | |
explanation = map_chr(results, "explanation"), | |
exposure_level = classification | |
) | |
# Print out the combined dataframe with the new classification column | |
print(onet_combined) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment