Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created April 5, 2024 00:25
Show Gist options
  • Save CoffeeVampir3/9b561245f1fa66905d5939ef229f70ad to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/9b561245f1fa66905d5939ef229f70ad to your computer and use it in GitHub Desktop.
a
def load_cfg_from_json(json_file):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
def load_cfg(model_id, cfg_path):
hf_config = load_cfg_from_json(cfg_path)
if 'pretrained_cfg' not in hf_config:
# old form, pull pretrain_cfg out of the base dict
pretrained_cfg = hf_config
hf_config = {}
hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
hf_config['pretrained_cfg'] = pretrained_cfg
pretrained_cfg = hf_config['pretrained_cfg']
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
pretrained_cfg['source'] = 'hf-hub'
# model should be created with base config num_classes if its exist
if 'num_classes' in hf_config:
pretrained_cfg['num_classes'] = hf_config['num_classes']
# label meta-data in base config overrides saved pretrained_cfg on load
if 'label_names' in hf_config:
pretrained_cfg['label_names'] = hf_config.pop('label_names')
if 'label_descriptions' in hf_config:
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
model_args = hf_config.get('model_args', {})
model_name = hf_config['architecture']
return pretrained_cfg, model_name, model_args
#"/home/blackroot/Desktop/ViTagger/model/config.json"
def load_model_but_good_this_time_srsly(model_path):
pretrained_cfg, model_name, model_args = load_cfg("vit_base_patch16_224", model_path)
model = timm.create_model(
"hf-hub:SmilingWolf/wd-vit-tagger-v3",
checkpoint_path="/home/blackroot/Desktop/ViTagger/model/model.safetensors",
pretrained=True,
**model_args)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment