Skip to content

Instantly share code, notes, and snippets.

View andrewor14's full-sized avatar

andrewor14

View GitHub Profile
@andrewor14
andrewor14 / compare_backward_numerics_pr_116092
Created March 1, 2024 16:39
Compare numerics: batch_norm_backward vs native_batch_norm_backward (#116092)
# Debug test failure in https://github.com/pytorch/pytorch/pull/116092 for:
# python test/test_decomp.py -k test_comprehensive_batch_norm_with_update_cuda_bfloat16
# Set up args (these are the exact tensors saved from the decomp test)
# All tensors in args16 are bfloat16
# All tensors in args64 are the same values in args16 upcast to float64
>>> args16
[tensor([[-0.5468750000],
[ 0.7812500000]], device='cuda:0', dtype=torch.bfloat16), tensor([[-1.5234375000],
[-4.1875000000]], device='cuda:0', dtype=torch.bfloat16,
requires_grad=True), tensor([8.8125000000], device='cuda:0', dtype=torch.bfloat16,
// Code generated by ColumnarBatchScan.scala when reading the column buffers
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator inmemorytablescan_input;
/* 008 */ private org.apache.spark.sql.execution.metric.SQLMetric inmemorytablescan_numOutputRows;
// Code generated by GenerateColumnarBatch.scala when building the column buffers
/* 001 */ import org.apache.spark.memory.MemoryMode;
/* 002 */ import org.apache.spark.sql.catalyst.InternalRow;
/* 003 */ import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
/* 004 */
/* 005 */ public GeneratedColumnarBatchIterator generate(Object[] references) {
/* 006 */ return new GeneratedColumnarBatchIterator(references);
/* 007 */ }
/* 008 */

This page tries to prove that the following two are equivalent, as suggested by @davies.

// === (1): The original code in Spark 1.6 before PR 10240

val maxToGrant = math.min(numBytes, math.max(0, maxMemoryPerTask - curMem))
val toGrant = math.min(maxToGrant, memoryFree)

if (curMem < minMemoryPerTask) {
  if (memoryFree >= math.min(maxToGrant, minMemoryPerTask - curMem)) {