Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created May 14, 2024 06:06
Show Gist options
  • Save pashu123/c8ea48a27ddfc434b9910cd0692be429 to your computer and use it in GitHub Desktop.
Save pashu123/c8ea48a27ddfc434b9910cd0692be429 to your computer and use it in GitHub Desktop.
import argparse
import re
parser = argparse.ArgumentParser(description='Convert parameter data type')
parser.add_argument('mlir', type=str, help='MLIR file where all parameters are mentioned')
parser.add_argument('dtype', type=str, help='Required data type of parameters')
parser.add_argument('irpa', type=str, help='destination irpa file')
args = parser.parse_args()
def extract_tensor_info(statement, convert_dtype=args.dtype):
# Define the regex pattern to match the desired parts
pattern = r'#stream\.parameter\.named<"model"::"([^"]+)"> : tensor<([^>]+)>'
# Search the statement with the pattern
match = re.search(pattern, statement)
if match:
# Extract the name and tensor size/dtype from the matched groups
name = match.group(1)
tensor_size_dtype = match.group(2)
if convert_dtype:
# Split the tensor size and dtype
parts = tensor_size_dtype.split('x')
# Replace the last element with the new dtype
parts[-1] = convert_dtype
tensor_size_dtype = 'x'.join(parts)
return name, tensor_size_dtype
else:
KeyError(f"Could not find the desired pattern in the statement: {statement}")
return None, None
def get_param_info(file_name):
with open(file_name, 'r') as f:
splat_params = []
for line in f:
if line.startswith(' util.global private @__'):
name, tensor_type = extract_tensor_info(line)
print(f"Name: {name}")
print(f"Tensor Type: {tensor_type}")
splat_params.append(f"--splat={name}={tensor_type}=1.0")
return splat_params
splat_params = get_param_info(args.mlir)
final_command = ["iree-create-parameters"] + splat_params + [f"--output={args.irpa}"]
import subprocess
subprocess.run(final_command)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment