Skip to content

Instantly share code, notes, and snippets.

@a-agmon
Created December 21, 2023 09:11
Show Gist options
  • Save a-agmon/84a5e21c25c2af918220afc4f7e05631 to your computer and use it in GitHub Desktop.
Save a-agmon/84a5e21c25c2af918220afc4f7e05631 to your computer and use it in GitHub Desktop.
loading a model
let repo = Repo::with_revision(model_name.parse()?, RepoType::Model, revision.parse()?);
let api = Api::new()?;
let api = api.repo(repo);
let config_filename = api.get("config.json")?;
let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = api.get("model.safetensors")?;
// load the model config
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
// load the tokenizer
let tokenizer = Tokenizer::from_file(tokenizer_filename)
.map_err(anyhow::Error::msg)?;
// load the model
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
let model = BertModel::load(vb, &config)?;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment