Skip to content

Instantly share code, notes, and snippets.

View pashu123's full-sized avatar
๐Ÿ˜‡
Working from home

Prashant Kumar pashu123

๐Ÿ˜‡
Working from home
View GitHub Profile
This file has been truncated, but you can view the full file.
#map0 = affine_map<(d0) -> (0)>
#map1 = affine_map<(d0) -> (d0)>
#map2 = affine_map<(d0) -> ()>
#map3 = affine_map<() -> ()>
#map4 = affine_map<(d0, d1) -> ()>
#map5 = affine_map<(d0, d1) -> (d0, d1)>
#map6 = affine_map<(d0, d1) -> (d0, 0)>
#map7 = affine_map<(d0, d1) -> (0, d1)>
#map8 = affine_map<(d0, d1) -> (d1, d0)>
#map9 = affine_map<(d0, d1) -> (d1)>
import numpy as np
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.shark_downloader import download_torch_model
mlir_model, func_name, inputs, golden_out = download_torch_model(
"resnet_50_fp16_old", tank_url="gs://shark_tank/prashant_nod"
)
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
import numpy as np
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.shark_downloader import download_torch_model
mlir_model, func_name, inputs, golden_out = download_torch_model(
"resnet_50_fp16_torch", tank_url="gs://shark_tank/prashant_nod"
)
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
This file has been truncated, but you can view the full file.
#map0 = affine_map<() -> ()>
#map1 = affine_map<(d0) -> (0)>
#map2 = affine_map<(d0) -> (d0)>
#map3 = affine_map<(d0) -> ()>
#map4 = affine_map<(d0, d1) -> (d0, 0)>
#map5 = affine_map<(d0, d1) -> (0, d1)>
#map6 = affine_map<(d0, d1) -> (d0, d1)>
#map7 = affine_map<(d0, d1) -> ()>
#map8 = affine_map<(d0, d1) -> (d1, d0)>
#map9 = affine_map<(d0, d1) -> (d1)>
import numpy as np
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.shark_downloader import download_torch_model
mlir_model, func_name, inputs, golden_out = download_torch_model(
"stable_diff_quant", tank_url="gs://shark_tank/prashant_nod"
)
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg", device="vulkan")
import torch
import numpy as np
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.shark_downloader import download_torch_model
mlir_model, func_name, inputs, golden_out = download_torch_model(
"stable_diff_quant", tank_url="gs://shark_tank/prashant_nod"
)
This file has been truncated, but you can view the full file.
#loc0 = loc(unknown)
module attributes {torch.debug_module_name = "_lambda"} {
func.func private @__torch__.torch.fx.graph_module._lambda.forward(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> loc(unknown), %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[2,4,64,64],f32>} loc(unknown), %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[],si64>} loc(unknown), %arg3: !torch.tensor {torch.type_bound = !torch.vtensor<[2,77,768],f32>} loc(unknown)) -> !torch.tensor {
%3919 = torch.tensor_static_info_cast %arg1 : !torch.tensor to !torch.tensor<[2,4,64,64],f32> loc(#loc0)
%3920 = torch.tensor_static_info_cast %arg2 : !torch.tensor to !torch.tensor<[],si64> loc(#loc0)
%3921 = torch.tensor_static_info_cast %arg3 : !torch.tensor to !torch.tensor<[2,77,768],f32> loc(#loc0)
%3922 = torch.prim.GetAttr %arg0["_param_constant365"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor loc(#loc0)
%3923 = torch.prim.GetAttr %arg0["_param_constant3
import argparse
from shark.shark_inference import SharkInference
import numpy as np
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
p.add_argument("--mlir_loc", type=str, default=None, help="the device to use")
args = p.parse_args()
from iree import runtime as ireert
from iree.compiler import tf as tfc
from iree.compiler import compile_str
import sys
from absl import app
import numpy as np
import os
import tempfile
from stable_diffusion_tf.stable_diffusion import get_models
from iree import runtime as ireert
from iree.compiler import tf as tfc
from iree.compiler import compile_str
import sys
from absl import app
import numpy as np
import os