Created
April 5, 2024 00:25
-
-
Save CoffeeVampir3/9b561245f1fa66905d5939ef229f70ad to your computer and use it in GitHub Desktop.
a
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_cfg_from_json(json_file): | |
with open(json_file, "r", encoding="utf-8") as reader: | |
text = reader.read() | |
return json.loads(text) | |
def load_cfg(model_id, cfg_path): | |
hf_config = load_cfg_from_json(cfg_path) | |
if 'pretrained_cfg' not in hf_config: | |
# old form, pull pretrain_cfg out of the base dict | |
pretrained_cfg = hf_config | |
hf_config = {} | |
hf_config['architecture'] = pretrained_cfg.pop('architecture') | |
hf_config['num_features'] = pretrained_cfg.pop('num_features', None) | |
if 'labels' in pretrained_cfg: # deprecated name for 'label_names' | |
pretrained_cfg['label_names'] = pretrained_cfg.pop('labels') | |
hf_config['pretrained_cfg'] = pretrained_cfg | |
pretrained_cfg = hf_config['pretrained_cfg'] | |
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation | |
pretrained_cfg['source'] = 'hf-hub' | |
# model should be created with base config num_classes if its exist | |
if 'num_classes' in hf_config: | |
pretrained_cfg['num_classes'] = hf_config['num_classes'] | |
# label meta-data in base config overrides saved pretrained_cfg on load | |
if 'label_names' in hf_config: | |
pretrained_cfg['label_names'] = hf_config.pop('label_names') | |
if 'label_descriptions' in hf_config: | |
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions') | |
model_args = hf_config.get('model_args', {}) | |
model_name = hf_config['architecture'] | |
return pretrained_cfg, model_name, model_args | |
#"/home/blackroot/Desktop/ViTagger/model/config.json" | |
def load_model_but_good_this_time_srsly(model_path): | |
pretrained_cfg, model_name, model_args = load_cfg("vit_base_patch16_224", model_path) | |
model = timm.create_model( | |
"hf-hub:SmilingWolf/wd-vit-tagger-v3", | |
checkpoint_path="/home/blackroot/Desktop/ViTagger/model/model.safetensors", | |
pretrained=True, | |
**model_args) | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment