Skip to content

Instantly share code, notes, and snippets.

@geohot
Created October 27, 2025 10:36
Show Gist options
  • Save geohot/cb8c6ea335dfed87a707618d7fff39af to your computer and use it in GitHub Desktop.
Save geohot/cb8c6ea335dfed87a707618d7fff39af to your computer and use it in GitHub Desktop.
beautiful mnist rendered to uops
def k0():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(1), arg=0)
c2 = c0.index(UOp.const(dtypes.index, 0))
c3 = c2.load()
c5 = (c3+512)
c6 = c2.store(c5)
ast = c6.sink()
return ast
def k1():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=0)
c2 = UOp.range(512, 0, AxisType.LOOP, tag=())
c5 = (c2<256)
c7 = c5.where(c2, UOp.const(dtypes.index, Invalid))
c11 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(1), arg=1)
c14 = c11.index(UOp.const(dtypes.index, 0)).load()
c16 = (c14+-512)
c17 = ((c7+1).cast(dtypes.uint)+c16)
c27 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(2), arg=2)
c29 = c27.index(UOp.const(dtypes.index, 1)).load()
c33 = c27.index(UOp.const(dtypes.index, 0)).load()
c35 = ((c29.cast(dtypes.ulong)*4294967296)|c33.cast(dtypes.ulong))
c41 = c5.where((UOp(Ops.THREEFRY, dtypes.ulong, ((((c17+255).cast(dtypes.ulong)*4294967296)|(c17+-1).cast(dtypes.ulong)), c35)).cast(dtypes.uint)&4294967295), UOp.const(dtypes.uint, 0))
c46 = (c5!=True).where((c2+-256), UOp.const(dtypes.index, Invalid))
c49 = ((c46+1).cast(dtypes.uint)+c16)
c60 = c5.where(UOp.const(dtypes.uint, 0), ((UOp(Ops.THREEFRY, dtypes.ulong, ((((c49+255).cast(dtypes.ulong)*4294967296)|(c49+-1).cast(dtypes.ulong)), c35))//4294967296).cast(dtypes.uint)&4294967295))
c68 = ((((c41+c60)//512)|1065353216).bitcast(dtypes.float)+-1.0)
c70 = c0.index(c2).store(c68).end(c2)
ast = c70.sink()
return ast
def k2():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(4014080), arg=0)
c2 = UOp.range(28, 3, AxisType.LOOP, tag=())
c5 = UOp.range(10, 4, AxisType.LOOP, tag=())
c7 = UOp.range(28, 2, AxisType.LOOP, tag=())
c12 = UOp.range(512, 1, AxisType.LOOP, tag=())
c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=1)
c20 = c18.index(c12).load()
c22 = (60000.0*c20).cast(dtypes.int)
c27 = (c22<0).where((c22+60000), c22)
c30 = UOp.range(6000, 0, AxisType.REDUCE, tag=())
c31 = ((c5*6000)+c30)
c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(47040000), arg=2)
c42 = c35.index((((c7*28)+c2)+(c31*784))).load()
c45 = (c27!=c31.cast(dtypes.int)).where(UOp.const(dtypes.uchar, 0), c42.cast(dtypes.uint).cast(dtypes.uchar))
c46 = c45.reduce(c30, arg=Ops.ADD)
c48 = c0.index(((((c2*10)+c5)+(c7*280))+(c12*7840))).store(c46).end(c12, c7, c2, c5)
ast = c48.sink()
return ast
def k3():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(401408), arg=0)
c2 = UOp.range(28, 2, AxisType.LOOP, tag=())
c4 = UOp.range(28, 3, AxisType.LOOP, tag=())
c7 = UOp.range(512, 1, AxisType.LOOP, tag=())
c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(4014080), arg=1)
c15 = UOp.range(10, 0, AxisType.REDUCE, tag=())
c24 = c12.index(((((c4*10)+c15)+(c2*280))+(c7*7840))).load()
c26 = c24.reduce(c15, arg=Ops.ADD).cast(dtypes.uint)
c28 = c0.index((((c2*28)+c4)+(c7*784))).store(c26).end(c7, c2, c4)
ast = c28.sink()
return ast
def k4():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=0)
c2 = UOp.range(24, 4, AxisType.LOOP, tag=())
c4 = UOp.range(24, 5, AxisType.LOOP, tag=())
c7 = UOp.range(32, 3, AxisType.LOOP, tag=())
c12 = UOp.range(512, 2, AxisType.LOOP, tag=())
c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(401408), arg=1)
c19 = UOp.range(5, 1, AxisType.REDUCE, tag=())
c21 = UOp.range(5, 0, AxisType.REDUCE, tag=())
c30 = c17.index((((c4+c19)+((c2+c21)*28))+(c12*784))).load()
c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(800), arg=2)
c42 = c35.index((((c21*5)+c19)+(c7*25))).load()
c43 = (c30.cast(dtypes.uchar).cast(dtypes.uint).cast(dtypes.uchar).cast(dtypes.float)*c42)
c45 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=3)
c47 = c45.index(c7).load()
c48 = (c43.reduce(c21, c19, arg=Ops.ADD)+c47)
c50 = c0.index(((((c2*24)+c4)+(c7*576))+(c12*18432))).store(c48).end(c12, c7, c2, c4)
ast = c50.sink()
return ast
def k5():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=0)
c2 = UOp.range(20, 5, AxisType.LOOP, tag=())
c4 = UOp.range(20, 6, AxisType.LOOP, tag=())
c7 = UOp.range(32, 4, AxisType.LOOP, tag=())
c12 = UOp.range(512, 3, AxisType.LOOP, tag=())
c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=1)
c20 = UOp.range(5, 2, AxisType.REDUCE, tag=())
c22 = UOp.range(5, 1, AxisType.REDUCE, tag=())
c27 = UOp.range(32, 0, AxisType.REDUCE, tag=())
c35 = c18.index(((((c4+c20)+((c2+c22)*24))+(c27*576))+(c12*18432))).load()
c37 = (0.0<c35).where(c35, UOp.const(dtypes.float, 0.0))
c38 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25600), arg=2)
c48 = c38.index(((((c22*5)+c20)+(c27*25))+(c7*800))).load()
c49 = (c37*c48)
c51 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=3)
c53 = c51.index(c7).load()
c54 = (c49.reduce(c27, c22, c20, arg=Ops.ADD)+c53)
c56 = c0.index(((((c2*20)+c4)+(c7*400))+(c12*12800))).store(c54).end(c12, c7, c2, c4)
ast = c56.sink()
return ast
def k6():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0)
c2 = UOp.range(32, 3, AxisType.LOOP, tag=())
c5 = UOp.range(256, 4, AxisType.LOOP, tag=())
c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1)
c11 = UOp.range(20, 1, AxisType.REDUCE, tag=())
c13 = UOp.range(20, 2, AxisType.REDUCE, tag=())
c20 = UOp.range(2, 0, AxisType.REDUCE, tag=())
c26 = c9.index(((((c11*20)+c13)+(c2*400))+(((c5*2)+c20)*12800))).load()
c28 = (0.0<c26).where(c26, UOp.const(dtypes.float, 0.0))
c29 = c28.reduce(c20, c11, c13, arg=Ops.ADD)
c31 = c0.index(((c2*256)+c5)).store(c29).end(c2, c5)
ast = c31.sink()
return ast
def k7():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0)
c2 = UOp.range(32, 1, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=1)
c7 = UOp.range(256, 0, AxisType.REDUCE, tag=())
c10 = c4.index(((c2*256)+c7)).load()
c13 = (c10.reduce(c7, arg=Ops.ADD)*4.8828125e-06)
c15 = c0.index(c2).store(c13).end(c2)
ast = c15.sink()
return ast
def k8():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0)
c2 = UOp.range(32, 3, AxisType.LOOP, tag=())
c5 = UOp.range(256, 4, AxisType.LOOP, tag=())
c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1)
c11 = UOp.range(20, 1, AxisType.REDUCE, tag=())
c13 = UOp.range(20, 2, AxisType.REDUCE, tag=())
c20 = UOp.range(2, 0, AxisType.REDUCE, tag=())
c26 = c9.index(((((c11*20)+c13)+(c2*400))+(((c5*2)+c20)*12800))).load()
c28 = (0.0<c26).where(c26, UOp.const(dtypes.float, 0.0))
c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2)
c31 = c29.index(c2).load()
c34 = (c28+(c31*-1.0))
c35 = (c34*c34)
c36 = c35.reduce(c20, c11, c13, arg=Ops.ADD)
c38 = c0.index(((c2*256)+c5)).store(c36).end(c2, c5)
ast = c38.sink()
return ast
def k9():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=0)
c2 = UOp.range(2, 4, AxisType.LOOP, tag=())
c4 = UOp.range(2, 5, AxisType.LOOP, tag=())
c7 = UOp.range(10, 3, AxisType.LOOP, tag=())
c11 = UOp.range(10, 2, AxisType.LOOP, tag=())
c16 = UOp.range(32, 1, AxisType.LOOP, tag=())
c18 = (c16*400)
c21 = UOp.range(512, 0, AxisType.LOOP, tag=())
c23 = (c21*12800)
c27 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1)
c38 = c27.index((((((c7*2)+c4)+(((c11*2)+c2)*20))+c18)+c23)).load()
c40 = (0.0<c38).where(c38, UOp.const(dtypes.float, 0.0))
c41 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2)
c43 = c41.index(c16).load()
c47 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=3)
c49 = c47.index(c16).load()
c51 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=4)
c53 = c51.index(c16).load()
c59 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=5)
c61 = c59.index(c16).load()
c62 = ((((c40+(c43*-1.0))*c49)*(c53+1e-05).sqrt().reciprocal())+c61)
c64 = c0.index(((((((c2*2)+c4)+(c7*4))+(c11*40))+c18)+c23)).store(c62).end(c21, c16, c11, c7, c2, c4)
ast = c64.sink()
return ast
def k10():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=0)
c2 = UOp.range(10, 4, AxisType.LOOP, tag=())
c4 = UOp.range(10, 5, AxisType.LOOP, tag=())
c7 = UOp.range(32, 3, AxisType.LOOP, tag=())
c12 = UOp.range(512, 2, AxisType.LOOP, tag=())
c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1)
c19 = UOp.range(2, 0, AxisType.REDUCE, tag=())
c21 = UOp.range(2, 1, AxisType.REDUCE, tag=())
c36 = c17.index(((((((c19*2)+c21)+(c4*4))+(c2*40))+(c7*400))+(c12*12800))).load()
c37 = c36.reduce(c19, c21, arg=Ops.MAX)
c39 = c0.index(((((c2*10)+c4)+(c7*100))+(c12*3200))).store(c37).end(c12, c7, c2, c4)
ast = c39.sink()
return ast
def k11():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=0)
c2 = UOp.range(8, 5, AxisType.LOOP, tag=())
c4 = UOp.range(8, 6, AxisType.LOOP, tag=())
c7 = UOp.range(64, 4, AxisType.LOOP, tag=())
c11 = UOp.range(512, 3, AxisType.LOOP, tag=())
c16 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=1)
c18 = UOp.range(3, 2, AxisType.REDUCE, tag=())
c20 = UOp.range(3, 1, AxisType.REDUCE, tag=())
c26 = UOp.range(32, 0, AxisType.REDUCE, tag=())
c34 = c16.index(((((c4+c18)+((c2+c20)*10))+(c26*100))+(c11*3200))).load()
c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18432), arg=2)
c45 = c35.index(((((c20*3)+c18)+(c26*9))+(c7*288))).load()
c46 = (c34*c45)
c48 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3)
c50 = c48.index(c7).load()
c51 = (c46.reduce(c26, c20, c18, arg=Ops.ADD)+c50)
c53 = c0.index(((((c2*8)+c4)+(c7*64))+(c11*4096))).store(c51).end(c11, c7, c2, c4)
ast = c53.sink()
return ast
def k12():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=0)
c2 = UOp.range(6, 5, AxisType.LOOP, tag=())
c4 = UOp.range(6, 6, AxisType.LOOP, tag=())
c7 = UOp.range(64, 4, AxisType.LOOP, tag=())
c12 = UOp.range(512, 3, AxisType.LOOP, tag=())
c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=1)
c20 = UOp.range(3, 2, AxisType.REDUCE, tag=())
c22 = UOp.range(3, 1, AxisType.REDUCE, tag=())
c27 = UOp.range(64, 0, AxisType.REDUCE, tag=())
c34 = c18.index(((((c4+c20)+((c2+c22)*8))+(c27*64))+(c12*4096))).load()
c36 = (0.0<c34).where(c34, UOp.const(dtypes.float, 0.0))
c37 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=2)
c47 = c37.index(((((c22*3)+c20)+(c27*9))+(c7*576))).load()
c48 = (c36*c47)
c50 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3)
c52 = c50.index(c7).load()
c53 = (c48.reduce(c27, c22, c20, arg=Ops.ADD)+c52)
c55 = c0.index(((((c2*6)+c4)+(c7*36))+(c12*2304))).store(c53).end(c12, c7, c2, c4)
ast = c55.sink()
return ast
def k13():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0)
c2 = UOp.range(64, 3, AxisType.LOOP, tag=())
c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1)
c7 = UOp.range(6, 1, AxisType.REDUCE, tag=())
c9 = UOp.range(6, 2, AxisType.REDUCE, tag=())
c15 = UOp.range(512, 0, AxisType.REDUCE, tag=())
c20 = c5.index(((((c7*6)+c9)+(c2*36))+(c15*2304))).load()
c22 = (0.0<c20).where(c20, UOp.const(dtypes.float, 0.0))
c25 = (c22.reduce(c15, c7, c9, arg=Ops.ADD)*5.425347222222222e-05)
c27 = c0.index(c2).store(c25).end(c2)
ast = c27.sink()
return ast
def k14():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0)
c2 = UOp.range(64, 3, AxisType.LOOP, tag=())
c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1)
c7 = UOp.range(6, 1, AxisType.REDUCE, tag=())
c9 = UOp.range(6, 2, AxisType.REDUCE, tag=())
c15 = UOp.range(512, 0, AxisType.REDUCE, tag=())
c20 = c5.index(((((c7*6)+c9)+(c2*36))+(c15*2304))).load()
c22 = (0.0<c20).where(c20, UOp.const(dtypes.float, 0.0))
c23 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2)
c25 = c23.index(c2).load()
c28 = (c22+(c25*-1.0))
c29 = (c28*c28)
c32 = (c29.reduce(c15, c7, c9, arg=Ops.ADD)*5.425347222222222e-05)
c34 = c0.index(c2).store(c32).end(c2)
ast = c34.sink()
return ast
def k15():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=0)
c2 = UOp.range(2, 4, AxisType.LOOP, tag=())
c4 = UOp.range(2, 5, AxisType.LOOP, tag=())
c7 = UOp.range(3, 3, AxisType.LOOP, tag=())
c11 = UOp.range(3, 2, AxisType.LOOP, tag=())
c16 = UOp.range(64, 1, AxisType.LOOP, tag=())
c18 = (c16*36)
c21 = UOp.range(512, 0, AxisType.LOOP, tag=())
c23 = (c21*2304)
c27 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1)
c38 = c27.index((((((c7*2)+c4)+(((c11*2)+c2)*6))+c18)+c23)).load()
c40 = (0.0<c38).where(c38, UOp.const(dtypes.float, 0.0))
c41 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2)
c43 = c41.index(c16).load()
c47 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3)
c49 = c47.index(c16).load()
c51 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=4)
c53 = c51.index(c16).load()
c59 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=5)
c61 = c59.index(c16).load()
c62 = ((((c40+(c43*-1.0))*c49)*(c53+1e-05).sqrt().reciprocal())+c61)
c64 = c0.index(((((((c2*2)+c4)+(c7*4))+(c11*12))+c18)+c23)).store(c62).end(c21, c16, c11, c7, c2, c4)
ast = c64.sink()
return ast
def k16():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=0)
c2 = UOp.range(3, 4, AxisType.LOOP, tag=())
c4 = UOp.range(3, 5, AxisType.LOOP, tag=())
c7 = UOp.range(64, 3, AxisType.LOOP, tag=())
c12 = UOp.range(512, 2, AxisType.LOOP, tag=())
c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1)
c19 = UOp.range(2, 0, AxisType.REDUCE, tag=())
c21 = UOp.range(2, 1, AxisType.REDUCE, tag=())
c36 = c17.index(((((((c19*2)+c21)+(c4*4))+(c2*12))+(c7*36))+(c12*2304))).load()
c37 = c36.reduce(c19, c21, arg=Ops.MAX)
c39 = c0.index(((((c2*3)+c4)+(c7*9))+(c12*576))).store(c37).end(c12, c7, c2, c4)
ast = c39.sink()
return ast
def k17():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=0)
c2 = UOp.range(512, 1, AxisType.LOOP, tag=())
c5 = UOp.range(10, 2, AxisType.LOOP, tag=())
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=1)
c10 = UOp.range(576, 0, AxisType.REDUCE, tag=())
c24 = c8.index(((((((c10//3)%3)*3)+(c10%3))+((c10//9)*9))+(c2*576))).load()
c25 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5760), arg=2)
c29 = c25.index(((c5*576)+c10)).load()
c30 = (c24*c29)
c32 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(10), arg=3)
c34 = c32.index(c5).load()
c35 = (c30.reduce(c10, arg=Ops.ADD)+c34)
c37 = c0.index(((c2*10)+c5)).store(c35).end(c2, c5)
ast = c37.sink()
return ast
def k18():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=0)
c2 = UOp.range(512, 1, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=1)
c7 = UOp.range(10, 0, AxisType.REDUCE, tag=())
c10 = c4.index(((c2*10)+c7)).load()
c11 = c10.reduce(c7, arg=Ops.MAX)
c13 = c0.index(c2).store(c11).end(c2)
ast = c13.sink()
return ast
def k19():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=0)
c2 = UOp.range(512, 1, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=1)
c7 = UOp.range(10, 0, AxisType.REDUCE, tag=())
c10 = c4.index(((c2*10)+c7)).load()
c11 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=2)
c13 = c11.index(c2).load()
c19 = ((c10+(c13*-1.0))*1.4426950408889634).exp2()
c20 = c19.reduce(c7, arg=Ops.ADD)
c22 = c0.index(c2).store(c20).end(c2)
ast = c22.sink()
return ast
def k20():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(128000), arg=0)
c2 = UOp.range(512, 1, AxisType.LOOP, tag=())
c5 = UOp.range(250, 2, AxisType.LOOP, tag=())
c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=1)
c11 = c9.index(c2).load()
c13 = (60000.0*c11).cast(dtypes.int)
c18 = (c13<0).where((c13+60000), c13)
c21 = UOp.range(240, 0, AxisType.REDUCE, tag=())
c22 = ((c5*240)+c21)
c26 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(60000), arg=2)
c28 = c26.index(c22).load()
c31 = (c18!=c22.cast(dtypes.int)).where(UOp.const(dtypes.uchar, 0), c28.cast(dtypes.uint).cast(dtypes.uchar))
c32 = c31.reduce(c21, arg=Ops.ADD)
c34 = c0.index(((c2*250)+c5)).store(c32).end(c2, c5)
ast = c34.sink()
return ast
def k21():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=0)
c2 = UOp.range(512, 1, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(128000), arg=1)
c7 = UOp.range(250, 0, AxisType.REDUCE, tag=())
c10 = c4.index(((c2*250)+c7)).load()
c12 = c10.reduce(c7, arg=Ops.ADD).cast(dtypes.int)
c14 = c0.index(c2).store(c12).end(c2)
ast = c14.sink()
return ast
def k22():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=0)
c3 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=1)
c5 = UOp.range(512, 1, AxisType.REDUCE, tag=())
c8 = UOp.range(10, 0, AxisType.REDUCE, tag=())
c11 = c3.index(((c5*10)+c8)).load()
c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=2)
c14 = c12.index(c5).load()
c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=3)
c20 = c18.index(c5).load()
c26 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=4)
c28 = c26.index(c5).load()
c34 = (((c11+(c14*-1.0))+((c20.log2()*0.6931471805599453)*-1.0))*((c28!=c8.cast(dtypes.int))!=True).cast(dtypes.float))
c35 = c34.reduce(c8, arg=Ops.ADD)
c38 = (c35.reduce(c5, arg=Ops.ADD)*-0.001953125)
c39 = c0.index(UOp.const(dtypes.index, 0)).store(c38)
ast = c39.sink()
return ast
def k23():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=0)
c2 = c0.index(UOp.const(dtypes.index, 0))
c3 = c2.load()
c5 = (c3*0.9)
c6 = c2.store(c5)
ast = c6.sink()
return ast
def k24():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=0)
c2 = c0.index(UOp.const(dtypes.index, 0))
c3 = c2.load()
c5 = (c3*0.999)
c6 = c2.store(c5)
ast = c6.sink()
return ast
def k25():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=0)
c2 = UOp.range(512, 1, AxisType.LOOP, tag=())
c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1)
c7 = c5.index(c2).load()
c9 = UOp.range(10, 0, AxisType.REDUCE, tag=())
c17 = (-1.0*(((c7!=c9.cast(dtypes.int))!=True).cast(dtypes.float)*-0.001953125))
c19 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=2)
c21 = c19.index(c2).load()
c26 = ((c17.reduce(c9, arg=Ops.ADD)*(c21*0.6931471805599453).reciprocal())*0.6931471805599453)
c28 = c0.index(c2).store(c26).end(c2)
ast = c28.sink()
return ast
def k26():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=0)
c2 = UOp.range(512, 0, AxisType.LOOP, tag=())
c5 = UOp.range(10, 1, AxisType.LOOP, tag=())
c6 = ((c2*10)+c5)
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1)
c10 = c8.index(c2).load()
c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=2)
c20 = c18.index(c6).load()
c21 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=3)
c23 = c21.index(c2).load()
c30 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=4)
c32 = c30.index(c2).load()
c34 = ((((c10!=c5.cast(dtypes.int))!=True).cast(dtypes.float)*-0.001953125)+(((c20+(c23*-1.0))*1.4426950408889634).exp2()*c32))
c36 = c0.index(c6).store(c34).end(c2, c5)
ast = c36.sink()
return ast
def k27():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=0)
c2 = UOp.range(3, 4, AxisType.LOOP, tag=())
c4 = UOp.range(3, 5, AxisType.LOOP, tag=())
c7 = UOp.range(64, 3, AxisType.LOOP, tag=())
c12 = UOp.range(512, 2, AxisType.LOOP, tag=())
c15 = ((((c2*3)+c4)+(c7*9))+(c12*576))
c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1)
c19 = UOp.range(2, 0, AxisType.REDUCE, tag=())
c21 = UOp.range(2, 1, AxisType.REDUCE, tag=())
c36 = c17.index(((((((c19*2)+c21)+(c4*4))+(c2*12))+(c7*36))+(c12*2304))).load()
c37 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=2)
c39 = c37.index(c15).load()
c43 = ((c36!=c39)!=True).cast(dtypes.float)
c44 = c43.reduce(c19, c21, arg=Ops.ADD)
c46 = c0.index(c15).store(c44).end(c12, c7, c2, c4)
ast = c46.sink()
return ast
def k28():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=0)
c2 = UOp.range(512, 1, AxisType.LOOP, tag=())
c5 = UOp.range(576, 2, AxisType.LOOP, tag=())
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5760), arg=1)
c10 = UOp.range(10, 0, AxisType.REDUCE, tag=())
c14 = c8.index(((c10*576)+c5)).load()
c15 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=2)
c19 = c15.index(((c2*10)+c10)).load()
c20 = (c14*c19)
c21 = c20.reduce(c10, arg=Ops.ADD)
c23 = c0.index(((c2*576)+c5)).store(c21).end(c2, c5)
ast = c23.sink()
return ast
def k29():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=0)
c2 = UOp.range(6, 2, AxisType.LOOP, tag=())
c4 = UOp.range(6, 3, AxisType.LOOP, tag=())
c7 = UOp.range(64, 1, AxisType.LOOP, tag=())
c9 = (c7*36)
c12 = UOp.range(512, 0, AxisType.LOOP, tag=())
c14 = (c12*2304)
c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1)
c23 = (c4//2)
c27 = (c2//2)
c34 = c17.index((((((((c2%2)*2)+(c4%2))+(c23*4))+(c27*12))+c9)+c14)).load()
c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=2)
c44 = ((((c27*3)+c23)+(c7*9))+(c12*576))
c46 = c35.index(c44).load()
c51 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=3)
c53 = c51.index(c44).load()
c56 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=4)
c58 = c56.index(c44).load()
c59 = ((((c34!=c46)!=True).cast(dtypes.float)*c53.reciprocal())*c58)
c61 = c0.index(((((c2*6)+c4)+c9)+c14)).store(c59).end(c12, c7, c2, c4)
ast = c61.sink()
return ast
def k30():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0)
c2 = UOp.range(64, 3, AxisType.LOOP, tag=())
c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1)
c7 = UOp.range(6, 1, AxisType.REDUCE, tag=())
c9 = UOp.range(6, 2, AxisType.REDUCE, tag=())
c15 = UOp.range(512, 0, AxisType.REDUCE, tag=())
c18 = ((((c7*6)+c9)+(c2*36))+(c15*2304))
c20 = c5.index(c18).load()
c22 = (0.0<c20).where(c20, UOp.const(dtypes.float, 0.0))
c23 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2)
c25 = c23.index(c2).load()
c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3)
c31 = c29.index(c2).load()
c33 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=4)
c35 = c33.index(c18).load()
c36 = (((c22+(c25*-1.0))*c31)*c35)
c38 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=5)
c40 = c38.index(c2).load()
c43 = (c40+1e-05).sqrt()
c44 = c43.reciprocal()
c52 = ((((c36.reduce(c15, c7, c9, arg=Ops.ADD)*c44)*c44)*(c43*2.0).reciprocal())*-5.425347222222222e-05)
c54 = c0.index(c2).store(c52).end(c2)
ast = c54.sink()
return ast
def k31():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0)
c2 = UOp.range(64, 3, AxisType.LOOP, tag=())
c6 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=1)
c8 = c6.index(c2).load()
c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2)
c11 = c9.index(c2).load()
c16 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=3)
c18 = UOp.range(6, 1, AxisType.REDUCE, tag=())
c20 = UOp.range(6, 2, AxisType.REDUCE, tag=())
c26 = UOp.range(512, 0, AxisType.REDUCE, tag=())
c31 = c16.index(((((c18*6)+c20)+(c2*36))+(c26*2304))).load()
c34 = (-1.0*(c8*((c11+1e-05).sqrt().reciprocal()*c31)))
c36 = (5.425347222222222e-05*c34.reduce(c26, c18, c20, arg=Ops.ADD))
c38 = c0.index(c2).store(c36).end(c2)
ast = c38.sink()
return ast
def k32():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=0)
c2 = UOp.range(6, 2, AxisType.LOOP, tag=())
c4 = UOp.range(6, 3, AxisType.LOOP, tag=())
c7 = UOp.range(64, 1, AxisType.LOOP, tag=())
c12 = UOp.range(512, 0, AxisType.LOOP, tag=())
c15 = ((((c2*6)+c4)+(c7*36))+(c12*2304))
c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1)
c20 = c18.index(c15).load()
c21 = (0.0<c20)
c22 = c21.where(c20, UOp.const(dtypes.float, 0.0))
c23 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2)
c25 = c23.index(c7).load()
c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3)
c31 = c29.index(c7).load()
c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=4)
c37 = c35.index(c7).load()
c38 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=5)
c40 = c38.index(c7).load()
c45 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=6)
c47 = c45.index(c15).load()
c51 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=7)
c53 = c51.index(c7).load()
c55 = c21.where((((((c22+(c25*-1.0))*c31)*2.0)+(c37*((c40+1e-05).sqrt().reciprocal()*c47)))+c53), UOp.const(dtypes.float, 0.0))
c57 = c0.index(c15).store(c55).end(c12, c7, c2, c4)
ast = c57.sink()
return ast
def k33():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=0)
c2 = UOp.range(8, 5, AxisType.LOOP, tag=())
c4 = UOp.range(8, 6, AxisType.LOOP, tag=())
c7 = UOp.range(64, 4, AxisType.LOOP, tag=())
c11 = UOp.range(512, 3, AxisType.LOOP, tag=())
c14 = ((((c2*8)+c4)+(c7*64))+(c11*4096))
c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=1)
c19 = c17.index(c14).load()
c22 = UOp.range(4, 1, AxisType.REDUCE, tag=())
c24 = ((c22*8)+c2)
c26 = (c24<27)
c27 = UOp.range(4, 2, AxisType.REDUCE, tag=())
c29 = ((c27*8)+c4)
c30 = (c29<27)
c33 = (c24%9)
c36 = (c26&(c33<6))
c38 = (c29%9)
c39 = (c38<6)
c42 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=2)
c45 = c36.where((c24//9), UOp.const(dtypes.index, Invalid))
c48 = (c30&c39)
c50 = c48.where((c29//9), UOp.const(dtypes.index, Invalid))
c54 = UOp.range(64, 0, AxisType.REDUCE, tag=())
c59 = c42.index(((((c45*3)+c50)+(c7*9))+(c54*576))).load()
c60 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=3)
c70 = (c48&c36).where(((((c33*6)+c38)+(c54*36))+(c11*2304)), UOp.const(dtypes.index, Invalid))
c72 = c60.index(c70).load()
c73 = (c59*c72)
c75 = ((c26&c30)&((c36&c30)&c39)).where(c73.reduce(c54, arg=Ops.ADD), UOp.const(dtypes.float, 0.0))
c77 = (0.0<c19).where(c75.reduce(c22, c27, arg=Ops.ADD), UOp.const(dtypes.float, 0.0))
c79 = c0.index(c14).store(c77).end(c11, c7, c2, c4)
ast = c79.sink()
return ast
def k34():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=0)
c2 = UOp.range(10, 4, AxisType.LOOP, tag=())
c4 = UOp.range(10, 5, AxisType.LOOP, tag=())
c7 = UOp.range(32, 3, AxisType.LOOP, tag=())
c12 = UOp.range(512, 2, AxisType.LOOP, tag=())
c15 = ((((c2*10)+c4)+(c7*100))+(c12*3200))
c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1)
c19 = UOp.range(2, 0, AxisType.REDUCE, tag=())
c21 = UOp.range(2, 1, AxisType.REDUCE, tag=())
c36 = c17.index(((((((c19*2)+c21)+(c4*4))+(c2*40))+(c7*400))+(c12*12800))).load()
c37 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=2)
c39 = c37.index(c15).load()
c43 = ((c36!=c39)!=True).cast(dtypes.float)
c44 = c43.reduce(c19, c21, arg=Ops.ADD)
c46 = c0.index(c15).store(c44).end(c12, c7, c2, c4)
ast = c46.sink()
return ast
def k35():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=0)
c2 = UOp.range(10, 5, AxisType.LOOP, tag=())
c4 = UOp.range(10, 6, AxisType.LOOP, tag=())
c7 = UOp.range(32, 4, AxisType.LOOP, tag=())
c12 = UOp.range(512, 3, AxisType.LOOP, tag=())
c18 = UOp.range(4, 1, AxisType.REDUCE, tag=())
c20 = ((c18*10)+c2)
c22 = (c20<33)
c23 = UOp.range(4, 2, AxisType.REDUCE, tag=())
c25 = ((c23*10)+c4)
c26 = (c25<33)
c29 = (c20%11)
c32 = (c22&(c29<8))
c34 = (c25%11)
c35 = (c34<8)
c38 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18432), arg=1)
c41 = c32.where((c20//11), UOp.const(dtypes.index, Invalid))
c44 = (c26&c35)
c46 = c44.where((c25//11), UOp.const(dtypes.index, Invalid))
c52 = UOp.range(64, 0, AxisType.REDUCE, tag=())
c57 = c38.index(((((c41*3)+c46)+(c7*9))+(c52*288))).load()
c58 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=2)
c67 = (c44&c32).where(((((c29*8)+c34)+(c52*64))+(c12*4096)), UOp.const(dtypes.index, Invalid))
c69 = c58.index(c67).load()
c70 = (c57*c69)
c73 = ((c22&c26)&((c32&c26)&c35)).where(c70.reduce(c52, arg=Ops.ADD), UOp.const(dtypes.float, 0.0))
c74 = c73.reduce(c18, c23, arg=Ops.ADD)
c76 = c0.index(((((c2*10)+c4)+(c7*100))+(c12*3200))).store(c74).end(c12, c7, c2, c4)
ast = c76.sink()
return ast
def k36():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=0)
c2 = UOp.range(10, 4, AxisType.LOOP, tag=())
c5 = UOp.range(2, 5, AxisType.LOOP, tag=())
c7 = UOp.range(2, 3, AxisType.LOOP, tag=())
c11 = UOp.range(10, 2, AxisType.LOOP, tag=())
c13 = (c11*40)
c16 = UOp.range(32, 1, AxisType.LOOP, tag=())
c18 = (c16*400)
c21 = UOp.range(512, 0, AxisType.LOOP, tag=())
c23 = (c21*12800)
c26 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1)
c36 = c26.index(((((((c7*2)+c5)+(c2*4))+c13)+c18)+c23)).load()
c37 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=2)
c45 = ((((c11*10)+c2)+(c16*100))+(c21*3200))
c47 = c37.index(c45).load()
c52 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=3)
c54 = c52.index(c45).load()
c57 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=4)
c59 = c57.index(c45).load()
c60 = ((((c36!=c47)!=True).cast(dtypes.float)*c54.reciprocal())*c59)
c62 = c0.index(((((((c2*2)+c5)+(c7*20))+c13)+c18)+c23)).store(c60).end(c21, c16, c11, c7, c2, c5)
ast = c62.sink()
return ast
def k37():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0)
c2 = UOp.range(32, 3, AxisType.LOOP, tag=())
c5 = UOp.range(256, 4, AxisType.LOOP, tag=())
c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1)
c11 = UOp.range(20, 1, AxisType.REDUCE, tag=())
c13 = UOp.range(20, 2, AxisType.REDUCE, tag=())
c20 = UOp.range(2, 0, AxisType.REDUCE, tag=())
c24 = ((((c11*20)+c13)+(c2*400))+(((c5*2)+c20)*12800))
c26 = c9.index(c24).load()
c28 = (0.0<c26).where(c26, UOp.const(dtypes.float, 0.0))
c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2)
c31 = c29.index(c2).load()
c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=3)
c37 = c35.index(c2).load()
c39 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=4)
c41 = c39.index(c24).load()
c42 = (((c28+(c31*-1.0))*c37)*c41)
c43 = c42.reduce(c20, c11, c13, arg=Ops.ADD)
c45 = c0.index(((c2*256)+c5)).store(c43).end(c2, c5)
ast = c45.sink()
return ast
def k38():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0)
c2 = UOp.range(32, 1, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=1)
c7 = UOp.range(256, 0, AxisType.REDUCE, tag=())
c10 = c4.index(((c2*256)+c7)).load()
c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2)
c14 = c12.index(c2).load()
c17 = (c14+1e-05).sqrt()
c18 = c17.reciprocal()
c26 = ((((c10.reduce(c7, arg=Ops.ADD)*c18)*c18)*(c17*2.0).reciprocal())*-4.8828125e-06)
c28 = c0.index(c2).store(c26).end(c2)
ast = c28.sink()
return ast
def k39():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0)
c2 = UOp.range(32, 3, AxisType.LOOP, tag=())
c5 = UOp.range(256, 4, AxisType.LOOP, tag=())
c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=1)
c11 = c9.index(c2).load()
c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2)
c14 = c12.index(c2).load()
c19 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=3)
c21 = UOp.range(20, 1, AxisType.REDUCE, tag=())
c23 = UOp.range(20, 2, AxisType.REDUCE, tag=())
c30 = UOp.range(2, 0, AxisType.REDUCE, tag=())
c36 = c19.index(((((c21*20)+c23)+(c2*400))+(((c5*2)+c30)*12800))).load()
c39 = (-1.0*(c11*((c14+1e-05).sqrt().reciprocal()*c36)))
c40 = c39.reduce(c30, c21, c23, arg=Ops.ADD)
c42 = c0.index(((c2*256)+c5)).store(c40).end(c2, c5)
ast = c42.sink()
return ast
def k40():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0)
c2 = UOp.range(32, 1, AxisType.LOOP, tag=())
c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=1)
c8 = UOp.range(256, 0, AxisType.REDUCE, tag=())
c11 = c5.index(((c2*256)+c8)).load()
c13 = (4.8828125e-06*c11.reduce(c8, arg=Ops.ADD))
c15 = c0.index(c2).store(c13).end(c2)
ast = c15.sink()
return ast
def k41():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=0)
c2 = UOp.range(20, 2, AxisType.LOOP, tag=())
c4 = UOp.range(20, 3, AxisType.LOOP, tag=())
c7 = UOp.range(32, 1, AxisType.LOOP, tag=())
c12 = UOp.range(512, 0, AxisType.LOOP, tag=())
c15 = ((((c2*20)+c4)+(c7*400))+(c12*12800))
c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1)
c20 = c18.index(c15).load()
c21 = (0.0<c20)
c22 = c21.where(c20, UOp.const(dtypes.float, 0.0))
c23 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2)
c25 = c23.index(c7).load()
c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=3)
c31 = c29.index(c7).load()
c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=4)
c37 = c35.index(c7).load()
c38 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=5)
c40 = c38.index(c7).load()
c45 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=6)
c47 = c45.index(c15).load()
c51 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=7)
c53 = c51.index(c7).load()
c55 = c21.where((((((c22+(c25*-1.0))*c31)*2.0)+(c37*((c40+1e-05).sqrt().reciprocal()*c47)))+c53), UOp.const(dtypes.float, 0.0))
c57 = c0.index(c15).store(c55).end(c12, c7, c2, c4)
ast = c57.sink()
return ast
def k42():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=0)
c2 = UOp.range(24, 5, AxisType.LOOP, tag=())
c4 = UOp.range(24, 6, AxisType.LOOP, tag=())
c7 = UOp.range(32, 4, AxisType.LOOP, tag=())
c12 = UOp.range(512, 3, AxisType.LOOP, tag=())
c15 = ((((c2*24)+c4)+(c7*576))+(c12*18432))
c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=1)
c20 = c18.index(c15).load()
c23 = UOp.range(6, 1, AxisType.REDUCE, tag=())
c25 = ((c23*24)+c2)
c27 = (c25<125)
c28 = UOp.range(6, 2, AxisType.REDUCE, tag=())
c30 = ((c28*24)+c4)
c31 = (c30<125)
c34 = (c25%25)
c37 = (c27&(c34<20))
c39 = (c30%25)
c40 = (c39<20)
c43 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25600), arg=2)
c46 = c37.where((c25//25), UOp.const(dtypes.index, Invalid))
c49 = (c31&c40)
c51 = c49.where((c30//25), UOp.const(dtypes.index, Invalid))
c55 = UOp.range(32, 0, AxisType.REDUCE, tag=())
c60 = c43.index(((((c46*5)+c51)+(c7*25))+(c55*800))).load()
c61 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=3)
c71 = (c49&c37).where(((((c34*20)+c39)+(c55*400))+(c12*12800)), UOp.const(dtypes.index, Invalid))
c73 = c61.index(c71).load()
c74 = (c60*c73)
c76 = ((c27&c31)&((c37&c31)&c40)).where(c74.reduce(c55, arg=Ops.ADD), UOp.const(dtypes.float, 0.0))
c78 = (0.0<c20).where(c76.reduce(c23, c28, arg=Ops.ADD), UOp.const(dtypes.float, 0.0))
c80 = c0.index(c15).store(c78).end(c12, c7, c2, c4)
ast = c80.sink()
return ast
def k43():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(204800), arg=0)
c2 = UOp.range(5, 5, AxisType.LOOP, tag=())
c5 = UOp.range(256, 6, AxisType.LOOP, tag=())
c7 = UOp.range(5, 4, AxisType.LOOP, tag=())
c12 = UOp.range(32, 3, AxisType.LOOP, tag=())
c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(401408), arg=1)
c19 = UOp.range(24, 2, AxisType.REDUCE, tag=())
c21 = UOp.range(24, 1, AxisType.REDUCE, tag=())
c28 = UOp.range(2, 0, AxisType.REDUCE, tag=())
c29 = ((c5*2)+c28)
c34 = c17.index((((c2+c19)+((c7+c21)*28))+(c29*784))).load()
c39 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=2)
c49 = c39.index(((((c21*24)+c19)+(c12*576))+(c29*18432))).load()
c50 = (c34.cast(dtypes.uchar).cast(dtypes.uint).cast(dtypes.uchar).cast(dtypes.float)*c49)
c51 = c50.reduce(c28, c21, c19, arg=Ops.ADD)
c53 = c0.index(((((c2*256)+c5)+(c7*1280))+(c12*6400))).store(c51).end(c12, c7, c2, c5)
ast = c53.sink()
return ast
def k44():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(800), arg=0)
c2 = UOp.range(5, 2, AxisType.LOOP, tag=())
c4 = UOp.range(5, 3, AxisType.LOOP, tag=())
c7 = UOp.range(32, 1, AxisType.LOOP, tag=())
c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(204800), arg=1)
c15 = UOp.range(256, 0, AxisType.REDUCE, tag=())
c24 = c12.index(((((c4*256)+c15)+(c2*1280))+(c7*6400))).load()
c25 = c24.reduce(c15, arg=Ops.ADD)
c27 = c0.index((((c2*5)+c4)+(c7*25))).store(c25).end(c7, c2, c4)
ast = c27.sink()
return ast
def k45():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0)
c2 = UOp.range(32, 3, AxisType.LOOP, tag=())
c5 = UOp.range(256, 4, AxisType.LOOP, tag=())
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=1)
c10 = UOp.range(24, 1, AxisType.REDUCE, tag=())
c12 = UOp.range(24, 2, AxisType.REDUCE, tag=())
c19 = UOp.range(2, 0, AxisType.REDUCE, tag=())
c25 = c8.index(((((c10*24)+c12)+(c2*576))+(((c5*2)+c19)*18432))).load()
c26 = c25.reduce(c19, c10, c12, arg=Ops.ADD)
c28 = c0.index(((c2*256)+c5)).store(c26).end(c2, c5)
ast = c28.sink()
return ast
def k46():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0)
c2 = UOp.range(32, 1, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=1)
c7 = UOp.range(256, 0, AxisType.REDUCE, tag=())
c10 = c4.index(((c2*256)+c7)).load()
c11 = c10.reduce(c7, arg=Ops.ADD)
c13 = c0.index(c2).store(c11).end(c2)
ast = c13.sink()
return ast
def k47():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(3276800), arg=0)
c2 = UOp.range(5, 6, AxisType.LOOP, tag=())
c5 = UOp.range(128, 7, AxisType.LOOP, tag=())
c7 = UOp.range(5, 5, AxisType.LOOP, tag=())
c12 = UOp.range(32, 4, AxisType.LOOP, tag=())
c16 = UOp.range(32, 3, AxisType.LOOP, tag=())
c22 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=1)
c24 = UOp.range(20, 2, AxisType.REDUCE, tag=())
c26 = UOp.range(20, 1, AxisType.REDUCE, tag=())
c36 = UOp.range(4, 0, AxisType.REDUCE, tag=())
c37 = ((c5*4)+c36)
c42 = c22.index(((((c2+c24)+((c7+c26)*24))+(c12*576))+(c37*18432))).load()
c44 = (0.0<c42).where(c42, UOp.const(dtypes.float, 0.0))
c45 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=2)
c55 = c45.index(((((c26*20)+c24)+(c16*400))+(c37*12800))).load()
c56 = (c44*c55)
c57 = c56.reduce(c36, c26, c24, arg=Ops.ADD)
c59 = c0.index((((((c2*128)+c5)+(c7*640))+(c12*3200))+(c16*102400))).store(c57).end(c16, c12, c7, c2, c5)
ast = c59.sink()
return ast
def k48():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25600), arg=0)
c2 = UOp.range(5, 3, AxisType.LOOP, tag=())
c4 = UOp.range(5, 4, AxisType.LOOP, tag=())
c7 = UOp.range(32, 2, AxisType.LOOP, tag=())
c11 = UOp.range(32, 1, AxisType.LOOP, tag=())
c16 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(3276800), arg=1)
c19 = UOp.range(128, 0, AxisType.REDUCE, tag=())
c31 = c16.index((((((c4*128)+c19)+(c2*640))+(c7*3200))+(c11*102400))).load()
c32 = c31.reduce(c19, arg=Ops.ADD)
c34 = c0.index(((((c2*5)+c4)+(c7*25))+(c11*800))).store(c32).end(c11, c7, c2, c4)
ast = c34.sink()
return ast
def k49():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0)
c2 = UOp.range(32, 3, AxisType.LOOP, tag=())
c5 = UOp.range(256, 4, AxisType.LOOP, tag=())
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1)
c10 = UOp.range(20, 1, AxisType.REDUCE, tag=())
c12 = UOp.range(20, 2, AxisType.REDUCE, tag=())
c19 = UOp.range(2, 0, AxisType.REDUCE, tag=())
c25 = c8.index(((((c10*20)+c12)+(c2*400))+(((c5*2)+c19)*12800))).load()
c26 = c25.reduce(c19, c10, c12, arg=Ops.ADD)
c28 = c0.index(((c2*256)+c5)).store(c26).end(c2, c5)
ast = c28.sink()
return ast
def k50():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0)
c2 = UOp.range(32, 3, AxisType.LOOP, tag=())
c5 = UOp.range(256, 4, AxisType.LOOP, tag=())
c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1)
c11 = UOp.range(20, 1, AxisType.REDUCE, tag=())
c13 = UOp.range(20, 2, AxisType.REDUCE, tag=())
c20 = UOp.range(2, 0, AxisType.REDUCE, tag=())
c24 = ((((c11*20)+c13)+(c2*400))+(((c5*2)+c20)*12800))
c26 = c9.index(c24).load()
c28 = (0.0<c26).where(c26, UOp.const(dtypes.float, 0.0))
c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2)
c31 = c29.index(c2).load()
c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=3)
c37 = c35.index(c2).load()
c42 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=4)
c44 = c42.index(c24).load()
c46 = ((c28+(c31*-1.0))*((c37+1e-05).sqrt().reciprocal()*c44))
c47 = c46.reduce(c20, c11, c13, arg=Ops.ADD)
c49 = c0.index(((c2*256)+c5)).store(c47).end(c2, c5)
ast = c49.sink()
return ast
def k51():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2359296), arg=0)
c2 = UOp.range(3, 6, AxisType.LOOP, tag=())
c5 = UOp.range(128, 7, AxisType.LOOP, tag=())
c7 = UOp.range(3, 5, AxisType.LOOP, tag=())
c12 = UOp.range(32, 4, AxisType.LOOP, tag=())
c17 = UOp.range(64, 3, AxisType.LOOP, tag=())
c22 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=1)
c24 = UOp.range(8, 2, AxisType.REDUCE, tag=())
c26 = UOp.range(8, 1, AxisType.REDUCE, tag=())
c36 = UOp.range(4, 0, AxisType.REDUCE, tag=())
c37 = ((c5*4)+c36)
c42 = c22.index(((((c2+c24)+((c7+c26)*10))+(c12*100))+(c37*3200))).load()
c43 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=2)
c52 = c43.index(((((c26*8)+c24)+(c17*64))+(c37*4096))).load()
c53 = (c42*c52)
c54 = c53.reduce(c36, c26, c24, arg=Ops.ADD)
c56 = c0.index((((((c2*128)+c5)+(c7*384))+(c12*1152))+(c17*36864))).store(c54).end(c17, c12, c7, c2, c5)
ast = c56.sink()
return ast
def k52():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18432), arg=0)
c2 = UOp.range(3, 3, AxisType.LOOP, tag=())
c4 = UOp.range(3, 4, AxisType.LOOP, tag=())
c7 = UOp.range(32, 2, AxisType.LOOP, tag=())
c12 = UOp.range(64, 1, AxisType.LOOP, tag=())
c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2359296), arg=1)
c20 = UOp.range(128, 0, AxisType.REDUCE, tag=())
c32 = c17.index((((((c4*128)+c20)+(c2*384))+(c7*1152))+(c12*36864))).load()
c33 = c32.reduce(c20, arg=Ops.ADD)
c35 = c0.index(((((c2*3)+c4)+(c7*9))+(c12*288))).store(c33).end(c12, c7, c2, c4)
ast = c35.sink()
return ast
def k53():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(16384), arg=0)
c2 = UOp.range(64, 3, AxisType.LOOP, tag=())
c5 = UOp.range(256, 4, AxisType.LOOP, tag=())
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=1)
c10 = UOp.range(8, 1, AxisType.REDUCE, tag=())
c12 = UOp.range(8, 2, AxisType.REDUCE, tag=())
c18 = UOp.range(2, 0, AxisType.REDUCE, tag=())
c24 = c8.index(((((c10*8)+c12)+(c2*64))+(((c5*2)+c18)*4096))).load()
c25 = c24.reduce(c18, c10, c12, arg=Ops.ADD)
c27 = c0.index(((c2*256)+c5)).store(c25).end(c2, c5)
ast = c27.sink()
return ast
def k54():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0)
c2 = UOp.range(64, 1, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(16384), arg=1)
c7 = UOp.range(256, 0, AxisType.REDUCE, tag=())
c10 = c4.index(((c2*256)+c7)).load()
c11 = c10.reduce(c7, arg=Ops.ADD)
c13 = c0.index(c2).store(c11).end(c2)
ast = c13.sink()
return ast
def k55():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=0)
c2 = UOp.range(3, 5, AxisType.LOOP, tag=())
c4 = UOp.range(3, 6, AxisType.LOOP, tag=())
c7 = UOp.range(64, 4, AxisType.LOOP, tag=())
c11 = UOp.range(64, 3, AxisType.LOOP, tag=())
c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=1)
c19 = UOp.range(6, 2, AxisType.REDUCE, tag=())
c21 = UOp.range(6, 1, AxisType.REDUCE, tag=())
c29 = UOp.range(512, 0, AxisType.REDUCE, tag=())
c34 = c17.index(((((c4+c19)+((c2+c21)*8))+(c7*64))+(c29*4096))).load()
c36 = (0.0<c34).where(c34, UOp.const(dtypes.float, 0.0))
c37 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=2)
c47 = c37.index(((((c21*6)+c19)+(c11*36))+(c29*2304))).load()
c48 = (c36*c47)
c49 = c48.reduce(c29, c21, c19, arg=Ops.ADD)
c51 = c0.index(((((c2*3)+c4)+(c7*9))+(c11*576))).store(c49).end(c11, c7, c2, c4)
ast = c51.sink()
return ast
def k56():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0)
c2 = UOp.range(64, 3, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1)
c6 = UOp.range(6, 1, AxisType.REDUCE, tag=())
c8 = UOp.range(6, 2, AxisType.REDUCE, tag=())
c14 = UOp.range(512, 0, AxisType.REDUCE, tag=())
c19 = c4.index(((((c6*6)+c8)+(c2*36))+(c14*2304))).load()
c20 = c19.reduce(c14, c6, c8, arg=Ops.ADD)
c22 = c0.index(c2).store(c20).end(c2)
ast = c22.sink()
return ast
def k57():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0)
c2 = UOp.range(64, 3, AxisType.LOOP, tag=())
c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1)
c7 = UOp.range(6, 1, AxisType.REDUCE, tag=())
c9 = UOp.range(6, 2, AxisType.REDUCE, tag=())
c15 = UOp.range(512, 0, AxisType.REDUCE, tag=())
c18 = ((((c7*6)+c9)+(c2*36))+(c15*2304))
c20 = c5.index(c18).load()
c22 = (0.0<c20).where(c20, UOp.const(dtypes.float, 0.0))
c23 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2)
c25 = c23.index(c2).load()
c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3)
c31 = c29.index(c2).load()
c36 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=4)
c38 = c36.index(c18).load()
c40 = ((c22+(c25*-1.0))*((c31+1e-05).sqrt().reciprocal()*c38))
c41 = c40.reduce(c15, c7, c9, arg=Ops.ADD)
c43 = c0.index(c2).store(c41).end(c2)
ast = c43.sink()
return ast
def k58():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5760), arg=0)
c2 = UOp.range(10, 1, AxisType.LOOP, tag=())
c5 = UOp.range(576, 2, AxisType.LOOP, tag=())
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=1)
c20 = UOp.range(512, 0, AxisType.REDUCE, tag=())
c24 = c8.index(((((((c5//3)%3)*3)+(c5%3))+((c5//9)*9))+(c20*576))).load()
c25 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=2)
c29 = c25.index(((c20*10)+c2)).load()
c30 = (c24*c29)
c31 = c30.reduce(c20, arg=Ops.ADD)
c33 = c0.index(((c2*576)+c5)).store(c31).end(c2, c5)
ast = c33.sink()
return ast
def k59():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(10), arg=0)
c2 = UOp.range(10, 1, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=1)
c6 = UOp.range(512, 0, AxisType.REDUCE, tag=())
c10 = c4.index(((c6*10)+c2)).load()
c11 = c10.reduce(c6, arg=Ops.ADD)
c13 = c0.index(c2).store(c11).end(c2)
ast = c13.sink()
return ast
def k60():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=0)
c2 = UOp.range(87850, 0, AxisType.LOOP, tag=())
c5 = (c2<800)
c6 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(800), arg=1)
c18 = c5.where((((((c2//5)%5)*5)+(c2%5))+((c2//25)*25)), UOp.const(dtypes.index, Invalid))
c20 = c6.index(c18).load()
c22 = c5.where(c20, UOp.const(dtypes.float, 0.0))
c26 = (c2<832)
c27 = ((c5!=True)&c26)
c28 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2)
c31 = c27.where((c2+-800), UOp.const(dtypes.index, Invalid))
c33 = c28.index(c31).load()
c34 = c27.where(c33, UOp.const(dtypes.float, 0.0))
c38 = (c2<26432)
c39 = ((c26!=True)&c38)
c40 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25600), arg=3)
c42 = (c2+3)
c65 = c39.where(((((((((c42//5)+3)%5)*5)+(c42%5))+(((((c2+18)//25)+30)%32)*25))+(((c2+768)//800)*800))+-1600), UOp.const(dtypes.index, Invalid))
c67 = c40.index(c65).load()
c68 = c39.where(c67, UOp.const(dtypes.float, 0.0))
c72 = (c2<26464)
c73 = ((c38!=True)&c72)
c74 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=4)
c77 = c73.where((c2+-26432), UOp.const(dtypes.index, Invalid))
c79 = c74.index(c77).load()
c80 = c73.where(c79, UOp.const(dtypes.float, 0.0))
c84 = (c2<26496)
c85 = ((c72!=True)&c84)
c86 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=5)
c89 = c85.where((c2+-26464), UOp.const(dtypes.index, Invalid))
c91 = c86.index(c89).load()
c92 = c85.where(c91, UOp.const(dtypes.float, 0.0))
c96 = (c2<26528)
c97 = ((c84!=True)&c96)
c98 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=6)
c101 = c97.where((c2+-26496), UOp.const(dtypes.index, Invalid))
c103 = c98.index(c101).load()
c104 = c97.where(c103, UOp.const(dtypes.float, 0.0))
c108 = (c2<44960)
c109 = ((c96!=True)&c108)
c110 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18432), arg=7)
c112 = (c2+1)
c136 = c109.where(((((((((c112//3)+1)%3)*3)+(c112%3))+(((((c2+4)//9)+28)%32)*9))+(((c2+256)//288)*288))+-26784), UOp.const(dtypes.index, Invalid))
c138 = c110.index(c136).load()
c139 = c109.where(c138, UOp.const(dtypes.float, 0.0))
c143 = (c2<45024)
c144 = ((c108!=True)&c143)
c145 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=8)
c148 = c144.where((c2+-44960), UOp.const(dtypes.index, Invalid))
c150 = c145.index(c148).load()
c151 = c144.where(c150, UOp.const(dtypes.float, 0.0))
c155 = (c2<81888)
c156 = ((c143!=True)&c155)
c157 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=9)
c179 = c156.where(((((((((c2//3)+1)%3)*3)+(c2%3))+((((c42//9)+53)%64)*9))+(((c2+480)//576)*576))+-45504), UOp.const(dtypes.index, Invalid))
c181 = c157.index(c179).load()
c182 = c156.where(c181, UOp.const(dtypes.float, 0.0))
c186 = (c2<81952)
c187 = ((c155!=True)&c186)
c188 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=10)
c191 = c187.where((c2+-81888), UOp.const(dtypes.index, Invalid))
c193 = c188.index(c191).load()
c194 = c187.where(c193, UOp.const(dtypes.float, 0.0))
c198 = (c2<82016)
c199 = ((c186!=True)&c198)
c200 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=11)
c203 = c199.where((c2+-81952), UOp.const(dtypes.index, Invalid))
c205 = c200.index(c203).load()
c206 = c199.where(c205, UOp.const(dtypes.float, 0.0))
c210 = (c2<82080)
c211 = ((c198!=True)&c210)
c212 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=12)
c215 = c211.where((c2+-82016), UOp.const(dtypes.index, Invalid))
c217 = c212.index(c215).load()
c218 = c211.where(c217, UOp.const(dtypes.float, 0.0))
c222 = (c2<87840)
c223 = ((c210!=True)&c222)
c224 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5760), arg=13)
c227 = c223.where((c2+-82080), UOp.const(dtypes.index, Invalid))
c229 = c224.index(c227).load()
c230 = c223.where(c229, UOp.const(dtypes.float, 0.0))
c232 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(10), arg=14)
c236 = (c222!=True).where((c2+-87840), UOp.const(dtypes.index, Invalid))
c238 = c232.index(c236).load()
c239 = c222.where(UOp.const(dtypes.float, 0.0), c238)
c240 = (((((((((((((c22+c34)+c68)+c80)+c92)+c104)+c139)+c151)+c182)+c194)+c206)+c218)+c230)+c239)
c242 = c0.index(c2).store(c240).end(c2)
ast = c242.sink()
return ast
def k61():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=0)
c2 = UOp.range(87850, 0, AxisType.LOOP, tag=())
c3 = c0.index(c2)
c5 = c3.load()
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c10 = c8.index(c2).load()
c12 = ((0.9*c5)+(0.09999999999999998*c10))
c14 = c3.store(c12).end(c2)
ast = c14.sink()
return ast
def k62():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=0)
c2 = UOp.range(87850, 0, AxisType.LOOP, tag=())
c3 = c0.index(c2)
c5 = c3.load()
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c10 = c8.index(c2).load()
c13 = ((0.999*c5)+(0.0010000000000000009*(c10*c10)))
c15 = c3.store(c13).end(c2)
ast = c15.sink()
return ast
def k63():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=0)
c2 = UOp.range(87850, 0, AxisType.LOOP, tag=())
c5 = (c2<800)
c6 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(800), arg=1)
c8 = c5.where(c2, UOp.const(dtypes.index, Invalid))
c10 = c6.index(c8).load()
c12 = c5.where(c10, UOp.const(dtypes.float, 0.0))
c16 = (c2<832)
c17 = ((c5!=True)&c16)
c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2)
c21 = c17.where((c2+-800), UOp.const(dtypes.index, Invalid))
c23 = c18.index(c21).load()
c24 = c17.where(c23, UOp.const(dtypes.float, 0.0))
c28 = (c2<26432)
c29 = ((c16!=True)&c28)
c30 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25600), arg=3)
c33 = c29.where((c2+-832), UOp.const(dtypes.index, Invalid))
c35 = c30.index(c33).load()
c36 = c29.where(c35, UOp.const(dtypes.float, 0.0))
c40 = (c2<26464)
c41 = ((c28!=True)&c40)
c42 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=4)
c45 = c41.where((c2+-26432), UOp.const(dtypes.index, Invalid))
c47 = c42.index(c45).load()
c48 = c41.where(c47, UOp.const(dtypes.float, 0.0))
c52 = (c2<26496)
c53 = ((c40!=True)&c52)
c54 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=5)
c57 = c53.where((c2+-26464), UOp.const(dtypes.index, Invalid))
c59 = c54.index(c57).load()
c60 = c53.where(c59, UOp.const(dtypes.float, 0.0))
c64 = (c2<26528)
c65 = ((c52!=True)&c64)
c66 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=6)
c69 = c65.where((c2+-26496), UOp.const(dtypes.index, Invalid))
c71 = c66.index(c69).load()
c72 = c65.where(c71, UOp.const(dtypes.float, 0.0))
c76 = (c2<44960)
c77 = ((c64!=True)&c76)
c78 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18432), arg=7)
c81 = c77.where((c2+-26528), UOp.const(dtypes.index, Invalid))
c83 = c78.index(c81).load()
c84 = c77.where(c83, UOp.const(dtypes.float, 0.0))
c88 = (c2<45024)
c89 = ((c76!=True)&c88)
c90 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=8)
c93 = c89.where((c2+-44960), UOp.const(dtypes.index, Invalid))
c95 = c90.index(c93).load()
c96 = c89.where(c95, UOp.const(dtypes.float, 0.0))
c100 = (c2<81888)
c101 = ((c88!=True)&c100)
c102 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=9)
c105 = c101.where((c2+-45024), UOp.const(dtypes.index, Invalid))
c107 = c102.index(c105).load()
c108 = c101.where(c107, UOp.const(dtypes.float, 0.0))
c112 = (c2<81952)
c113 = ((c100!=True)&c112)
c114 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=10)
c117 = c113.where((c2+-81888), UOp.const(dtypes.index, Invalid))
c119 = c114.index(c117).load()
c120 = c113.where(c119, UOp.const(dtypes.float, 0.0))
c124 = (c2<82016)
c125 = ((c112!=True)&c124)
c126 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=11)
c129 = c125.where((c2+-81952), UOp.const(dtypes.index, Invalid))
c131 = c126.index(c129).load()
c132 = c125.where(c131, UOp.const(dtypes.float, 0.0))
c136 = (c2<82080)
c137 = ((c124!=True)&c136)
c138 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=12)
c141 = c137.where((c2+-82016), UOp.const(dtypes.index, Invalid))
c143 = c138.index(c141).load()
c144 = c137.where(c143, UOp.const(dtypes.float, 0.0))
c148 = (c2<87840)
c149 = ((c136!=True)&c148)
c150 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5760), arg=13)
c153 = c149.where((c2+-82080), UOp.const(dtypes.index, Invalid))
c155 = c150.index(c153).load()
c156 = c149.where(c155, UOp.const(dtypes.float, 0.0))
c158 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(10), arg=14)
c162 = (c148!=True).where((c2+-87840), UOp.const(dtypes.index, Invalid))
c164 = c158.index(c162).load()
c165 = c148.where(UOp.const(dtypes.float, 0.0), c164)
c167 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=15)
c170 = c167.index(UOp.const(dtypes.index, 0)).load()
c171 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=16)
c173 = c171.index(c2).load()
c175 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=17)
c177 = c175.index(UOp.const(dtypes.index, 0)).load()
c181 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=18)
c183 = c181.index(c2).load()
c184 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=19)
c186 = c184.index(UOp.const(dtypes.index, 0)).load()
c199 = ((((((((((((((c12+c24)+c36)+c48)+c60)+c72)+c84)+c96)+c108)+c120)+c132)+c144)+c156)+c165)+((c170*(c173*((1.0+(c177*-1.0))*((c183*(1.0+(c186*-1.0)).reciprocal()).sqrt()+1e-08)).reciprocal()))*-1.0))
c201 = c0.index(c2).store(c199).end(c2)
ast = c201.sink()
return ast
def k64():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(800), arg=0)
c2 = UOp.range(5, 1, AxisType.LOOP, tag=())
c4 = UOp.range(5, 2, AxisType.LOOP, tag=())
c7 = UOp.range(32, 0, AxisType.LOOP, tag=())
c10 = (((c2*5)+c4)+(c7*25))
c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c14 = c12.index(c10).load()
c16 = c0.index(c10).store(c14).end(c7, c2, c4)
ast = c16.sink()
return ast
def k65():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0)
c2 = UOp.range(32, 0, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c8 = c4.index((c2+800)).load()
c10 = c0.index(c2).store(c8).end(c2)
ast = c10.sink()
return ast
def k66():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25600), arg=0)
c2 = UOp.range(5, 2, AxisType.LOOP, tag=())
c4 = UOp.range(5, 3, AxisType.LOOP, tag=())
c7 = UOp.range(32, 1, AxisType.LOOP, tag=())
c11 = UOp.range(32, 0, AxisType.LOOP, tag=())
c14 = ((((c2*5)+c4)+(c7*25))+(c11*800))
c16 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c20 = c16.index((c14+832)).load()
c22 = c0.index(c14).store(c20).end(c11, c7, c2, c4)
ast = c22.sink()
return ast
def k67():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0)
c2 = UOp.range(32, 0, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c8 = c4.index((c2+26432)).load()
c10 = c0.index(c2).store(c8).end(c2)
ast = c10.sink()
return ast
def k68():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0)
c2 = UOp.range(32, 0, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c8 = c4.index((c2+26464)).load()
c10 = c0.index(c2).store(c8).end(c2)
ast = c10.sink()
return ast
def k69():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0)
c2 = UOp.range(32, 0, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c8 = c4.index((c2+26496)).load()
c10 = c0.index(c2).store(c8).end(c2)
ast = c10.sink()
return ast
def k70():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18432), arg=0)
c2 = UOp.range(3, 2, AxisType.LOOP, tag=())
c4 = UOp.range(3, 3, AxisType.LOOP, tag=())
c7 = UOp.range(32, 1, AxisType.LOOP, tag=())
c12 = UOp.range(64, 0, AxisType.LOOP, tag=())
c15 = ((((c2*3)+c4)+(c7*9))+(c12*288))
c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c21 = c17.index((c15+26528)).load()
c23 = c0.index(c15).store(c21).end(c12, c7, c2, c4)
ast = c23.sink()
return ast
def k71():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0)
c2 = UOp.range(64, 0, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c8 = c4.index((c2+44960)).load()
c10 = c0.index(c2).store(c8).end(c2)
ast = c10.sink()
return ast
def k72():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=0)
c2 = UOp.range(3, 2, AxisType.LOOP, tag=())
c4 = UOp.range(3, 3, AxisType.LOOP, tag=())
c7 = UOp.range(64, 1, AxisType.LOOP, tag=())
c11 = UOp.range(64, 0, AxisType.LOOP, tag=())
c14 = ((((c2*3)+c4)+(c7*9))+(c11*576))
c16 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c20 = c16.index((c14+45024)).load()
c22 = c0.index(c14).store(c20).end(c11, c7, c2, c4)
ast = c22.sink()
return ast
def k73():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0)
c2 = UOp.range(64, 0, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c8 = c4.index((c2+81888)).load()
c10 = c0.index(c2).store(c8).end(c2)
ast = c10.sink()
return ast
def k74():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0)
c2 = UOp.range(64, 0, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c8 = c4.index((c2+81952)).load()
c10 = c0.index(c2).store(c8).end(c2)
ast = c10.sink()
return ast
def k75():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0)
c2 = UOp.range(64, 0, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c8 = c4.index((c2+82016)).load()
c10 = c0.index(c2).store(c8).end(c2)
ast = c10.sink()
return ast
def k76():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5760), arg=0)
c2 = UOp.range(10, 0, AxisType.LOOP, tag=())
c5 = UOp.range(576, 1, AxisType.LOOP, tag=())
c6 = ((c2*576)+c5)
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c12 = c8.index((c6+82080)).load()
c14 = c0.index(c6).store(c12).end(c2, c5)
ast = c14.sink()
return ast
def k77():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(10), arg=0)
c2 = UOp.range(10, 0, AxisType.LOOP, tag=())
c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1)
c8 = c4.index((c2+87840)).load()
c10 = c0.index(c2).store(c8).end(c2)
ast = c10.sink()
return ast
def k78():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(1), arg=0)
c2 = c0.index(UOp.const(dtypes.index, 0))
c3 = c2.load()
c5 = (c3+1)
c6 = c2.store(c5)
ast = c6.sink()
return ast
def k79():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0)
c2 = UOp.range(32, 0, AxisType.LOOP, tag=())
c3 = c0.index(c2)
c5 = c3.load()
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=1)
c10 = c8.index(c2).load()
c12 = ((0.9*c5)+(0.1*c10))
c14 = c3.store(c12).end(c2)
ast = c14.sink()
return ast
def k80():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0)
c2 = UOp.range(32, 0, AxisType.LOOP, tag=())
c3 = c0.index(c2)
c5 = c3.load()
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=1)
c10 = c8.index(c2).load()
c12 = ((0.9*c5)+(0.1000004882836342*c10))
c14 = c3.store(c12).end(c2)
ast = c14.sink()
return ast
def k81():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0)
c2 = UOp.range(64, 0, AxisType.LOOP, tag=())
c3 = c0.index(c2)
c5 = c3.load()
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=1)
c10 = c8.index(c2).load()
c12 = ((0.9*c5)+(0.1*c10))
c14 = c3.store(c12).end(c2)
ast = c14.sink()
return ast
def k82():
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0)
c2 = UOp.range(64, 0, AxisType.LOOP, tag=())
c3 = c0.index(c2)
c5 = c3.load()
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=1)
c10 = c8.index(c2).load()
c12 = ((0.9*c5)+(0.10000542564158212*c10))
c14 = c3.store(c12).end(c2)
ast = c14.sink()
return ast
c2 = UOp.new_buffer('METAL', 1, dtypes.uint, 17)
c3 = UOp(Ops.KERNEL, src=(c2,), arg=Kernel(k0(), (Metadata(name='randint', caller='', backward=False),)))
c4 = c2.after(c3)
c7 = c4.forced_reshape((1,)).forced_reshape((1,))
c9 = UOp.new_buffer('METAL', 512, dtypes.float, 126)
c11 = UOp.new_buffer('METAL', 2, dtypes.uint, 18)
c12 = UOp(Ops.KERNEL, src=(c9, c4, c11), arg=Kernel(k1(), (Metadata(name='randint', caller='', backward=False),)))
c13 = c9.after(c12)
c15 = c13.forced_reshape((512,))
c17 = UOp.new_buffer('METAL', 9437184, dtypes.float, 127)
c19 = UOp.new_buffer('METAL', 401408, dtypes.uint, 128)
c21 = UOp.new_buffer('METAL', 4014080, dtypes.uchar, 129)
c23 = UOp.new_buffer('METAL', 47040000, dtypes.uchar, 20)
c24 = UOp(Ops.KERNEL, src=(c21, c13, c23), arg=Kernel(k2(), (Metadata(name='randint', caller='', backward=False), Metadata(name='__getitem__', caller='', backward=False))))
c26 = UOp(Ops.KERNEL, src=(c19, c21.after(c24)), arg=Kernel(k3(), (Metadata(name='conv2d', caller='', backward=False),)))
c27 = c19.after(c26)
c29 = UOp.new_buffer('METAL', 800, dtypes.float, 107)
c31 = UOp.new_buffer('METAL', 32, dtypes.float, 108)
c32 = UOp(Ops.KERNEL, src=(c17, c27, c29, c31), arg=Kernel(k4(), (Metadata(name='conv2d', caller='', backward=False),)))
c33 = c17.after(c32)
c35 = c33.reshape((512,32,24,24))
c37 = UOp.new_buffer('METAL', 6553600, dtypes.float, 130)
c39 = UOp.new_buffer('METAL', 25600, dtypes.float, 109)
c41 = UOp.new_buffer('METAL', 32, dtypes.float, 110)
c42 = UOp(Ops.KERNEL, src=(c37, c33, c39, c41), arg=Kernel(k5(), (Metadata(name='relu', caller='', backward=False), Metadata(name='conv2d', caller='', backward=False))))
c43 = c37.after(c42)
c45 = c43.reshape((512,32,20,20))
c47 = UOp.new_buffer('METAL', 32, dtypes.float, 131)
c49 = UOp.new_buffer('METAL', 8192, dtypes.float, 132)
c50 = UOp(Ops.KERNEL, src=(c49, c43), arg=Kernel(k6(), (Metadata(name='relu', caller='', backward=False),)))
c52 = UOp(Ops.KERNEL, src=(c47, c49.after(c50)), arg=Kernel(k7(), (Metadata(name='mean', caller='', backward=False),)))
c53 = c47.after(c52)
c55 = c53.forced_reshape((32,))
c57 = UOp.new_buffer('METAL', 32, dtypes.float, 133)
c59 = UOp.new_buffer('METAL', 8192, dtypes.float, 134)
c60 = UOp(Ops.KERNEL, src=(c59, c43, c53), arg=Kernel(k8(), (Metadata(name='relu', caller='', backward=False), Metadata(name='__sub__', caller='', backward=False), Metadata(name='__mul__', caller='', backward=False))))
c62 = UOp(Ops.KERNEL, src=(c57, c59.after(c60)), arg=Kernel(k7(), (Metadata(name='mean', caller='', backward=False),)))
c63 = c57.after(c62)
c64 = c63.forced_reshape((32,))
c66 = UOp.new_buffer('METAL', 2097152, dtypes.float, 135)
c68 = UOp.new_buffer('METAL', 1638400, dtypes.float, 136)
c70 = UOp.new_buffer('METAL', 6553600, dtypes.float, 137)
c72 = UOp.new_buffer('METAL', 32, dtypes.float, 111)
c74 = UOp.new_buffer('METAL', 32, dtypes.float, 112)
c75 = UOp(Ops.KERNEL, src=(c70, c43, c53, c72, c63, c74), arg=Kernel(k9(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='relu', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=False), Metadata(name='max_pool2d', caller='', backward=False))))
c76 = c70.after(c75)
c77 = UOp(Ops.KERNEL, src=(c68, c76), arg=Kernel(k10(), (Metadata(name='max_pool2d', caller='', backward=False),)))
c78 = c68.after(c77)
c80 = UOp.new_buffer('METAL', 18432, dtypes.float, 113)
c82 = UOp.new_buffer('METAL', 64, dtypes.float, 114)
c83 = UOp(Ops.KERNEL, src=(c66, c78, c80, c82), arg=Kernel(k11(), (Metadata(name='conv2d', caller='', backward=False),)))
c84 = c66.after(c83)
c86 = c84.reshape((512,64,8,8))
c88 = UOp.new_buffer('METAL', 1179648, dtypes.float, 138)
c90 = UOp.new_buffer('METAL', 36864, dtypes.float, 115)
c92 = UOp.new_buffer('METAL', 64, dtypes.float, 116)
c93 = UOp(Ops.KERNEL, src=(c88, c84, c90, c92), arg=Kernel(k12(), (Metadata(name='relu', caller='', backward=False), Metadata(name='conv2d', caller='', backward=False))))
c94 = c88.after(c93)
c96 = c94.reshape((512,64,6,6))
c98 = UOp.new_buffer('METAL', 64, dtypes.float, 139)
c99 = UOp(Ops.KERNEL, src=(c98, c94), arg=Kernel(k13(), (Metadata(name='relu', caller='', backward=False), Metadata(name='mean', caller='', backward=False))))
c100 = c98.after(c99)
c102 = c100.forced_reshape((64,))
c104 = UOp.new_buffer('METAL', 64, dtypes.float, 140)
c105 = UOp(Ops.KERNEL, src=(c104, c94, c100), arg=Kernel(k14(), (Metadata(name='relu', caller='', backward=False), Metadata(name='__sub__', caller='', backward=False), Metadata(name='__mul__', caller='', backward=False), Metadata(name='mean', caller='', backward=False))))
c106 = c104.after(c105)
c107 = c106.forced_reshape((64,))
c109 = UOp.new_buffer('METAL', 5120, dtypes.float, 141)
c111 = UOp.new_buffer('METAL', 294912, dtypes.float, 142)
c113 = UOp.new_buffer('METAL', 1179648, dtypes.float, 143)
c115 = UOp.new_buffer('METAL', 64, dtypes.float, 117)
c117 = UOp.new_buffer('METAL', 64, dtypes.float, 118)
c118 = UOp(Ops.KERNEL, src=(c113, c94, c100, c115, c106, c117), arg=Kernel(k15(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='relu', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=False), Metadata(name='max_pool2d', caller='', backward=False))))
c119 = c113.after(c118)
c120 = UOp(Ops.KERNEL, src=(c111, c119), arg=Kernel(k16(), (Metadata(name='max_pool2d', caller='', backward=False),)))
c121 = c111.after(c120)
c123 = UOp.new_buffer('METAL', 5760, dtypes.float, 119)
c125 = UOp.new_buffer('METAL', 10, dtypes.float, 120)
c126 = UOp(Ops.KERNEL, src=(c109, c121, c123, c125), arg=Kernel(k17(), (Metadata(name='linear', caller='', backward=False),)))
c127 = c109.after(c126)
c129 = c127.reshape((512,10))
c131 = UOp.new_buffer('METAL', 1, dtypes.float, 144)
c133 = UOp.new_buffer('METAL', 512, dtypes.float, 145)
c134 = UOp(Ops.KERNEL, src=(c133, c127), arg=Kernel(k18(), (Metadata(name='sparse_categorical_crossentropy', caller='', backward=False),)))
c135 = c133.after(c134)
c137 = UOp.new_buffer('METAL', 512, dtypes.float, 146)
c138 = UOp(Ops.KERNEL, src=(c137, c127, c135), arg=Kernel(k19(), (Metadata(name='sparse_categorical_crossentropy', caller='', backward=False),)))
c139 = c137.after(c138)
c141 = UOp.new_buffer('METAL', 512, dtypes.int, 147)
c143 = UOp.new_buffer('METAL', 128000, dtypes.uchar, 148)
c145 = UOp.new_buffer('METAL', 60000, dtypes.uchar, 47)
c146 = UOp(Ops.KERNEL, src=(c143, c13, c145), arg=Kernel(k20(), (Metadata(name='randint', caller='', backward=False), Metadata(name='__getitem__', caller='', backward=False))))
c148 = UOp(Ops.KERNEL, src=(c141, c143.after(c146)), arg=Kernel(k21(), (Metadata(name='sparse_categorical_crossentropy', caller='', backward=False),)))
c149 = c141.after(c148)
c150 = UOp(Ops.KERNEL, src=(c131, c127, c135, c139, c149), arg=Kernel(k22(), (Metadata(name='sparse_categorical_crossentropy', caller='', backward=False),)))
c152 = UOp(Ops.VECTORIZE, dtypes.index.vec(0))
c153 = c131.after(c150).reshape(())
c155 = UOp.new_buffer('METAL', 1, dtypes.float, 55)
c156 = UOp(Ops.KERNEL, src=(c155,), arg=Kernel(k23(), (Metadata(name='__imul__', caller='', backward=False),)))
c157 = c155.after(c156)
c159 = c157.forced_reshape((1,)).forced_reshape((1,))
c161 = UOp.new_buffer('METAL', 1, dtypes.float, 56)
c162 = UOp(Ops.KERNEL, src=(c161,), arg=Kernel(k24(), (Metadata(name='__imul__', caller='', backward=False),)))
c163 = c161.after(c162)
c165 = c163.forced_reshape((1,)).forced_reshape((1,))
c167 = UOp.new_buffer('METAL', 5120, dtypes.float, 149)
c169 = UOp.new_buffer('METAL', 512, dtypes.float, 150)
c170 = UOp(Ops.KERNEL, src=(c169, c149, c139), arg=Kernel(k25(), (Metadata(name='sparse_categorical_crossentropy', caller='', backward=False), Metadata(name='sparse_categorical_crossentropy', caller='', backward=True))))
c172 = UOp(Ops.KERNEL, src=(c167, c149, c127, c135, c169.after(c170)), arg=Kernel(k26(), (Metadata(name='sparse_categorical_crossentropy', caller='', backward=False), Metadata(name='sparse_categorical_crossentropy', caller='', backward=True))))
c173 = c167.after(c172)
c174 = c173.reshape((512,10))
c176 = UOp.new_buffer('METAL', 64, dtypes.float, 151)
c178 = UOp.new_buffer('METAL', 1179648, dtypes.float, 152)
c180 = UOp.new_buffer('METAL', 294912, dtypes.float, 153)
c181 = UOp(Ops.KERNEL, src=(c180, c119, c121), arg=Kernel(k27(), (Metadata(name='max_pool2d', caller='', backward=True),)))
c184 = UOp.new_buffer('METAL', 294912, dtypes.float, 154)
c185 = UOp(Ops.KERNEL, src=(c184, c123, c173), arg=Kernel(k28(), (Metadata(name='linear', caller='', backward=True),)))
c187 = UOp(Ops.KERNEL, src=(c178, c119, c121, c180.after(c181), c184.after(c185)), arg=Kernel(k29(), (Metadata(name='max_pool2d', caller='', backward=True),)))
c188 = c178.after(c187)
c189 = UOp(Ops.KERNEL, src=(c176, c94, c100, c115, c188, c106), arg=Kernel(k30(), (Metadata(name='rsqrt', caller='', backward=True), Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='relu', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True), Metadata(name='mean', caller='', backward=True))))
c190 = c176.after(c189)
c191 = c190.forced_reshape((64,))
c193 = UOp.new_buffer('METAL', 64, dtypes.float, 155)
c194 = UOp(Ops.KERNEL, src=(c193, c115, c106, c188), arg=Kernel(k31(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True), Metadata(name='mean', caller='', backward=True))))
c195 = c193.after(c194)
c196 = c195.forced_reshape((64,))
c198 = UOp.new_buffer('METAL', 1179648, dtypes.float, 156)
c199 = UOp(Ops.KERNEL, src=(c198, c94, c100, c190, c115, c106, c188, c195), arg=Kernel(k32(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True), Metadata(name='__sub__', caller='', backward=False), Metadata(name='__mul__', caller='', backward=True), Metadata(name='relu', caller='', backward=False), Metadata(name='relu', caller='', backward=True))))
c200 = c198.after(c199)
c201 = c200.reshape((512,64,6,6))
c203 = UOp.new_buffer('METAL', 2097152, dtypes.float, 157)
c204 = UOp(Ops.KERNEL, src=(c203, c84, c90, c200), arg=Kernel(k33(), (Metadata(name='relu', caller='', backward=False), Metadata(name='conv2d', caller='', backward=True))))
c205 = c203.after(c204)
c206 = c205.reshape((512,64,8,8))
c208 = UOp.new_buffer('METAL', 32, dtypes.float, 158)
c210 = UOp.new_buffer('METAL', 8192, dtypes.float, 159)
c212 = UOp.new_buffer('METAL', 6553600, dtypes.float, 160)
c214 = UOp.new_buffer('METAL', 1638400, dtypes.float, 161)
c215 = UOp(Ops.KERNEL, src=(c214, c76, c78), arg=Kernel(k34(), (Metadata(name='max_pool2d', caller='', backward=True),)))
c218 = UOp.new_buffer('METAL', 1638400, dtypes.float, 162)
c219 = UOp(Ops.KERNEL, src=(c218, c80, c205), arg=Kernel(k35(), (Metadata(name='conv2d', caller='', backward=True),)))
c221 = UOp(Ops.KERNEL, src=(c212, c76, c78, c214.after(c215), c218.after(c219)), arg=Kernel(k36(), (Metadata(name='max_pool2d', caller='', backward=True),)))
c222 = c212.after(c221)
c223 = UOp(Ops.KERNEL, src=(c210, c43, c53, c72, c222), arg=Kernel(k37(), (Metadata(name='relu', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True))))
c225 = UOp(Ops.KERNEL, src=(c208, c210.after(c223), c63), arg=Kernel(k38(), (Metadata(name='rsqrt', caller='', backward=True), Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='mean', caller='', backward=True))))
c226 = c208.after(c225)
c227 = c226.forced_reshape((32,))
c229 = UOp.new_buffer('METAL', 32, dtypes.float, 163)
c231 = UOp.new_buffer('METAL', 8192, dtypes.float, 164)
c232 = UOp(Ops.KERNEL, src=(c231, c72, c63, c222), arg=Kernel(k39(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True))))
c234 = UOp(Ops.KERNEL, src=(c229, c231.after(c232)), arg=Kernel(k40(), (Metadata(name='mean', caller='', backward=True),)))
c235 = c229.after(c234)
c236 = c235.forced_reshape((32,))
c238 = UOp.new_buffer('METAL', 6553600, dtypes.float, 165)
c239 = UOp(Ops.KERNEL, src=(c238, c43, c53, c226, c72, c63, c222, c235), arg=Kernel(k41(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True), Metadata(name='__sub__', caller='', backward=False), Metadata(name='__mul__', caller='', backward=True), Metadata(name='relu', caller='', backward=False), Metadata(name='relu', caller='', backward=True))))
c240 = c238.after(c239)
c241 = c240.reshape((512,32,20,20))
c243 = UOp.new_buffer('METAL', 9437184, dtypes.float, 166)
c244 = UOp(Ops.KERNEL, src=(c243, c33, c39, c240), arg=Kernel(k42(), (Metadata(name='relu', caller='', backward=False), Metadata(name='conv2d', caller='', backward=True))))
c245 = c243.after(c244)
c246 = c245.reshape((512,32,24,24))
c248 = UOp.new_buffer('METAL', 800, dtypes.float, 167)
c250 = UOp.new_buffer('METAL', 204800, dtypes.float, 168)
c251 = UOp(Ops.KERNEL, src=(c250, c27, c245), arg=Kernel(k43(), (Metadata(name='conv2d', caller='', backward=False), Metadata(name='conv2d', caller='', backward=True))))
c253 = UOp(Ops.KERNEL, src=(c248, c250.after(c251)), arg=Kernel(k44(), (Metadata(name='contiguous', caller='', backward=False),)))
c254 = c248.after(c253)
c256 = c254.reshape((32,1,5,5))
c258 = UOp.new_buffer('METAL', 32, dtypes.float, 169)
c260 = UOp.new_buffer('METAL', 8192, dtypes.float, 170)
c261 = UOp(Ops.KERNEL, src=(c260, c245), arg=Kernel(k45(), ()))
c263 = UOp(Ops.KERNEL, src=(c258, c260.after(c261)), arg=Kernel(k46(), (Metadata(name='contiguous', caller='', backward=False),)))
c264 = c258.after(c263)
c265 = c264.forced_reshape((32,))
c267 = UOp.new_buffer('METAL', 25600, dtypes.float, 171)
c269 = UOp.new_buffer('METAL', 3276800, dtypes.float, 172)
c270 = UOp(Ops.KERNEL, src=(c269, c33, c240), arg=Kernel(k47(), (Metadata(name='relu', caller='', backward=False), Metadata(name='conv2d', caller='', backward=False), Metadata(name='conv2d', caller='', backward=True))))
c272 = UOp(Ops.KERNEL, src=(c267, c269.after(c270)), arg=Kernel(k48(), (Metadata(name='contiguous', caller='', backward=False),)))
c273 = c267.after(c272)
c275 = c273.reshape((32,32,5,5))
c277 = UOp.new_buffer('METAL', 32, dtypes.float, 173)
c279 = UOp.new_buffer('METAL', 8192, dtypes.float, 174)
c280 = UOp(Ops.KERNEL, src=(c279, c240), arg=Kernel(k49(), ()))
c282 = UOp(Ops.KERNEL, src=(c277, c279.after(c280)), arg=Kernel(k46(), (Metadata(name='contiguous', caller='', backward=False),)))
c283 = c277.after(c282)
c284 = c283.forced_reshape((32,))
c286 = UOp.new_buffer('METAL', 32, dtypes.float, 175)
c288 = UOp.new_buffer('METAL', 8192, dtypes.float, 176)
c289 = UOp(Ops.KERNEL, src=(c288, c43, c53, c63, c222), arg=Kernel(k50(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='relu', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True))))
c291 = UOp(Ops.KERNEL, src=(c286, c288.after(c289)), arg=Kernel(k46(), (Metadata(name='contiguous', caller='', backward=False),)))
c292 = c286.after(c291)
c293 = c292.forced_reshape((32,))
c295 = UOp.new_buffer('METAL', 32, dtypes.float, 177)
c297 = UOp.new_buffer('METAL', 8192, dtypes.float, 178)
c298 = UOp(Ops.KERNEL, src=(c297, c222), arg=Kernel(k49(), ()))
c300 = UOp(Ops.KERNEL, src=(c295, c297.after(c298)), arg=Kernel(k46(), (Metadata(name='contiguous', caller='', backward=False),)))
c301 = c295.after(c300)
c302 = c301.forced_reshape((32,))
c304 = UOp.new_buffer('METAL', 18432, dtypes.float, 179)
c306 = UOp.new_buffer('METAL', 2359296, dtypes.float, 180)
c307 = UOp(Ops.KERNEL, src=(c306, c78, c205), arg=Kernel(k51(), (Metadata(name='conv2d', caller='', backward=True),)))
c309 = UOp(Ops.KERNEL, src=(c304, c306.after(c307)), arg=Kernel(k52(), (Metadata(name='contiguous', caller='', backward=False),)))
c310 = c304.after(c309)
c312 = c310.reshape((64,32,3,3))
c314 = UOp.new_buffer('METAL', 64, dtypes.float, 181)
c316 = UOp.new_buffer('METAL', 16384, dtypes.float, 182)
c317 = UOp(Ops.KERNEL, src=(c316, c205), arg=Kernel(k53(), ()))
c319 = UOp(Ops.KERNEL, src=(c314, c316.after(c317)), arg=Kernel(k54(), (Metadata(name='contiguous', caller='', backward=False),)))
c320 = c314.after(c319)
c321 = c320.forced_reshape((64,))
c323 = UOp.new_buffer('METAL', 36864, dtypes.float, 183)
c324 = UOp(Ops.KERNEL, src=(c323, c84, c200), arg=Kernel(k55(), (Metadata(name='relu', caller='', backward=False), Metadata(name='conv2d', caller='', backward=False), Metadata(name='conv2d', caller='', backward=True), Metadata(name='contiguous', caller='', backward=False))))
c325 = c323.after(c324)
c327 = c325.reshape((64,64,3,3))
c329 = UOp.new_buffer('METAL', 64, dtypes.float, 184)
c330 = UOp(Ops.KERNEL, src=(c329, c200), arg=Kernel(k56(), (Metadata(name='conv2d', caller='', backward=True), Metadata(name='contiguous', caller='', backward=False))))
c331 = c329.after(c330)
c332 = c331.forced_reshape((64,))
c334 = UOp.new_buffer('METAL', 64, dtypes.float, 185)
c335 = UOp(Ops.KERNEL, src=(c334, c94, c100, c106, c188), arg=Kernel(k57(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='relu', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True), Metadata(name='contiguous', caller='', backward=False))))
c336 = c334.after(c335)
c337 = c336.forced_reshape((64,))
c339 = UOp.new_buffer('METAL', 64, dtypes.float, 186)
c340 = UOp(Ops.KERNEL, src=(c339, c188), arg=Kernel(k56(), (Metadata(name='batchnorm', caller='', backward=True), Metadata(name='contiguous', caller='', backward=False))))
c341 = c339.after(c340)
c342 = c341.forced_reshape((64,))
c344 = UOp.new_buffer('METAL', 5760, dtypes.float, 187)
c345 = UOp(Ops.KERNEL, src=(c344, c121, c173), arg=Kernel(k58(), (Metadata(name='linear', caller='', backward=True), Metadata(name='contiguous', caller='', backward=False))))
c346 = c344.after(c345)
c348 = c346.reshape((10,576))
c350 = UOp.new_buffer('METAL', 10, dtypes.float, 188)
c351 = UOp(Ops.KERNEL, src=(c350, c173), arg=Kernel(k59(), (Metadata(name='linear', caller='', backward=True), Metadata(name='contiguous', caller='', backward=False))))
c352 = c350.after(c351)
c354 = c352.forced_reshape((10,))
c356 = UOp.new_buffer('METAL', 87850, dtypes.float, 189)
c357 = UOp(Ops.KERNEL, src=(c356, c254, c264, c273, c283, c292, c301, c310, c320, c325, c331, c336, c341, c346, c352), arg=Kernel(k60(), (Metadata(name='cat', caller='', backward=False),)))
c358 = c356.after(c357)
c360 = c358.forced_reshape((87850,))
c362 = UOp.new_buffer('METAL', 87850, dtypes.float, 103)
c363 = UOp(Ops.KERNEL, src=(c362, c358), arg=Kernel(k61(), (Metadata(name='__rmul__', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='assign', caller='', backward=False))))
c364 = c362.after(c363)
c366 = c364.forced_reshape((87850,)).forced_reshape((87850,))
c368 = UOp.new_buffer('METAL', 87850, dtypes.float, 104)
c369 = UOp(Ops.KERNEL, src=(c368, c358), arg=Kernel(k62(), (Metadata(name='__mul__', caller='', backward=False), Metadata(name='__rmul__', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='assign', caller='', backward=False))))
c370 = c368.after(c369)
c372 = c370.forced_reshape((87850,)).forced_reshape((87850,))
c374 = UOp.new_buffer('METAL', 87850, dtypes.float, 190)
c376 = UOp.new_buffer('METAL', 1, dtypes.float, 105)
c377 = UOp(Ops.KERNEL, src=(c374, c29, c31, c39, c41, c72, c74, c80, c82, c90, c92, c115, c117, c123, c125, c376, c364, c157, c370, c163), arg=Kernel(k63(), (Metadata(name='__truediv__', caller='', backward=False), Metadata(name='sqrt', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='__rsub__', caller='', backward=False), Metadata(name='__mul__', caller='', backward=False), Metadata(name='cat', caller='', backward=False), Metadata(name='__sub__', caller='', backward=False))))
c378 = c374.after(c377)
c379 = c378.forced_reshape((87850,))
c380 = UOp(Ops.KERNEL, src=(c29, c378), arg=Kernel(k64(), (Metadata(name='assign', caller='', backward=False),)))
c383 = c29.after(c380).reshape((32,1,5,5)).forced_reshape((32,1,5,5))
c384 = UOp(Ops.KERNEL, src=(c31, c378), arg=Kernel(k65(), (Metadata(name='assign', caller='', backward=False),)))
c387 = c31.after(c384).forced_reshape((32,)).forced_reshape((32,))
c388 = UOp(Ops.KERNEL, src=(c39, c378), arg=Kernel(k66(), (Metadata(name='assign', caller='', backward=False),)))
c391 = c39.after(c388).reshape((32,32,5,5)).forced_reshape((32,32,5,5))
c392 = UOp(Ops.KERNEL, src=(c41, c378), arg=Kernel(k67(), (Metadata(name='assign', caller='', backward=False),)))
c395 = c41.after(c392).forced_reshape((32,)).forced_reshape((32,))
c396 = UOp(Ops.KERNEL, src=(c72, c378), arg=Kernel(k68(), (Metadata(name='assign', caller='', backward=False),)))
c399 = c72.after(c396).forced_reshape((32,)).forced_reshape((32,))
c400 = UOp(Ops.KERNEL, src=(c74, c378), arg=Kernel(k69(), (Metadata(name='assign', caller='', backward=False),)))
c403 = c74.after(c400).forced_reshape((32,)).forced_reshape((32,))
c404 = UOp(Ops.KERNEL, src=(c80, c378), arg=Kernel(k70(), (Metadata(name='assign', caller='', backward=False),)))
c407 = c80.after(c404).reshape((64,32,3,3)).forced_reshape((64,32,3,3))
c408 = UOp(Ops.KERNEL, src=(c82, c378), arg=Kernel(k71(), (Metadata(name='assign', caller='', backward=False),)))
c411 = c82.after(c408).forced_reshape((64,)).forced_reshape((64,))
c412 = UOp(Ops.KERNEL, src=(c90, c378), arg=Kernel(k72(), (Metadata(name='assign', caller='', backward=False),)))
c415 = c90.after(c412).reshape((64,64,3,3)).forced_reshape((64,64,3,3))
c416 = UOp(Ops.KERNEL, src=(c92, c378), arg=Kernel(k73(), (Metadata(name='assign', caller='', backward=False),)))
c419 = c92.after(c416).forced_reshape((64,)).forced_reshape((64,))
c420 = UOp(Ops.KERNEL, src=(c115, c378), arg=Kernel(k74(), (Metadata(name='assign', caller='', backward=False),)))
c423 = c115.after(c420).forced_reshape((64,)).forced_reshape((64,))
c424 = UOp(Ops.KERNEL, src=(c117, c378), arg=Kernel(k75(), (Metadata(name='assign', caller='', backward=False),)))
c427 = c117.after(c424).forced_reshape((64,)).forced_reshape((64,))
c428 = UOp(Ops.KERNEL, src=(c123, c378), arg=Kernel(k76(), (Metadata(name='assign', caller='', backward=False),)))
c431 = c123.after(c428).reshape((10,576)).forced_reshape((10,576))
c432 = UOp(Ops.KERNEL, src=(c125, c378), arg=Kernel(k77(), (Metadata(name='assign', caller='', backward=False),)))
c435 = c125.after(c432).forced_reshape((10,)).forced_reshape((10,))
c437 = UOp.new_buffer('METAL', 1, dtypes.long, 121)
c438 = UOp(Ops.KERNEL, src=(c437,), arg=Kernel(k78(), (Metadata(name='__iadd__', caller='', backward=False),)))
c441 = c437.after(c438).reshape(()).forced_reshape(())
c443 = UOp.new_buffer('METAL', 32, dtypes.float, 122)
c444 = UOp(Ops.KERNEL, src=(c443, c53), arg=Kernel(k79(), (Metadata(name='__rmul__', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='assign', caller='', backward=False))))
c447 = c443.after(c444).forced_reshape((32,)).forced_reshape((32,))
c449 = UOp.new_buffer('METAL', 32, dtypes.float, 123)
c450 = UOp(Ops.KERNEL, src=(c449, c63), arg=Kernel(k80(), (Metadata(name='__rmul__', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='assign', caller='', backward=False))))
c453 = c449.after(c450).forced_reshape((32,)).forced_reshape((32,))
c455 = UOp.new_buffer('METAL', 64, dtypes.float, 124)
c456 = UOp(Ops.KERNEL, src=(c455, c100), arg=Kernel(k81(), (Metadata(name='__rmul__', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='assign', caller='', backward=False))))
c459 = c455.after(c456).forced_reshape((64,)).forced_reshape((64,))
c461 = UOp.new_buffer('METAL', 64, dtypes.float, 125)
c462 = UOp(Ops.KERNEL, src=(c461, c106), arg=Kernel(k82(), (Metadata(name='__rmul__', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='assign', caller='', backward=False))))
c465 = c461.after(c462).forced_reshape((64,)).forced_reshape((64,))
ast = c7.sink(c15, c35, c45, c55, c64, c86, c96, c102, c107, c129, c153, c159, c165, c174, c191, c196, c201, c206, c227, c236, c241, c246, c256, c265, c275, c284, c293, c302, c312, c321, c327, c332, c337, c342, c348, c354, c360, c366, c372, c379, c383, c387, c391, c395, c399, c403, c407, c411, c415, c419, c423, c427, c431, c435, c441, c447, c453, c459, c465)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment