|
import os |
|
import pickle |
|
|
|
|
|
def is_pickle_file(filepath): |
|
if not os.path.isfile(filepath): |
|
return False |
|
try: |
|
with open(filepath, 'rb') as file: |
|
pickle.load(file) |
|
return True |
|
except: |
|
return False |
|
|
|
|
|
def is_pytorch_file(filepath): |
|
if not os.path.isfile(filepath): |
|
return False |
|
try: |
|
import torch |
|
torch.load(filepath, map_location="cpu") |
|
return True |
|
except: |
|
return False |
|
|
|
|
|
def is_safetensor_file(filepath): |
|
if not os.path.isfile(filepath): |
|
return False |
|
try: |
|
from safetensors import safe_open |
|
with safe_open(filepath, framework="pt", device="cpu") as f: |
|
return len(f.keys()) > 0 |
|
except: |
|
return False |
|
|
|
|
|
def is_keras_file(filepath): |
|
if not os.path.isfile(filepath): |
|
return False |
|
try: |
|
from tensorflow.keras.models import load_model |
|
load_model(filepath) |
|
return True |
|
except: |
|
return False |
|
|
|
|
|
def is_h5_file(filepath): |
|
if not os.path.isfile(filepath): |
|
return False |
|
try: |
|
import h5py |
|
with h5py.File(filepath, 'r'): |
|
return True |
|
except: |
|
return False |
|
|
|
|
|
def is_onnx_file(filepath): |
|
if not os.path.isfile(filepath): |
|
return False |
|
try: |
|
import onnx |
|
onnx.load(filepath) |
|
return True |
|
except: |
|
return False |
|
|
|
|
|
def check_model_type(filepath): |
|
results = [] |
|
|
|
print(f"\nChecking {filepath}...\n") |
|
|
|
print(f"Checking pickle file...") |
|
if is_pickle_file(filepath): |
|
results.append("pickle") |
|
|
|
print(f"Checking pytorch file...") |
|
if is_pytorch_file(filepath): |
|
results.append("pytorch") |
|
|
|
print(f"Checking safetensor file...") |
|
if is_safetensor_file(filepath): |
|
results.append("safetensor") |
|
|
|
print(f"Checking keras file...") |
|
if is_keras_file(filepath): |
|
results.append("keras") |
|
|
|
print(f"Checking h5 file...") |
|
if is_h5_file(filepath): |
|
results.append("h5") |
|
|
|
print(f"Checking onnx file...") |
|
if is_onnx_file(filepath): |
|
results.append("onnx") |
|
|
|
return results |
|
|
|
|
|
# Example usage |
|
filepath = "pytorch_model.bin" |
|
model_types = check_model_type(filepath) |
|
|
|
if model_types: |
|
print(f"\n\nResults: {filepath} is a {', '.join(model_types)} file") |
|
else: |
|
print(f"{filepath} is not a recognized model file") |