Created
May 14, 2024 06:06
-
-
Save pashu123/c8ea48a27ddfc434b9910cd0692be429 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import re | |
parser = argparse.ArgumentParser(description='Convert parameter data type') | |
parser.add_argument('mlir', type=str, help='MLIR file where all parameters are mentioned') | |
parser.add_argument('dtype', type=str, help='Required data type of parameters') | |
parser.add_argument('irpa', type=str, help='destination irpa file') | |
args = parser.parse_args() | |
def extract_tensor_info(statement, convert_dtype=args.dtype): | |
# Define the regex pattern to match the desired parts | |
pattern = r'#stream\.parameter\.named<"model"::"([^"]+)"> : tensor<([^>]+)>' | |
# Search the statement with the pattern | |
match = re.search(pattern, statement) | |
if match: | |
# Extract the name and tensor size/dtype from the matched groups | |
name = match.group(1) | |
tensor_size_dtype = match.group(2) | |
if convert_dtype: | |
# Split the tensor size and dtype | |
parts = tensor_size_dtype.split('x') | |
# Replace the last element with the new dtype | |
parts[-1] = convert_dtype | |
tensor_size_dtype = 'x'.join(parts) | |
return name, tensor_size_dtype | |
else: | |
KeyError(f"Could not find the desired pattern in the statement: {statement}") | |
return None, None | |
def get_param_info(file_name): | |
with open(file_name, 'r') as f: | |
splat_params = [] | |
for line in f: | |
if line.startswith(' util.global private @__'): | |
name, tensor_type = extract_tensor_info(line) | |
print(f"Name: {name}") | |
print(f"Tensor Type: {tensor_type}") | |
splat_params.append(f"--splat={name}={tensor_type}=1.0") | |
return splat_params | |
splat_params = get_param_info(args.mlir) | |
final_command = ["iree-create-parameters"] + splat_params + [f"--output={args.irpa}"] | |
import subprocess | |
subprocess.run(final_command) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment