Last active
December 15, 2020 23:53
-
-
Save benvanik/a2c1915c71dfb611f8e6d7ddcdc96539 to your computer and use it in GitHub Desktop.
tiled dispatch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Simple dispatch of static shapes. | |
func @staticShapeDispatch(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> { | |
%x = constant 100 : index | |
%y = constant 50 : index | |
// %x, %y here are the workgroup counts along a 2D grid to dispatch; backends turn them into 3D XYZ. | |
%0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = ( | |
// I/O are modeled in the region as ref arguments that have some special ops available. | |
%arg : !flow.dispatch.input<8x4xf32>, %ret : !flow.dispatch.output<4x8xf32> | |
) { | |
// Loads a tensor from an input; can be tiled with offsets/sizes/strides. | |
%arg_value = flow.dispatch.input.load %arg : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32> | |
// Shapes can be retrieved from the I/O arguments. | |
%arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<8x4xf32> -> !shapex.ranked_shape<[8,4]> | |
%ret_shape = flow.dispatch.shape %ret : !flow.dispatch.output<4x8xf32> -> !shapex.ranked_shape<[4,8]> | |
// Representative "produce a tile from an input with shape information" op. | |
%ret_value = "test.sink"(%arg_value, %arg_shape, %ret_shape) : (tensor<8x4xf32>, !shapex.ranked_shape<[8,4]>, !shapex.ranked_shape<[4,8]>) -> (tensor<4x8xf32>) | |
// Stores a tile into the output I/O argument. | |
flow.dispatch.output.store %ret_value, %ret : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32> | |
flow.return | |
} | |
return %0 : tensor<4x8xf32> | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Simple transformation that outlines the region; also ran canonicalize afterward. | |
flow.executable @staticShapeDispatch_dispatch_0 attributes {sym_visibility = "private"} { | |
flow.dispatch.entry @staticShapeDispatch_dispatch_0 attributes { | |
// Information that may be useful for logging/tracing/etc, but otherwise not needed. | |
signature = (tensor<8x4xf32>) -> tensor<4x8xf32>, | |
// The original rank of the workgroup grid (XY in this example). | |
workgroup_rank = 2 : index | |
} | |
module { | |
// Arguments match that of the region body (references to I/O). | |
func @staticShapeDispatch_dispatch_0(%arg0: !flow.dispatch.input<8x4xf32>, %arg1: !flow.dispatch.output<4x8xf32>) { | |
// Shapes are static and the queries canonicalized to constant values. | |
%rs8_4 = shapex.const_ranked_shape : !shapex.ranked_shape<[8,4]> | |
%rs4_8 = shapex.const_ranked_shape : !shapex.ranked_shape<[4,8]> | |
%0 = flow.dispatch.input.load %arg0 : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32> | |
%1 = "test.sink"(%0, %rs8_4, %rs4_8) : (tensor<8x4xf32>, !shapex.ranked_shape<[8,4]>, !shapex.ranked_shape<[4,8]>) -> tensor<4x8xf32> | |
flow.dispatch.output.store %1, %arg1 : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32> | |
return | |
} | |
} | |
} | |
func @staticShapeDispatch(%arg0: tensor<8x4xf32>) -> tensor<4x8xf32> { | |
%c100 = constant 100 : index | |
%c50 = constant 50 : index | |
%0 = flow.dispatch2 @staticShapeDispatch_dispatch_0::@staticShapeDispatch_dispatch_0[%c100, %c50] (%arg0) : (tensor<8x4xf32>) -> tensor<4x8xf32> | |
return %0 : tensor<4x8xf32> | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Materializing the hal.interface and executables; note that this is a partial conversion: flow ops remain. | |
hal.executable @static_tiled_dispatch attributes {sym_visibility = "private"} { | |
hal.interface @legacy_io { | |
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" | |
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" | |
} | |
hal.executable.target @vmla, filter="vmla" { | |
hal.executable.entry_point @entry attributes { | |
interface = @legacy_io, | |
ordinal = 0 : i32, | |
signature = (!flow.dispatch.input<8x4xf32>, !flow.dispatch.output<4x8xf32>) -> () | |
} | |
module { | |
func @entry() { | |
// As with today, an arbitrary byte offset into the binding can be provided. | |
// Optionally a byte length can be provided too (if we can generate them); may be useful for bounds checking. | |
%c0 = constant 0 : index | |
// This op returns AnyType; here it's 1:1 with the original I/O arguments, but further lowering | |
// can turn it into memref<?xi8>/etc. | |
%0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<8x4xf32> | |
%1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<4x8xf32> | |
%rs8_4 = shapex.const_ranked_shape : !shapex.ranked_shape<[8,4]> | |
%rs4_8 = shapex.const_ranked_shape : !shapex.ranked_shape<[4,8]> | |
%2 = flow.dispatch.input.load %0 : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32> | |
%3 = "test.sink"(%2, %rs8_4, %rs4_8) : (tensor<8x4xf32>, !shapex.ranked_shape<[8,4]>, !shapex.ranked_shape<[4,8]>) -> tensor<4x8xf32> | |
flow.dispatch.output.store %3, %1 : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32> | |
return | |
} | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// More complicated example with dynamic shapes. | |
func @dynamicShapeDispatch(%arg0 : tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32> { | |
%c1 = constant 1 : index | |
%c3 = constant 3 : index | |
// Local query of the dimensions; these will end up getting turned into real values independent of this. | |
%dim1 = dim %arg0, %c1 : tensor<7x?x24x?xf32> | |
%dim3 = dim %arg0, %c3 : tensor<7x?x24x?xf32> | |
%x = constant 1024 : index | |
%y = constant 512 : index | |
// Shape ties are used (as they are today) to indicate which shapes correspond to which tensors. | |
%arg0_shape = shapex.make_ranked_shape %dim1, %dim3 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]> | |
%arg0_shaped = shapex.tie_shape %arg0, %arg0_shape : tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]> | |
%ret0_shape = shapex.make_ranked_shape %dim3, %dim1 : (index, index) -> !shapex.ranked_shape<[?,?,1024]> | |
%ret0 = flow.dispatch.workgroups[%x, %y](%arg0_shaped) : (tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32> = ( | |
%arg : !flow.dispatch.input<7x?x24x?xf32>, %ret : !flow.dispatch.output<?x?x1024xf32> | |
) { | |
// Resolves to 2 when canonicalization runs. | |
%workgroup_rank = flow.dispatch.workgroup.rank : index | |
// Can get dynamic shape dimensions of %arg. | |
%arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<7x?x24x?xf32> -> !shapex.ranked_shape<[7,?,24,?]> | |
%arg_dim1 = shapex.ranked_dim %arg_shape[1] : !shapex.ranked_shape<[7,?,24,?]> -> index | |
%arg_dim3 = shapex.ranked_dim %arg_shape[3] : !shapex.ranked_shape<[7,?,24,?]> -> index | |
"test.sink_shape_arg"(%arg_dim1, %arg_dim3) : (index, index) -> () | |
// Can get dynamic shape dimensions of %ret. | |
%ret_shape = flow.dispatch.shape %ret : !flow.dispatch.output<?x?x1024xf32> -> !shapex.ranked_shape<[?,?,1024]> | |
%ret_dim0 = shapex.ranked_dim %ret_shape[0] : !shapex.ranked_shape<[?,?,1024]> -> index | |
%ret_dim1 = shapex.ranked_dim %ret_shape[1] : !shapex.ranked_shape<[?,?,1024]> -> index | |
"test.sink_shape_ret"(%ret_dim0, %ret_dim1) : (index, index) -> () | |
// Load a tile (and get the tile size - which if we used offsets/sizes/strides may be smaller than the tensors). | |
%arg_tile = flow.dispatch.input.load %arg : !flow.dispatch.input<7x?x24x?xf32> -> tensor<7x?x24x?xf32> | |
%arg_tile_shape = shapex.get_ranked_shape %arg_tile : tensor<7x?x24x?xf32> -> !shapex.ranked_shape<[7,?,24,?]> | |
// Produce a new tile. | |
%ret_tile = "test.tile_math"(%arg_tile, %arg_tile_shape, %ret_shape) : | |
(tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>, !shapex.ranked_shape<[?,?,1024]>) -> (tensor<?x?x1024xf32>) | |
// Store tile back. | |
flow.dispatch.output.store %ret_tile, %ret : tensor<?x?x1024xf32> -> !flow.dispatch.output<?x?x1024xf32> | |
flow.return | |
} | |
// Tie here allows us to know what the result shape is and feed it into the dispatch op. | |
%ret0_shaped = shapex.tie_shape %ret0, %ret0_shape : tensor<?x?x1024xf32>, !shapex.ranked_shape<[?,?,1024]> | |
return %ret0_shaped : tensor<?x?x1024xf32> | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Outlining + canonicalization. | |
flow.executable @dynamicShapeDispatch_dispatch_0 attributes {sym_visibility = "private"} { | |
flow.dispatch.entry @dynamicShapeDispatch_dispatch_0 attributes {signature = (tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32>, workgroup_rank = 2 : index} | |
module { | |
func @dynamicShapeDispatch_dispatch_0( | |
%arg0: !flow.dispatch.input<7x?x24x?xf32>, %arg1: !flow.dispatch.output<?x?x1024xf32>, | |
// Dynamic dimensions for arg/ret, expanded to primitive indices here. | |
%arg2: index, %arg3: index, %arg4: index, %arg5: index | |
) { | |
// Constructs and ties shapes for the arg/ret so that any use of the %1/%3 I/O can fetch full dynamic shape values. | |
%0 = shapex.make_ranked_shape %arg2, %arg3 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]> | |
%1 = flow.dispatch.tie_shape %arg0, %0 : (!flow.dispatch.input<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>) -> !flow.dispatch.input<7x?x24x?xf32> | |
%2 = shapex.make_ranked_shape %arg4, %arg5 : (index, index) -> !shapex.ranked_shape<[?,?,1024]> | |
%3 = flow.dispatch.tie_shape %arg1, %2 : (!flow.dispatch.output<?x?x1024xf32>, !shapex.ranked_shape<[?,?,1024]>) -> !flow.dispatch.output<?x?x1024xf32> | |
"test.sink_shape_arg"(%arg2, %arg3) : (index, index) -> () | |
"test.sink_shape_ret"(%arg4, %arg5) : (index, index) -> () | |
%4 = flow.dispatch.input.load %1 : !flow.dispatch.input<7x?x24x?xf32> -> tensor<7x?x24x?xf32> | |
%5 = shapex.get_ranked_shape %4 : tensor<7x?x24x?xf32> -> !shapex.ranked_shape<[7,?,24,?]> | |
%6 = "test.tile_math"(%4, %5, %2) : (tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>, !shapex.ranked_shape<[?,?,1024]>) -> tensor<?x?x1024xf32> | |
flow.dispatch.output.store %6, %3 : tensor<?x?x1024xf32> -> !flow.dispatch.output<?x?x1024xf32> | |
return | |
} | |
} | |
} | |
func @dynamicShapeDispatch(%arg0: tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32> { | |
%c1 = constant 1 : index | |
%c3 = constant 3 : index | |
%c1024 = constant 1024 : index | |
%c512 = constant 512 : index | |
%0 = dim %arg0, %c1 : tensor<7x?x24x?xf32> | |
%1 = dim %arg0, %c3 : tensor<7x?x24x?xf32> | |
%2 = shapex.make_ranked_shape %0, %1 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]> | |
%3 = shapex.tie_shape %arg0, %2 : tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]> | |
%4 = shapex.make_ranked_shape %1, %0 : (index, index) -> !shapex.ranked_shape<[?,?,1024]> | |
// Note that the dynamic shape dimensions are passed in here. | |
%5 = flow.dispatch2 @dynamicShapeDispatch_dispatch_0::@dynamicShapeDispatch_dispatch_0[%c1024, %c512] (%3, %0, %1, %1, %0) : (tensor<7x?x24x?xf32>, index, index, index, index) -> tensor<?x?x1024xf32> | |
%6 = shapex.tie_shape %5, %4 : tensor<?x?x1024xf32>, !shapex.ranked_shape<[?,?,1024]> | |
return %6 : tensor<?x?x1024xf32> | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Interface materialized and dynamic dimensions ended up as push constants. | |
hal.executable @dynamic_tiled_dispatch attributes {sym_visibility = "private"} { | |
hal.interface @legacy_io attributes {push_constants = 4 : i32} { | |
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" | |
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" | |
} | |
hal.executable.target @vmla, filter="vmla" { | |
hal.executable.entry_point @entry attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.input<7x?x24x?xf32>, !flow.dispatch.output<?x?x1024xf32>, index, index, index, index) -> ()} | |
module { | |
func @entry() { | |
%c0 = constant 0 : index | |
%0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<7x?x24x?xf32> | |
%1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<?x?x1024xf32> | |
// These are the 4 (2 + 2) dynamic dimensions that were arguments, now fetched via the interface. | |
%2 = hal.interface.load.constant offset = 0 : index | |
%3 = hal.interface.load.constant offset = 1 : index | |
%4 = hal.interface.load.constant offset = 2 : index | |
%5 = hal.interface.load.constant offset = 3 : index | |
// Shapes are constructed using the dynamic dimensions and tied such that following code has everything. | |
%6 = shapex.make_ranked_shape %2, %3 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]> | |
%7 = flow.dispatch.tie_shape %0, %6 : (!flow.dispatch.input<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>) -> !flow.dispatch.input<7x?x24x?xf32> | |
%8 = shapex.make_ranked_shape %4, %5 : (index, index) -> !shapex.ranked_shape<[?,?,1024]> | |
%9 = flow.dispatch.tie_shape %1, %8 : (!flow.dispatch.output<?x?x1024xf32>, !shapex.ranked_shape<[?,?,1024]>) -> !flow.dispatch.output<?x?x1024xf32> | |
%10 = flow.dispatch.input.load %7 : !flow.dispatch.input<7x?x24x?xf32> -> tensor<7x?x24x?xf32> | |
%11 = "test.tile_math"(%10) : (tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32> | |
flow.dispatch.output.store %11, %9 : tensor<?x?x1024xf32> -> !flow.dispatch.output<?x?x1024xf32> | |
return | |
} | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// WIP to canonicalize tie shape into the load/stores so that the shape ops aren't required. | |
// Done by hand here! | |
hal.executable @dynamic_tiled_dispatch attributes {sym_visibility = "private"} { | |
hal.interface @legacy_io attributes {push_constants = 4 : i32} { | |
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" | |
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" | |
} | |
hal.executable.target @vmla, filter="vmla" { | |
hal.executable.entry_point @entry attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.input<7x?x24x?xf32>, !flow.dispatch.output<?x?x1024xf32>, index, index, index, index) -> ()} | |
module { | |
func @entry() { | |
%c0 = constant 0 : index | |
%0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<7x?x24x?xf32> | |
%1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<?x?x1024xf32> | |
// These are the 4 (2 + 2) dynamic dimensions that were arguments, now fetched via the interface. | |
%2 = hal.interface.load.constant offset = 0 : index | |
%3 = hal.interface.load.constant offset = 1 : index | |
%4 = hal.interface.load.constant offset = 2 : index | |
%5 = hal.interface.load.constant offset = 3 : index | |
// Shape tie is canonicalized away and becomes direct values/attrs on the load/stores. | |
%10 = flow.dispatch.input.load %0, shape = [7, %2, 24, %3] : !flow.dispatch.input<7x?x24x?xf32> -> tensor<7x?x24x?xf32> | |
%11 = "test.tile_math"(%10) : (tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32> | |
flow.dispatch.output.store %11, %1, shape = [%4, %5, 1024] : tensor<?x?x1024xf32> -> !flow.dispatch.output<?x?x1024xf32> | |
return | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment