For an IterVar (or an axis), it has three kinds of features
- axis attribute
- arithmetic feature
- touch feature
struct Edge {
int src;
int dst;
float feature[100];
}
struct Node {
int node_type;
int id;
import numpy as np | |
import tvm | |
from tvm import te, auto_scheduler, topi | |
@auto_scheduler.register_workload | |
def dense_layer(in_dim, out_dim): | |
data = te.placeholder((1, in_dim), name="data") | |
weight = te.placeholder((out_dim, in_dim), name="weight") | |
bias = te.placeholder((out_dim,), name="bias") | |
out = topi.nn.dense(data, weight) |
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# |
from functools import partial | |
import numpy as np | |
import jax | |
import jax.numpy as jnp | |
def split(a, axis, factor): | |
assert a.shape[axis] % factor == 0 | |
new_shape = a.shape[:axis] + (factor, a.shape[axis] // factor) + a.shape[axis+1:] | |
a = a.reshape(new_shape) |
from functools import partial | |
import numpy as np | |
import jax | |
import jax.numpy as jnp | |
from jax.nn import relu | |
from jax.experimental import PartitionSpec as P | |
from jax.experimental.maps import mesh | |
from jax.experimental.pjit import pjit, with_sharding_constraint |
# Style 1 | |
@auto_parallel | |
def step(batch, weight): | |
grads = grad(loss_func)(batch, weight) | |
# do not know where to insert pmean | |
new_weight = optimier_step(grads) | |
return new_weight # REQUIREMENT: new_weight and weight maps | |
import time | |
import cupy as cp | |
def benchmark(n, k, m, dtype): | |
warmup = 2 | |
number = 100 | |
a = cp.ones((n, k), dtype) | |
b = cp.ones((k, m), dtype) |
HloModule train_step_shard_parallel.3684, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias), {4}: (4, {}, may-alias), {5}: (5, {}, may-alias), {6}: (6, {}, may-alias), {7}: (7, {}, may-alias), {8}: (8, {}, may-alias), {9}: (9, {}, may-alias), {10}: (10, {}, may-alias), {11}: (11, {}, may-alias), {12}: (12, {}, may-alias), {13}: (13, {}, may-alias), {14}: (14, {}, may-alias), {15}: (15, {}, may-alias), {16}: (16, {}, may-alias), {17}: (17, {}, may-alias), {18}: (18, {}, may-alias), {19}: (19, {}, may-alias), {20}: (20, {}, may-alias), {21}: (21, {}, may-alias), {22}: (22, {}, may-alias), {23}: (23, {}, may-alias), {24}: (24, {}, may-alias), {25}: (25, {}, may-alias), {26}: (26, {}, may-alias), {27}: (27, {}, may-alias), {28}: (28, {}, may-alias), {29}: (29, {}, may-alias), {30}: (30, {}, may-alias), {31}: (31, {}, may-alias), {32}: (32, {}, may-alias), {33}: (33, {}, may-alias), {34}: (34, {}, may-alias), {35}: (35, {}, may-alias), {36}: (36 |
// Original : https://github.com/alpa-projects/tensorflow-alpa/blob/d298f84474a04ecce02085332793e6115c0c8e0e/tensorflow/compiler/xla/service/spmd/auto_sharding_strategy.h#L854-L876 | |
if (adj_list.size() > 1) { | |
// Merge src to dst. | |
// | |
// Before: | |
// | |
// src ---- adj ---- dst | |
// | | | |
// ------------------- |