- Movement between autocasting-enabled and autocasting-disabled regions
- Nesting FP32-enforced regions within autocasting regions
- Extension kernels, which might not route through Pytorch dispatch at all
- Avoid routing through an Amp dispatch layer wherever we can help it, because it incurs an extra lap through the dispatcher
- Avoid the Amp dispatch layer handling all functions (we should only need a subset)
- All operations that take multiple inputs and are not safe to run with different precisions among inputs (ie, do not support builtin type promotion) must receive an explicit Amp backend function
All of the UXs below are possible with the current dispatch. 1 and 2 would be simpler to implement than 3 and 4.
An amp.autocast
context manager flips a global flag that controls whether or not ops route through an Amp dispatch layer.
Tensors themselves are not given any special additional identity.
with amp.autocast():
ops... # Safe to enter autocast-enabled region.
with amp.autocast(enabled=False):
ops... # Ops here may need to deal with a mixture of float and half tensors, and require manual casts to float.
# Type promotion will smooth over some of these.
ops... # Safe to reenter autocast-enabled region.
ops... # Ops here have to deal with a mixture of float and half tensors created under the context manager.
# Errors will crop up one by one and require a manual float conversion in each case. The errors will be clear
# and easy to find on a per-op basis, though. With type promotion, there may not even be that many.
Advantages:
- Manual control
- Ops that don't need Amp special treatment can "fall through" to the next step along the dispatch path (autograd history recording in VariableType*.cpp, most likely), saving 1 round trip through the dispatch machinery. Ed has not implemented the fallthrough yet, but he is pushing for the idea (pytorch/pytorch#28386).
Disadvantages:
- Bleedover of tensors created with different types from autocasting-enabled regions to autocasting-disabled regions. People may have to insert manual casts. The places these manual casts must go will be easy to find, and minimized by kernels supporting type promotion. We don't know for certain how common/annoying this will be for typical networks. This could be regarded as a "documentation problem."
- backward() should not be under the context manager, which is a gotcha people may easily run into.
MyModule(torch.nn.Module):
@amp.float
def my_float_func(args...):
ops...
@amp.half # Maybe this one should not exist at all.
def my_half_func(args...):
ops...
@amp.autocast
def forward(args...):
ops...
# amp.autocast would look like:
def autocast(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
already_autocasting = amp.is_autocasting()
if already_autocasting:
return func(*args, **kwargs) # Go with the flow
else:
with amp._autocast(): #
return cast_to_float(func(*args, **kwargs)) # Cast the output to float
return wrapper
# amp.float would look like
def float(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with amp._autocast(enabled=False): # Disable autocasting
return func(cast_to_float(*args), cast_to_float(**kwargs))
return wrapper
Advantages:
- Simplicity of use for common cases
- Since we are wrapping things that have explicitly-known inputs and outputs, we can minimize bleedover of tensors created with different types from autocasting-enabled to autocasting-disabled regions. We can ensure that outputs are always cast to float when exiting an autocasting region.
- The danger of running the backward pass with autocasting enabled is reduced.
- Fallthrough in the backend is just as viable as for the raw context manager.
- If the JIT script parsing can parse decoration statements directly, as opposed to parsing the expanded form that may include the with statement, maybe we won't need with statement support in the JIT to make this API work. Someone could say
@torch.jit.script @amp.autocast def my_jit_autocast_function...
Disadvantages:
- There is still danger of bleedover, if
@amp.float
functions use data that doesn't come in through the argument list, or@amp.autocast
functions create tensors that are supplied to the outside world by some means other than the function outputs. This could be regarded as a "documentation problem." - The granularity at which regions can be autocasted must coincide with functions. If people require finer granularity, they must either use the implementation-detail context manager to enable/disable that region (with all the bleedover that implies) or write a new function to encapsulate the desired region, and decorate that function.
AmpTensors become a unique datatype, as in the msnpu test (https://github.com/pytorch/pytorch/blob/master/test/test_cpp_extensions.py#L739). We can decide whether they are exposed as a separate device or Ampness is a separate tensor axis.
UX: model.amp()
.
All operations where at least 1 tensor is an AmpTensor will route through the Amp dispatch layer, and be subject to autocasting. These operations will output Tensors that are also AmpTensors. AmpTensor identity would be independent from floatness or halfness. Autocasting and Ampness would occur and propagate to wherever AmpTensors were used, just like grad history recording and requires_gradness.
Autocasting would not be triggered for extension calls. Extension ops would need to cast their inputs to float manually, whether in an an autocasting-enabled or disabled region. We should probably supply a decorator to help. We would also need a context manager or decorator for nested FP32-enforced regions.
Fallthrough would still work.
AmpTensors could facilitate "self-cleaning" context managers/decorators.
The context managers control a global flag as usual. Under context managers, whitelist/blacklist ops always dispatch through an Amp layer, AND return Tensors that have been given AmpTensor identity (or maybe only HalfTensors would need to be given AmpTensor identity, because they're the only ones that would need to be cleaned up?). After context manager exit, the global flag is False, but any operation with at least one AmpTensor among the inputs will still route through the Amp dispatch layer. The Amp function, seeing that the global autocasting flag is False, will realize its autocasting shenanigans are no longer welcome, cast any AmpTensor arguments to float, run the op, and return ordinary float Tensors that do not have Amp identity.
Autocasting and "self-cleaning" would not be triggered for custom/extension ops. Extension ops would need to cast their inputs to float manually, whether in an an autocasting-enabled or disabled region. We should probably supply a decorator to help.