Skip to content

Instantly share code, notes, and snippets.

@vjcitn
Last active July 20, 2024 23:28
Show Gist options
  • Save vjcitn/0e1b0289eb1a219b98d60b5b863f8d62 to your computer and use it in GitHub Desktop.
Save vjcitn/0e1b0289eb1a219b98d60b5b863f8d62 to your computer and use it in GitHub Desktop.
demonstration of bert-base-uncased in huggingface
use_bert = function(phrase) {
# use reticulate::py_install(c("torch", "transformers"), pip=TRUE) to set up
# devtools::source_gist() may produce some warnings related to GPU
# note that first run will populate .cache/huggingface/hub with model components
my_bert_template = "
# ensure accessible python has transformers installed
from transformers import AutoTokenizer, BertForMaskedLM, logging
from transformers import pipeline
logging.set_verbosity_error()
unmasker = pipeline('fill-mask', model='google-bert/bert-base-uncased') # can be changed to huggingface hub elements
ans = unmasker('%%ONEMASKPHRASE%%')
"
my_bert_prog = gsub("%%ONEMASKPHRASE%%", phrase, my_bert_template)
rans = reticulate::py_run_string(my_bert_prog, convert=TRUE, local=TRUE)
reticulate::py_to_r(rans$ans) |> dplyr::bind_rows()
}
ii = rownames(installed.packages())
needp = c("reticulate", "devtools", "dplyr")
todo = setdiff(needp, ii)
if (length(todo)>0) install.packages(todo, ask=FALSE, repos = "https://packagemanager.rstudio.com/all/__linux__/jammy/latest")
library(reticulate)
tt = try(import("transformers"))
if (inherits(tt, "try-error")) { # this discovery step fails too often, maybe time is needed to find the installation
Sys.sleep(15)
}
tt = try(import("transformers"))
if (inherits(tt, "try-error")) {
py_install(c("torch", "transformers"), pip=TRUE)
Sys.sleep(5)
}
phrase <- "The capital of Spain is [MASK]."
print(use_bert(phrase))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment