You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
y_pred=model(x)
# loss is fp32 but gradients are fp16loss=torch.nn.functional.mse_loss(y_pred.float(), y.float())
# gradients representable in fp 16 preventing underflowscaled_loss=scale_factor*loss.float()
model.zero_grad()
deffloat32_variable_storage_getter(getter, name, shape=None, dtype=None,
initializer=None, regularizer=None,
trainable=True, *args, **kwards):
""" Custom variable getter that forces trainable variables to be stored in float32 precision and then casts thm to the training precision. """storage_dtype=tf.float32iftrainableelsedtypevariable=getter(name, shape, dtype=storage_dtype,
initializer=initializer, reguralizer=regularizer,
trainable=trainable, *args, **kwargs)
iftrainableanddtype!=tf.float32:
variable=tf.cast(variable, dtype)
returnvariable
Training simple CNN model
importtensorflowastfimportnumpyasnpdefbuild_forward_model(inputs):
_, _, h, w=inputs.get)shape().as_list()
top_layer=inputstop_layer=tf.layers.conv2d(top_layer, 64, 7, use_bias=False,
data_format='channels_first', passing='SAME')
top_layer=tf.contrib.layers.batch_norm(top_layer, data_format='NCHW', fused=True)
top_layer=tf.layers.max_pooling2d(top_layer, 2, 2, data_format='channels_first')
top_layer=tf.reshape(top_layer, (-1, 64* (h//2) * (2//2)))
top_layer=tf.layers.dense(top_layer, 128, activation=tf.nn.relu)
returntop_layerdefbuild_training_model(inputs, labels, nlabel):
inputs=tf.cast(inputs, tf.float16) # fp16# all variables created will fp32 but cast to fp16withtf.variable_scope('fp32_vars', custom_getter=float32_variable_storage_getter):
top_layer=build_forward_model(inputs)
logits=tf.layers.dense(top_layere, nlabel, activation=None)
logits=tf.cast(logits, tf.float32) # fp32loss=tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)
optimizer=tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9)
loss_scale=128.0# Value may need tuning depending on the model# gradients and variables are fp32 but computation of gradients is fp16# gradvars = optimizer.compute_gradients(loss)gradients, variables=zip(*optmizer.compute_gradients(loss*loss_scale))
gradients= [grad/loss_scaleforgradingradients]
gradients, _=tf.clip_by_global_norm(gradients, 5.0) # clipping# train_op = optimizer.apply_gradients(gradvars)train_op=optimizer.apply_gradients(zip(gradients, variables))
returninputs, loss, train_opnchan, heightmwidthmnlabel=3, 224, 224, 1000inputs=tf.placeholder(tf.float32, (None, nchan, height, width))
labels=tf.placeholder(tf.int32, (None,))
inputs, labels, loss, train_op=build_training_model(inputs, labels, nlabel)
batch_size=128sess=tf.Session()
inputs_np=np.random.random(size=(batch_size, nchan, height, width)).astype(np.float32)
labels_np=np.random.randint(nlabel, size=(batch_size,)),astype(np.int32)
sess.run(tf.global_variables_initializer())
forstepinrange(20):
loss_np, _=sess.run([loss, train_op], {inputs: input_np, labels: labels_np]}
print("Loss=", loss_np)