Last active
October 15, 2020 21:26
-
-
Save comaniac/5b9f11c6096aff980d9a5366656d4535 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
import logging | |
import tvm | |
from tvm import auto_scheduler, te, topi | |
from tvm.topi.nn.util import get_pad_tuple | |
from tvm.auto_scheduler.compute_dag import ComputeDAG | |
logging.basicConfig(level=logging.INFO, filename='time.log') | |
resnet_conv2d_configs = { | |
# format : H, W, CI, CO, KH, KW, strides, padding, dilation | |
'18': [ | |
(224, 224, 3, 64, 7, 7, (2, 2), (3, 3), (1, 1)), | |
(56, 56, 64, 128, 3, 3, (2, 2), (1, 1), (1, 1)), | |
(56, 56, 64, 128, 1, 1, (2, 2), (0, 0), (1, 1)), | |
(56, 56, 64, 64, 3, 3, (1, 1), (1, 1), (1, 1)), | |
(56, 56, 64, 64, 1, 1, (1, 1), (0, 0), (1, 1)), | |
(28, 28, 128, 256, 3, 3, (2, 2), (1, 1), (1, 1)), | |
(28, 28, 128, 256, 1, 1, (2, 2), (0, 0), (1, 1)), | |
(28, 28, 128, 128, 3, 3, (1, 1), (1, 1), (1, 1)), | |
(14, 14, 256, 512, 3, 3, (2, 2), (1, 1), (1, 1)), | |
(14, 14, 256, 512, 1, 1, (2, 2), (0, 0), (1, 1)), | |
(14, 14, 256, 256, 3, 3, (1, 1), (1, 1), (1, 1)), | |
(7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)), | |
], | |
'50': [ | |
(224, 224, 3, 64, 7, 7, (2, 2), (3, 3), (1, 1)), | |
(56, 56, 256, 512, 1, 1, (2, 2), (0, 0), (1, 1)), | |
(56, 56, 256, 128, 1, 1, (2, 2), (0, 0), (1, 1)), | |
(56, 56, 256, 64, 1, 1, (1, 1), (0, 0), (1, 1)), | |
(56, 56, 64, 256, 1, 1, (1, 1), (0, 0), (1, 1)), | |
(56, 56, 64, 64, 3, 3, (1, 1), (1, 1), (1, 1)), | |
(56, 56, 64, 64, 1, 1, (1, 1), (0, 0), (1, 1)), | |
(28, 28, 512, 1024, 1, 1, (2, 2), (0, 0), (1, 1)), | |
(28, 28, 512, 256, 1, 1, (2, 2), (0, 0), (1, 1)), | |
(28, 28, 512, 128, 1, 1, (1, 1), (0, 0), (1, 1)), | |
(28, 28, 128, 512, 1, 1, (1, 1), (0, 0), (1, 1)), | |
(28, 28, 128, 128, 3, 3, (1, 1), (1, 1), (1, 1)), | |
(14, 14, 1024, 2048, 1, 1, (2, 2), (0, 0), (1, 1)), | |
(14, 14, 1024, 512, 1, 1, (2, 2), (0, 0), (1, 1)), | |
(14, 14, 1024, 256, 1, 1, (1, 1), (0, 0), (1, 1)), | |
(14, 14, 256, 1024, 1, 1, (1, 1), (0, 0), (1, 1)), | |
(14, 14, 256, 256, 3, 3, (1, 1), (1, 1), (1, 1)), | |
(7, 7, 2048, 512, 1, 1, (1, 1), (0, 0), (1, 1)), | |
(7, 7, 512, 2048, 1, 1, (1, 1), (0, 0), (1, 1)), | |
(7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)), | |
], | |
} | |
def get_log_file_name_from_task(task): | |
return "{}.json".format( | |
task.workload_key.replace("[", "") | |
.replace("]", "") | |
.replace('"', "") | |
.replace(",", "_") | |
.replace(" ", "") | |
) | |
@auto_scheduler.register_workload | |
def conv2d_nchw(N, H, W, CI, CO, KH, KW, stride, padding, dilation): | |
data = te.placeholder((N, CI, H, W), name="data") | |
kernel = te.placeholder((CO, CI, KH, KW), name="kernel") | |
out = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation, out_dtype="float32") | |
return [data, kernel, out] | |
@auto_scheduler.register_workload | |
def conv2d_nchw_gd(N, H, W, CI, CO, KH, KW, stride, padding, dilation): | |
data, kernel, f_out = conv2d_nchw(N, H, W, CI, CO, KH, KW, stride, padding, dilation) | |
dy = te.placeholder(f_out.shape, name="dy") | |
out = te.gradient(f_out, [data, kernel], head=dy) | |
return [data, kernel, dy, *out] | |
@auto_scheduler.register_workload | |
def conv2d_nhwc(N, H, W, CI, CO, KH, KW, stride, padding): | |
data = te.placeholder((N, H, W, CI), name="data") | |
kernel = te.placeholder((KH, KW, CI, CO), name="kernel") | |
out = topi.nn.conv2d_nhwc(data, kernel, stride, padding, dilation=1, out_dtype="float32") | |
return [data, kernel, out] | |
@auto_scheduler.register_workload | |
def conv2d_nhwc_gd(N, H, W, CI, CO, KH, KW, stride, padding): | |
data, kernel, f_out = conv2d_nhwc(N, H, W, CI, CO, KH, KW, stride, padding) | |
dy = te.placeholder(f_out.shape, name="dy") | |
out = te.gradient(f_out, [data, kernel], head=dy) | |
return [data, kernel, dy, *out] | |
target = tvm.target.Target("cuda -model=t4") | |
batch = 32 | |
tasks = [] | |
for cfg in resnet_conv2d_configs["18"]: | |
tasks.append( | |
auto_scheduler.create_task(conv2d_nchw, (batch, *cfg), target) | |
) | |
tasks.append( | |
auto_scheduler.create_task(conv2d_nchw_gd, (batch, *cfg), target) | |
) | |
print('Getting device...') | |
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300) | |
ctx = tvm.gpu() | |
num_out = 2 | |
for idx, task in enumerate(tasks): | |
log_file = get_log_file_name_from_task(task) | |
logging.info("[%d / %d Tasks] Log to %s" % (idx + 1, len(tasks), log_file)) | |
cost_model = auto_scheduler.XGBModel() | |
if os.path.exists(log_file): | |
cost_model.update_from_file(log_file) | |
search_policy = auto_scheduler.SketchPolicy( | |
task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)] | |
) | |
else: | |
search_policy = auto_scheduler.SketchPolicy(task, cost_model) | |
tune_option = auto_scheduler.TuningOptions( | |
num_measure_trials=1500, | |
runner=measure_ctx.runner, | |
measure_callbacks=[auto_scheduler.RecordToFile(log_file)], | |
) | |
#print(task.compute_dag) | |
sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options=tune_option) | |
#inp, res = auto_scheduler.load_best(log_file, task.workload_key) | |
#print(task.compute_dag.print_python_code_from_state(inp.state)) | |
#sch, args = task.compute_dag.apply_steps_from_state(inp.state) | |
func = tvm.build(sch, args, target) | |
in_nps = [np.random.uniform(size=[v.value for v in a.shape]).astype(np.float32) for a in args[:-num_out]] | |
in_args = [tvm.nd.array(dnp, ctx=ctx) for dnp in in_nps] | |
out_args = [tvm.nd.empty([v.value for v in a.shape], ctx=ctx) for a in args[-num_out:]] | |
# Evaluate execution time | |
evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500) | |
logging.info( | |
"Median execution time: %.3f ms" | |
% (np.median(evaluator(*in_args, *out_args).results) * 1000) | |
) | |
#np.testing.assert_equal(ref_out_args[1].asnumpy(), out_args[0].asnumpy()) | |
del measure_ctx |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment