Created
October 27, 2025 10:36
-
-
Save geohot/cb8c6ea335dfed87a707618d7fff39af to your computer and use it in GitHub Desktop.
beautiful mnist rendered to uops
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def k0(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(1), arg=0) | |
| c2 = c0.index(UOp.const(dtypes.index, 0)) | |
| c3 = c2.load() | |
| c5 = (c3+512) | |
| c6 = c2.store(c5) | |
| ast = c6.sink() | |
| return ast | |
| def k1(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=0) | |
| c2 = UOp.range(512, 0, AxisType.LOOP, tag=()) | |
| c5 = (c2<256) | |
| c7 = c5.where(c2, UOp.const(dtypes.index, Invalid)) | |
| c11 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(1), arg=1) | |
| c14 = c11.index(UOp.const(dtypes.index, 0)).load() | |
| c16 = (c14+-512) | |
| c17 = ((c7+1).cast(dtypes.uint)+c16) | |
| c27 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(2), arg=2) | |
| c29 = c27.index(UOp.const(dtypes.index, 1)).load() | |
| c33 = c27.index(UOp.const(dtypes.index, 0)).load() | |
| c35 = ((c29.cast(dtypes.ulong)*4294967296)|c33.cast(dtypes.ulong)) | |
| c41 = c5.where((UOp(Ops.THREEFRY, dtypes.ulong, ((((c17+255).cast(dtypes.ulong)*4294967296)|(c17+-1).cast(dtypes.ulong)), c35)).cast(dtypes.uint)&4294967295), UOp.const(dtypes.uint, 0)) | |
| c46 = (c5!=True).where((c2+-256), UOp.const(dtypes.index, Invalid)) | |
| c49 = ((c46+1).cast(dtypes.uint)+c16) | |
| c60 = c5.where(UOp.const(dtypes.uint, 0), ((UOp(Ops.THREEFRY, dtypes.ulong, ((((c49+255).cast(dtypes.ulong)*4294967296)|(c49+-1).cast(dtypes.ulong)), c35))//4294967296).cast(dtypes.uint)&4294967295)) | |
| c68 = ((((c41+c60)//512)|1065353216).bitcast(dtypes.float)+-1.0) | |
| c70 = c0.index(c2).store(c68).end(c2) | |
| ast = c70.sink() | |
| return ast | |
| def k2(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(4014080), arg=0) | |
| c2 = UOp.range(28, 3, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(10, 4, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(28, 2, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(512, 1, AxisType.LOOP, tag=()) | |
| c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=1) | |
| c20 = c18.index(c12).load() | |
| c22 = (60000.0*c20).cast(dtypes.int) | |
| c27 = (c22<0).where((c22+60000), c22) | |
| c30 = UOp.range(6000, 0, AxisType.REDUCE, tag=()) | |
| c31 = ((c5*6000)+c30) | |
| c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(47040000), arg=2) | |
| c42 = c35.index((((c7*28)+c2)+(c31*784))).load() | |
| c45 = (c27!=c31.cast(dtypes.int)).where(UOp.const(dtypes.uchar, 0), c42.cast(dtypes.uint).cast(dtypes.uchar)) | |
| c46 = c45.reduce(c30, arg=Ops.ADD) | |
| c48 = c0.index(((((c2*10)+c5)+(c7*280))+(c12*7840))).store(c46).end(c12, c7, c2, c5) | |
| ast = c48.sink() | |
| return ast | |
| def k3(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(401408), arg=0) | |
| c2 = UOp.range(28, 2, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(28, 3, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(512, 1, AxisType.LOOP, tag=()) | |
| c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(4014080), arg=1) | |
| c15 = UOp.range(10, 0, AxisType.REDUCE, tag=()) | |
| c24 = c12.index(((((c4*10)+c15)+(c2*280))+(c7*7840))).load() | |
| c26 = c24.reduce(c15, arg=Ops.ADD).cast(dtypes.uint) | |
| c28 = c0.index((((c2*28)+c4)+(c7*784))).store(c26).end(c7, c2, c4) | |
| ast = c28.sink() | |
| return ast | |
| def k4(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=0) | |
| c2 = UOp.range(24, 4, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(24, 5, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(32, 3, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(512, 2, AxisType.LOOP, tag=()) | |
| c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(401408), arg=1) | |
| c19 = UOp.range(5, 1, AxisType.REDUCE, tag=()) | |
| c21 = UOp.range(5, 0, AxisType.REDUCE, tag=()) | |
| c30 = c17.index((((c4+c19)+((c2+c21)*28))+(c12*784))).load() | |
| c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(800), arg=2) | |
| c42 = c35.index((((c21*5)+c19)+(c7*25))).load() | |
| c43 = (c30.cast(dtypes.uchar).cast(dtypes.uint).cast(dtypes.uchar).cast(dtypes.float)*c42) | |
| c45 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=3) | |
| c47 = c45.index(c7).load() | |
| c48 = (c43.reduce(c21, c19, arg=Ops.ADD)+c47) | |
| c50 = c0.index(((((c2*24)+c4)+(c7*576))+(c12*18432))).store(c48).end(c12, c7, c2, c4) | |
| ast = c50.sink() | |
| return ast | |
| def k5(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=0) | |
| c2 = UOp.range(20, 5, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(20, 6, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(32, 4, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(512, 3, AxisType.LOOP, tag=()) | |
| c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=1) | |
| c20 = UOp.range(5, 2, AxisType.REDUCE, tag=()) | |
| c22 = UOp.range(5, 1, AxisType.REDUCE, tag=()) | |
| c27 = UOp.range(32, 0, AxisType.REDUCE, tag=()) | |
| c35 = c18.index(((((c4+c20)+((c2+c22)*24))+(c27*576))+(c12*18432))).load() | |
| c37 = (0.0<c35).where(c35, UOp.const(dtypes.float, 0.0)) | |
| c38 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25600), arg=2) | |
| c48 = c38.index(((((c22*5)+c20)+(c27*25))+(c7*800))).load() | |
| c49 = (c37*c48) | |
| c51 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=3) | |
| c53 = c51.index(c7).load() | |
| c54 = (c49.reduce(c27, c22, c20, arg=Ops.ADD)+c53) | |
| c56 = c0.index(((((c2*20)+c4)+(c7*400))+(c12*12800))).store(c54).end(c12, c7, c2, c4) | |
| ast = c56.sink() | |
| return ast | |
| def k6(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0) | |
| c2 = UOp.range(32, 3, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(256, 4, AxisType.LOOP, tag=()) | |
| c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1) | |
| c11 = UOp.range(20, 1, AxisType.REDUCE, tag=()) | |
| c13 = UOp.range(20, 2, AxisType.REDUCE, tag=()) | |
| c20 = UOp.range(2, 0, AxisType.REDUCE, tag=()) | |
| c26 = c9.index(((((c11*20)+c13)+(c2*400))+(((c5*2)+c20)*12800))).load() | |
| c28 = (0.0<c26).where(c26, UOp.const(dtypes.float, 0.0)) | |
| c29 = c28.reduce(c20, c11, c13, arg=Ops.ADD) | |
| c31 = c0.index(((c2*256)+c5)).store(c29).end(c2, c5) | |
| ast = c31.sink() | |
| return ast | |
| def k7(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0) | |
| c2 = UOp.range(32, 1, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=1) | |
| c7 = UOp.range(256, 0, AxisType.REDUCE, tag=()) | |
| c10 = c4.index(((c2*256)+c7)).load() | |
| c13 = (c10.reduce(c7, arg=Ops.ADD)*4.8828125e-06) | |
| c15 = c0.index(c2).store(c13).end(c2) | |
| ast = c15.sink() | |
| return ast | |
| def k8(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0) | |
| c2 = UOp.range(32, 3, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(256, 4, AxisType.LOOP, tag=()) | |
| c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1) | |
| c11 = UOp.range(20, 1, AxisType.REDUCE, tag=()) | |
| c13 = UOp.range(20, 2, AxisType.REDUCE, tag=()) | |
| c20 = UOp.range(2, 0, AxisType.REDUCE, tag=()) | |
| c26 = c9.index(((((c11*20)+c13)+(c2*400))+(((c5*2)+c20)*12800))).load() | |
| c28 = (0.0<c26).where(c26, UOp.const(dtypes.float, 0.0)) | |
| c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2) | |
| c31 = c29.index(c2).load() | |
| c34 = (c28+(c31*-1.0)) | |
| c35 = (c34*c34) | |
| c36 = c35.reduce(c20, c11, c13, arg=Ops.ADD) | |
| c38 = c0.index(((c2*256)+c5)).store(c36).end(c2, c5) | |
| ast = c38.sink() | |
| return ast | |
| def k9(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=0) | |
| c2 = UOp.range(2, 4, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(2, 5, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(10, 3, AxisType.LOOP, tag=()) | |
| c11 = UOp.range(10, 2, AxisType.LOOP, tag=()) | |
| c16 = UOp.range(32, 1, AxisType.LOOP, tag=()) | |
| c18 = (c16*400) | |
| c21 = UOp.range(512, 0, AxisType.LOOP, tag=()) | |
| c23 = (c21*12800) | |
| c27 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1) | |
| c38 = c27.index((((((c7*2)+c4)+(((c11*2)+c2)*20))+c18)+c23)).load() | |
| c40 = (0.0<c38).where(c38, UOp.const(dtypes.float, 0.0)) | |
| c41 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2) | |
| c43 = c41.index(c16).load() | |
| c47 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=3) | |
| c49 = c47.index(c16).load() | |
| c51 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=4) | |
| c53 = c51.index(c16).load() | |
| c59 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=5) | |
| c61 = c59.index(c16).load() | |
| c62 = ((((c40+(c43*-1.0))*c49)*(c53+1e-05).sqrt().reciprocal())+c61) | |
| c64 = c0.index(((((((c2*2)+c4)+(c7*4))+(c11*40))+c18)+c23)).store(c62).end(c21, c16, c11, c7, c2, c4) | |
| ast = c64.sink() | |
| return ast | |
| def k10(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=0) | |
| c2 = UOp.range(10, 4, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(10, 5, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(32, 3, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(512, 2, AxisType.LOOP, tag=()) | |
| c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1) | |
| c19 = UOp.range(2, 0, AxisType.REDUCE, tag=()) | |
| c21 = UOp.range(2, 1, AxisType.REDUCE, tag=()) | |
| c36 = c17.index(((((((c19*2)+c21)+(c4*4))+(c2*40))+(c7*400))+(c12*12800))).load() | |
| c37 = c36.reduce(c19, c21, arg=Ops.MAX) | |
| c39 = c0.index(((((c2*10)+c4)+(c7*100))+(c12*3200))).store(c37).end(c12, c7, c2, c4) | |
| ast = c39.sink() | |
| return ast | |
| def k11(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=0) | |
| c2 = UOp.range(8, 5, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(8, 6, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(64, 4, AxisType.LOOP, tag=()) | |
| c11 = UOp.range(512, 3, AxisType.LOOP, tag=()) | |
| c16 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=1) | |
| c18 = UOp.range(3, 2, AxisType.REDUCE, tag=()) | |
| c20 = UOp.range(3, 1, AxisType.REDUCE, tag=()) | |
| c26 = UOp.range(32, 0, AxisType.REDUCE, tag=()) | |
| c34 = c16.index(((((c4+c18)+((c2+c20)*10))+(c26*100))+(c11*3200))).load() | |
| c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18432), arg=2) | |
| c45 = c35.index(((((c20*3)+c18)+(c26*9))+(c7*288))).load() | |
| c46 = (c34*c45) | |
| c48 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3) | |
| c50 = c48.index(c7).load() | |
| c51 = (c46.reduce(c26, c20, c18, arg=Ops.ADD)+c50) | |
| c53 = c0.index(((((c2*8)+c4)+(c7*64))+(c11*4096))).store(c51).end(c11, c7, c2, c4) | |
| ast = c53.sink() | |
| return ast | |
| def k12(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=0) | |
| c2 = UOp.range(6, 5, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(6, 6, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(64, 4, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(512, 3, AxisType.LOOP, tag=()) | |
| c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=1) | |
| c20 = UOp.range(3, 2, AxisType.REDUCE, tag=()) | |
| c22 = UOp.range(3, 1, AxisType.REDUCE, tag=()) | |
| c27 = UOp.range(64, 0, AxisType.REDUCE, tag=()) | |
| c34 = c18.index(((((c4+c20)+((c2+c22)*8))+(c27*64))+(c12*4096))).load() | |
| c36 = (0.0<c34).where(c34, UOp.const(dtypes.float, 0.0)) | |
| c37 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=2) | |
| c47 = c37.index(((((c22*3)+c20)+(c27*9))+(c7*576))).load() | |
| c48 = (c36*c47) | |
| c50 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3) | |
| c52 = c50.index(c7).load() | |
| c53 = (c48.reduce(c27, c22, c20, arg=Ops.ADD)+c52) | |
| c55 = c0.index(((((c2*6)+c4)+(c7*36))+(c12*2304))).store(c53).end(c12, c7, c2, c4) | |
| ast = c55.sink() | |
| return ast | |
| def k13(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0) | |
| c2 = UOp.range(64, 3, AxisType.LOOP, tag=()) | |
| c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1) | |
| c7 = UOp.range(6, 1, AxisType.REDUCE, tag=()) | |
| c9 = UOp.range(6, 2, AxisType.REDUCE, tag=()) | |
| c15 = UOp.range(512, 0, AxisType.REDUCE, tag=()) | |
| c20 = c5.index(((((c7*6)+c9)+(c2*36))+(c15*2304))).load() | |
| c22 = (0.0<c20).where(c20, UOp.const(dtypes.float, 0.0)) | |
| c25 = (c22.reduce(c15, c7, c9, arg=Ops.ADD)*5.425347222222222e-05) | |
| c27 = c0.index(c2).store(c25).end(c2) | |
| ast = c27.sink() | |
| return ast | |
| def k14(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0) | |
| c2 = UOp.range(64, 3, AxisType.LOOP, tag=()) | |
| c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1) | |
| c7 = UOp.range(6, 1, AxisType.REDUCE, tag=()) | |
| c9 = UOp.range(6, 2, AxisType.REDUCE, tag=()) | |
| c15 = UOp.range(512, 0, AxisType.REDUCE, tag=()) | |
| c20 = c5.index(((((c7*6)+c9)+(c2*36))+(c15*2304))).load() | |
| c22 = (0.0<c20).where(c20, UOp.const(dtypes.float, 0.0)) | |
| c23 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2) | |
| c25 = c23.index(c2).load() | |
| c28 = (c22+(c25*-1.0)) | |
| c29 = (c28*c28) | |
| c32 = (c29.reduce(c15, c7, c9, arg=Ops.ADD)*5.425347222222222e-05) | |
| c34 = c0.index(c2).store(c32).end(c2) | |
| ast = c34.sink() | |
| return ast | |
| def k15(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=0) | |
| c2 = UOp.range(2, 4, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(2, 5, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(3, 3, AxisType.LOOP, tag=()) | |
| c11 = UOp.range(3, 2, AxisType.LOOP, tag=()) | |
| c16 = UOp.range(64, 1, AxisType.LOOP, tag=()) | |
| c18 = (c16*36) | |
| c21 = UOp.range(512, 0, AxisType.LOOP, tag=()) | |
| c23 = (c21*2304) | |
| c27 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1) | |
| c38 = c27.index((((((c7*2)+c4)+(((c11*2)+c2)*6))+c18)+c23)).load() | |
| c40 = (0.0<c38).where(c38, UOp.const(dtypes.float, 0.0)) | |
| c41 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2) | |
| c43 = c41.index(c16).load() | |
| c47 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3) | |
| c49 = c47.index(c16).load() | |
| c51 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=4) | |
| c53 = c51.index(c16).load() | |
| c59 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=5) | |
| c61 = c59.index(c16).load() | |
| c62 = ((((c40+(c43*-1.0))*c49)*(c53+1e-05).sqrt().reciprocal())+c61) | |
| c64 = c0.index(((((((c2*2)+c4)+(c7*4))+(c11*12))+c18)+c23)).store(c62).end(c21, c16, c11, c7, c2, c4) | |
| ast = c64.sink() | |
| return ast | |
| def k16(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=0) | |
| c2 = UOp.range(3, 4, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(3, 5, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(64, 3, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(512, 2, AxisType.LOOP, tag=()) | |
| c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1) | |
| c19 = UOp.range(2, 0, AxisType.REDUCE, tag=()) | |
| c21 = UOp.range(2, 1, AxisType.REDUCE, tag=()) | |
| c36 = c17.index(((((((c19*2)+c21)+(c4*4))+(c2*12))+(c7*36))+(c12*2304))).load() | |
| c37 = c36.reduce(c19, c21, arg=Ops.MAX) | |
| c39 = c0.index(((((c2*3)+c4)+(c7*9))+(c12*576))).store(c37).end(c12, c7, c2, c4) | |
| ast = c39.sink() | |
| return ast | |
| def k17(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=0) | |
| c2 = UOp.range(512, 1, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(10, 2, AxisType.LOOP, tag=()) | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=1) | |
| c10 = UOp.range(576, 0, AxisType.REDUCE, tag=()) | |
| c24 = c8.index(((((((c10//3)%3)*3)+(c10%3))+((c10//9)*9))+(c2*576))).load() | |
| c25 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5760), arg=2) | |
| c29 = c25.index(((c5*576)+c10)).load() | |
| c30 = (c24*c29) | |
| c32 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(10), arg=3) | |
| c34 = c32.index(c5).load() | |
| c35 = (c30.reduce(c10, arg=Ops.ADD)+c34) | |
| c37 = c0.index(((c2*10)+c5)).store(c35).end(c2, c5) | |
| ast = c37.sink() | |
| return ast | |
| def k18(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=0) | |
| c2 = UOp.range(512, 1, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=1) | |
| c7 = UOp.range(10, 0, AxisType.REDUCE, tag=()) | |
| c10 = c4.index(((c2*10)+c7)).load() | |
| c11 = c10.reduce(c7, arg=Ops.MAX) | |
| c13 = c0.index(c2).store(c11).end(c2) | |
| ast = c13.sink() | |
| return ast | |
| def k19(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=0) | |
| c2 = UOp.range(512, 1, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=1) | |
| c7 = UOp.range(10, 0, AxisType.REDUCE, tag=()) | |
| c10 = c4.index(((c2*10)+c7)).load() | |
| c11 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=2) | |
| c13 = c11.index(c2).load() | |
| c19 = ((c10+(c13*-1.0))*1.4426950408889634).exp2() | |
| c20 = c19.reduce(c7, arg=Ops.ADD) | |
| c22 = c0.index(c2).store(c20).end(c2) | |
| ast = c22.sink() | |
| return ast | |
| def k20(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(128000), arg=0) | |
| c2 = UOp.range(512, 1, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(250, 2, AxisType.LOOP, tag=()) | |
| c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=1) | |
| c11 = c9.index(c2).load() | |
| c13 = (60000.0*c11).cast(dtypes.int) | |
| c18 = (c13<0).where((c13+60000), c13) | |
| c21 = UOp.range(240, 0, AxisType.REDUCE, tag=()) | |
| c22 = ((c5*240)+c21) | |
| c26 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(60000), arg=2) | |
| c28 = c26.index(c22).load() | |
| c31 = (c18!=c22.cast(dtypes.int)).where(UOp.const(dtypes.uchar, 0), c28.cast(dtypes.uint).cast(dtypes.uchar)) | |
| c32 = c31.reduce(c21, arg=Ops.ADD) | |
| c34 = c0.index(((c2*250)+c5)).store(c32).end(c2, c5) | |
| ast = c34.sink() | |
| return ast | |
| def k21(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=0) | |
| c2 = UOp.range(512, 1, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(128000), arg=1) | |
| c7 = UOp.range(250, 0, AxisType.REDUCE, tag=()) | |
| c10 = c4.index(((c2*250)+c7)).load() | |
| c12 = c10.reduce(c7, arg=Ops.ADD).cast(dtypes.int) | |
| c14 = c0.index(c2).store(c12).end(c2) | |
| ast = c14.sink() | |
| return ast | |
| def k22(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=0) | |
| c3 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=1) | |
| c5 = UOp.range(512, 1, AxisType.REDUCE, tag=()) | |
| c8 = UOp.range(10, 0, AxisType.REDUCE, tag=()) | |
| c11 = c3.index(((c5*10)+c8)).load() | |
| c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=2) | |
| c14 = c12.index(c5).load() | |
| c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=3) | |
| c20 = c18.index(c5).load() | |
| c26 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=4) | |
| c28 = c26.index(c5).load() | |
| c34 = (((c11+(c14*-1.0))+((c20.log2()*0.6931471805599453)*-1.0))*((c28!=c8.cast(dtypes.int))!=True).cast(dtypes.float)) | |
| c35 = c34.reduce(c8, arg=Ops.ADD) | |
| c38 = (c35.reduce(c5, arg=Ops.ADD)*-0.001953125) | |
| c39 = c0.index(UOp.const(dtypes.index, 0)).store(c38) | |
| ast = c39.sink() | |
| return ast | |
| def k23(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=0) | |
| c2 = c0.index(UOp.const(dtypes.index, 0)) | |
| c3 = c2.load() | |
| c5 = (c3*0.9) | |
| c6 = c2.store(c5) | |
| ast = c6.sink() | |
| return ast | |
| def k24(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=0) | |
| c2 = c0.index(UOp.const(dtypes.index, 0)) | |
| c3 = c2.load() | |
| c5 = (c3*0.999) | |
| c6 = c2.store(c5) | |
| ast = c6.sink() | |
| return ast | |
| def k25(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=0) | |
| c2 = UOp.range(512, 1, AxisType.LOOP, tag=()) | |
| c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1) | |
| c7 = c5.index(c2).load() | |
| c9 = UOp.range(10, 0, AxisType.REDUCE, tag=()) | |
| c17 = (-1.0*(((c7!=c9.cast(dtypes.int))!=True).cast(dtypes.float)*-0.001953125)) | |
| c19 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=2) | |
| c21 = c19.index(c2).load() | |
| c26 = ((c17.reduce(c9, arg=Ops.ADD)*(c21*0.6931471805599453).reciprocal())*0.6931471805599453) | |
| c28 = c0.index(c2).store(c26).end(c2) | |
| ast = c28.sink() | |
| return ast | |
| def k26(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=0) | |
| c2 = UOp.range(512, 0, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(10, 1, AxisType.LOOP, tag=()) | |
| c6 = ((c2*10)+c5) | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1) | |
| c10 = c8.index(c2).load() | |
| c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=2) | |
| c20 = c18.index(c6).load() | |
| c21 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=3) | |
| c23 = c21.index(c2).load() | |
| c30 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(512), arg=4) | |
| c32 = c30.index(c2).load() | |
| c34 = ((((c10!=c5.cast(dtypes.int))!=True).cast(dtypes.float)*-0.001953125)+(((c20+(c23*-1.0))*1.4426950408889634).exp2()*c32)) | |
| c36 = c0.index(c6).store(c34).end(c2, c5) | |
| ast = c36.sink() | |
| return ast | |
| def k27(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=0) | |
| c2 = UOp.range(3, 4, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(3, 5, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(64, 3, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(512, 2, AxisType.LOOP, tag=()) | |
| c15 = ((((c2*3)+c4)+(c7*9))+(c12*576)) | |
| c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1) | |
| c19 = UOp.range(2, 0, AxisType.REDUCE, tag=()) | |
| c21 = UOp.range(2, 1, AxisType.REDUCE, tag=()) | |
| c36 = c17.index(((((((c19*2)+c21)+(c4*4))+(c2*12))+(c7*36))+(c12*2304))).load() | |
| c37 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=2) | |
| c39 = c37.index(c15).load() | |
| c43 = ((c36!=c39)!=True).cast(dtypes.float) | |
| c44 = c43.reduce(c19, c21, arg=Ops.ADD) | |
| c46 = c0.index(c15).store(c44).end(c12, c7, c2, c4) | |
| ast = c46.sink() | |
| return ast | |
| def k28(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=0) | |
| c2 = UOp.range(512, 1, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(576, 2, AxisType.LOOP, tag=()) | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5760), arg=1) | |
| c10 = UOp.range(10, 0, AxisType.REDUCE, tag=()) | |
| c14 = c8.index(((c10*576)+c5)).load() | |
| c15 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=2) | |
| c19 = c15.index(((c2*10)+c10)).load() | |
| c20 = (c14*c19) | |
| c21 = c20.reduce(c10, arg=Ops.ADD) | |
| c23 = c0.index(((c2*576)+c5)).store(c21).end(c2, c5) | |
| ast = c23.sink() | |
| return ast | |
| def k29(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=0) | |
| c2 = UOp.range(6, 2, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(6, 3, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(64, 1, AxisType.LOOP, tag=()) | |
| c9 = (c7*36) | |
| c12 = UOp.range(512, 0, AxisType.LOOP, tag=()) | |
| c14 = (c12*2304) | |
| c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1) | |
| c23 = (c4//2) | |
| c27 = (c2//2) | |
| c34 = c17.index((((((((c2%2)*2)+(c4%2))+(c23*4))+(c27*12))+c9)+c14)).load() | |
| c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=2) | |
| c44 = ((((c27*3)+c23)+(c7*9))+(c12*576)) | |
| c46 = c35.index(c44).load() | |
| c51 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=3) | |
| c53 = c51.index(c44).load() | |
| c56 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=4) | |
| c58 = c56.index(c44).load() | |
| c59 = ((((c34!=c46)!=True).cast(dtypes.float)*c53.reciprocal())*c58) | |
| c61 = c0.index(((((c2*6)+c4)+c9)+c14)).store(c59).end(c12, c7, c2, c4) | |
| ast = c61.sink() | |
| return ast | |
| def k30(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0) | |
| c2 = UOp.range(64, 3, AxisType.LOOP, tag=()) | |
| c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1) | |
| c7 = UOp.range(6, 1, AxisType.REDUCE, tag=()) | |
| c9 = UOp.range(6, 2, AxisType.REDUCE, tag=()) | |
| c15 = UOp.range(512, 0, AxisType.REDUCE, tag=()) | |
| c18 = ((((c7*6)+c9)+(c2*36))+(c15*2304)) | |
| c20 = c5.index(c18).load() | |
| c22 = (0.0<c20).where(c20, UOp.const(dtypes.float, 0.0)) | |
| c23 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2) | |
| c25 = c23.index(c2).load() | |
| c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3) | |
| c31 = c29.index(c2).load() | |
| c33 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=4) | |
| c35 = c33.index(c18).load() | |
| c36 = (((c22+(c25*-1.0))*c31)*c35) | |
| c38 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=5) | |
| c40 = c38.index(c2).load() | |
| c43 = (c40+1e-05).sqrt() | |
| c44 = c43.reciprocal() | |
| c52 = ((((c36.reduce(c15, c7, c9, arg=Ops.ADD)*c44)*c44)*(c43*2.0).reciprocal())*-5.425347222222222e-05) | |
| c54 = c0.index(c2).store(c52).end(c2) | |
| ast = c54.sink() | |
| return ast | |
| def k31(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0) | |
| c2 = UOp.range(64, 3, AxisType.LOOP, tag=()) | |
| c6 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=1) | |
| c8 = c6.index(c2).load() | |
| c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2) | |
| c11 = c9.index(c2).load() | |
| c16 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=3) | |
| c18 = UOp.range(6, 1, AxisType.REDUCE, tag=()) | |
| c20 = UOp.range(6, 2, AxisType.REDUCE, tag=()) | |
| c26 = UOp.range(512, 0, AxisType.REDUCE, tag=()) | |
| c31 = c16.index(((((c18*6)+c20)+(c2*36))+(c26*2304))).load() | |
| c34 = (-1.0*(c8*((c11+1e-05).sqrt().reciprocal()*c31))) | |
| c36 = (5.425347222222222e-05*c34.reduce(c26, c18, c20, arg=Ops.ADD)) | |
| c38 = c0.index(c2).store(c36).end(c2) | |
| ast = c38.sink() | |
| return ast | |
| def k32(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=0) | |
| c2 = UOp.range(6, 2, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(6, 3, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(64, 1, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(512, 0, AxisType.LOOP, tag=()) | |
| c15 = ((((c2*6)+c4)+(c7*36))+(c12*2304)) | |
| c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1) | |
| c20 = c18.index(c15).load() | |
| c21 = (0.0<c20) | |
| c22 = c21.where(c20, UOp.const(dtypes.float, 0.0)) | |
| c23 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2) | |
| c25 = c23.index(c7).load() | |
| c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3) | |
| c31 = c29.index(c7).load() | |
| c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=4) | |
| c37 = c35.index(c7).load() | |
| c38 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=5) | |
| c40 = c38.index(c7).load() | |
| c45 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=6) | |
| c47 = c45.index(c15).load() | |
| c51 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=7) | |
| c53 = c51.index(c7).load() | |
| c55 = c21.where((((((c22+(c25*-1.0))*c31)*2.0)+(c37*((c40+1e-05).sqrt().reciprocal()*c47)))+c53), UOp.const(dtypes.float, 0.0)) | |
| c57 = c0.index(c15).store(c55).end(c12, c7, c2, c4) | |
| ast = c57.sink() | |
| return ast | |
| def k33(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=0) | |
| c2 = UOp.range(8, 5, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(8, 6, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(64, 4, AxisType.LOOP, tag=()) | |
| c11 = UOp.range(512, 3, AxisType.LOOP, tag=()) | |
| c14 = ((((c2*8)+c4)+(c7*64))+(c11*4096)) | |
| c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=1) | |
| c19 = c17.index(c14).load() | |
| c22 = UOp.range(4, 1, AxisType.REDUCE, tag=()) | |
| c24 = ((c22*8)+c2) | |
| c26 = (c24<27) | |
| c27 = UOp.range(4, 2, AxisType.REDUCE, tag=()) | |
| c29 = ((c27*8)+c4) | |
| c30 = (c29<27) | |
| c33 = (c24%9) | |
| c36 = (c26&(c33<6)) | |
| c38 = (c29%9) | |
| c39 = (c38<6) | |
| c42 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=2) | |
| c45 = c36.where((c24//9), UOp.const(dtypes.index, Invalid)) | |
| c48 = (c30&c39) | |
| c50 = c48.where((c29//9), UOp.const(dtypes.index, Invalid)) | |
| c54 = UOp.range(64, 0, AxisType.REDUCE, tag=()) | |
| c59 = c42.index(((((c45*3)+c50)+(c7*9))+(c54*576))).load() | |
| c60 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=3) | |
| c70 = (c48&c36).where(((((c33*6)+c38)+(c54*36))+(c11*2304)), UOp.const(dtypes.index, Invalid)) | |
| c72 = c60.index(c70).load() | |
| c73 = (c59*c72) | |
| c75 = ((c26&c30)&((c36&c30)&c39)).where(c73.reduce(c54, arg=Ops.ADD), UOp.const(dtypes.float, 0.0)) | |
| c77 = (0.0<c19).where(c75.reduce(c22, c27, arg=Ops.ADD), UOp.const(dtypes.float, 0.0)) | |
| c79 = c0.index(c14).store(c77).end(c11, c7, c2, c4) | |
| ast = c79.sink() | |
| return ast | |
| def k34(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=0) | |
| c2 = UOp.range(10, 4, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(10, 5, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(32, 3, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(512, 2, AxisType.LOOP, tag=()) | |
| c15 = ((((c2*10)+c4)+(c7*100))+(c12*3200)) | |
| c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1) | |
| c19 = UOp.range(2, 0, AxisType.REDUCE, tag=()) | |
| c21 = UOp.range(2, 1, AxisType.REDUCE, tag=()) | |
| c36 = c17.index(((((((c19*2)+c21)+(c4*4))+(c2*40))+(c7*400))+(c12*12800))).load() | |
| c37 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=2) | |
| c39 = c37.index(c15).load() | |
| c43 = ((c36!=c39)!=True).cast(dtypes.float) | |
| c44 = c43.reduce(c19, c21, arg=Ops.ADD) | |
| c46 = c0.index(c15).store(c44).end(c12, c7, c2, c4) | |
| ast = c46.sink() | |
| return ast | |
| def k35(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=0) | |
| c2 = UOp.range(10, 5, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(10, 6, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(32, 4, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(512, 3, AxisType.LOOP, tag=()) | |
| c18 = UOp.range(4, 1, AxisType.REDUCE, tag=()) | |
| c20 = ((c18*10)+c2) | |
| c22 = (c20<33) | |
| c23 = UOp.range(4, 2, AxisType.REDUCE, tag=()) | |
| c25 = ((c23*10)+c4) | |
| c26 = (c25<33) | |
| c29 = (c20%11) | |
| c32 = (c22&(c29<8)) | |
| c34 = (c25%11) | |
| c35 = (c34<8) | |
| c38 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18432), arg=1) | |
| c41 = c32.where((c20//11), UOp.const(dtypes.index, Invalid)) | |
| c44 = (c26&c35) | |
| c46 = c44.where((c25//11), UOp.const(dtypes.index, Invalid)) | |
| c52 = UOp.range(64, 0, AxisType.REDUCE, tag=()) | |
| c57 = c38.index(((((c41*3)+c46)+(c7*9))+(c52*288))).load() | |
| c58 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=2) | |
| c67 = (c44&c32).where(((((c29*8)+c34)+(c52*64))+(c12*4096)), UOp.const(dtypes.index, Invalid)) | |
| c69 = c58.index(c67).load() | |
| c70 = (c57*c69) | |
| c73 = ((c22&c26)&((c32&c26)&c35)).where(c70.reduce(c52, arg=Ops.ADD), UOp.const(dtypes.float, 0.0)) | |
| c74 = c73.reduce(c18, c23, arg=Ops.ADD) | |
| c76 = c0.index(((((c2*10)+c4)+(c7*100))+(c12*3200))).store(c74).end(c12, c7, c2, c4) | |
| ast = c76.sink() | |
| return ast | |
| def k36(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=0) | |
| c2 = UOp.range(10, 4, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(2, 5, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(2, 3, AxisType.LOOP, tag=()) | |
| c11 = UOp.range(10, 2, AxisType.LOOP, tag=()) | |
| c13 = (c11*40) | |
| c16 = UOp.range(32, 1, AxisType.LOOP, tag=()) | |
| c18 = (c16*400) | |
| c21 = UOp.range(512, 0, AxisType.LOOP, tag=()) | |
| c23 = (c21*12800) | |
| c26 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1) | |
| c36 = c26.index(((((((c7*2)+c5)+(c2*4))+c13)+c18)+c23)).load() | |
| c37 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=2) | |
| c45 = ((((c11*10)+c2)+(c16*100))+(c21*3200)) | |
| c47 = c37.index(c45).load() | |
| c52 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=3) | |
| c54 = c52.index(c45).load() | |
| c57 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=4) | |
| c59 = c57.index(c45).load() | |
| c60 = ((((c36!=c47)!=True).cast(dtypes.float)*c54.reciprocal())*c59) | |
| c62 = c0.index(((((((c2*2)+c5)+(c7*20))+c13)+c18)+c23)).store(c60).end(c21, c16, c11, c7, c2, c5) | |
| ast = c62.sink() | |
| return ast | |
| def k37(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0) | |
| c2 = UOp.range(32, 3, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(256, 4, AxisType.LOOP, tag=()) | |
| c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1) | |
| c11 = UOp.range(20, 1, AxisType.REDUCE, tag=()) | |
| c13 = UOp.range(20, 2, AxisType.REDUCE, tag=()) | |
| c20 = UOp.range(2, 0, AxisType.REDUCE, tag=()) | |
| c24 = ((((c11*20)+c13)+(c2*400))+(((c5*2)+c20)*12800)) | |
| c26 = c9.index(c24).load() | |
| c28 = (0.0<c26).where(c26, UOp.const(dtypes.float, 0.0)) | |
| c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2) | |
| c31 = c29.index(c2).load() | |
| c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=3) | |
| c37 = c35.index(c2).load() | |
| c39 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=4) | |
| c41 = c39.index(c24).load() | |
| c42 = (((c28+(c31*-1.0))*c37)*c41) | |
| c43 = c42.reduce(c20, c11, c13, arg=Ops.ADD) | |
| c45 = c0.index(((c2*256)+c5)).store(c43).end(c2, c5) | |
| ast = c45.sink() | |
| return ast | |
| def k38(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0) | |
| c2 = UOp.range(32, 1, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=1) | |
| c7 = UOp.range(256, 0, AxisType.REDUCE, tag=()) | |
| c10 = c4.index(((c2*256)+c7)).load() | |
| c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2) | |
| c14 = c12.index(c2).load() | |
| c17 = (c14+1e-05).sqrt() | |
| c18 = c17.reciprocal() | |
| c26 = ((((c10.reduce(c7, arg=Ops.ADD)*c18)*c18)*(c17*2.0).reciprocal())*-4.8828125e-06) | |
| c28 = c0.index(c2).store(c26).end(c2) | |
| ast = c28.sink() | |
| return ast | |
| def k39(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0) | |
| c2 = UOp.range(32, 3, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(256, 4, AxisType.LOOP, tag=()) | |
| c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=1) | |
| c11 = c9.index(c2).load() | |
| c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2) | |
| c14 = c12.index(c2).load() | |
| c19 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=3) | |
| c21 = UOp.range(20, 1, AxisType.REDUCE, tag=()) | |
| c23 = UOp.range(20, 2, AxisType.REDUCE, tag=()) | |
| c30 = UOp.range(2, 0, AxisType.REDUCE, tag=()) | |
| c36 = c19.index(((((c21*20)+c23)+(c2*400))+(((c5*2)+c30)*12800))).load() | |
| c39 = (-1.0*(c11*((c14+1e-05).sqrt().reciprocal()*c36))) | |
| c40 = c39.reduce(c30, c21, c23, arg=Ops.ADD) | |
| c42 = c0.index(((c2*256)+c5)).store(c40).end(c2, c5) | |
| ast = c42.sink() | |
| return ast | |
| def k40(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0) | |
| c2 = UOp.range(32, 1, AxisType.LOOP, tag=()) | |
| c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=1) | |
| c8 = UOp.range(256, 0, AxisType.REDUCE, tag=()) | |
| c11 = c5.index(((c2*256)+c8)).load() | |
| c13 = (4.8828125e-06*c11.reduce(c8, arg=Ops.ADD)) | |
| c15 = c0.index(c2).store(c13).end(c2) | |
| ast = c15.sink() | |
| return ast | |
| def k41(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=0) | |
| c2 = UOp.range(20, 2, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(20, 3, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(32, 1, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(512, 0, AxisType.LOOP, tag=()) | |
| c15 = ((((c2*20)+c4)+(c7*400))+(c12*12800)) | |
| c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1) | |
| c20 = c18.index(c15).load() | |
| c21 = (0.0<c20) | |
| c22 = c21.where(c20, UOp.const(dtypes.float, 0.0)) | |
| c23 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2) | |
| c25 = c23.index(c7).load() | |
| c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=3) | |
| c31 = c29.index(c7).load() | |
| c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=4) | |
| c37 = c35.index(c7).load() | |
| c38 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=5) | |
| c40 = c38.index(c7).load() | |
| c45 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=6) | |
| c47 = c45.index(c15).load() | |
| c51 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=7) | |
| c53 = c51.index(c7).load() | |
| c55 = c21.where((((((c22+(c25*-1.0))*c31)*2.0)+(c37*((c40+1e-05).sqrt().reciprocal()*c47)))+c53), UOp.const(dtypes.float, 0.0)) | |
| c57 = c0.index(c15).store(c55).end(c12, c7, c2, c4) | |
| ast = c57.sink() | |
| return ast | |
| def k42(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=0) | |
| c2 = UOp.range(24, 5, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(24, 6, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(32, 4, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(512, 3, AxisType.LOOP, tag=()) | |
| c15 = ((((c2*24)+c4)+(c7*576))+(c12*18432)) | |
| c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=1) | |
| c20 = c18.index(c15).load() | |
| c23 = UOp.range(6, 1, AxisType.REDUCE, tag=()) | |
| c25 = ((c23*24)+c2) | |
| c27 = (c25<125) | |
| c28 = UOp.range(6, 2, AxisType.REDUCE, tag=()) | |
| c30 = ((c28*24)+c4) | |
| c31 = (c30<125) | |
| c34 = (c25%25) | |
| c37 = (c27&(c34<20)) | |
| c39 = (c30%25) | |
| c40 = (c39<20) | |
| c43 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25600), arg=2) | |
| c46 = c37.where((c25//25), UOp.const(dtypes.index, Invalid)) | |
| c49 = (c31&c40) | |
| c51 = c49.where((c30//25), UOp.const(dtypes.index, Invalid)) | |
| c55 = UOp.range(32, 0, AxisType.REDUCE, tag=()) | |
| c60 = c43.index(((((c46*5)+c51)+(c7*25))+(c55*800))).load() | |
| c61 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=3) | |
| c71 = (c49&c37).where(((((c34*20)+c39)+(c55*400))+(c12*12800)), UOp.const(dtypes.index, Invalid)) | |
| c73 = c61.index(c71).load() | |
| c74 = (c60*c73) | |
| c76 = ((c27&c31)&((c37&c31)&c40)).where(c74.reduce(c55, arg=Ops.ADD), UOp.const(dtypes.float, 0.0)) | |
| c78 = (0.0<c20).where(c76.reduce(c23, c28, arg=Ops.ADD), UOp.const(dtypes.float, 0.0)) | |
| c80 = c0.index(c15).store(c78).end(c12, c7, c2, c4) | |
| ast = c80.sink() | |
| return ast | |
| def k43(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(204800), arg=0) | |
| c2 = UOp.range(5, 5, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(256, 6, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(5, 4, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(32, 3, AxisType.LOOP, tag=()) | |
| c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(401408), arg=1) | |
| c19 = UOp.range(24, 2, AxisType.REDUCE, tag=()) | |
| c21 = UOp.range(24, 1, AxisType.REDUCE, tag=()) | |
| c28 = UOp.range(2, 0, AxisType.REDUCE, tag=()) | |
| c29 = ((c5*2)+c28) | |
| c34 = c17.index((((c2+c19)+((c7+c21)*28))+(c29*784))).load() | |
| c39 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=2) | |
| c49 = c39.index(((((c21*24)+c19)+(c12*576))+(c29*18432))).load() | |
| c50 = (c34.cast(dtypes.uchar).cast(dtypes.uint).cast(dtypes.uchar).cast(dtypes.float)*c49) | |
| c51 = c50.reduce(c28, c21, c19, arg=Ops.ADD) | |
| c53 = c0.index(((((c2*256)+c5)+(c7*1280))+(c12*6400))).store(c51).end(c12, c7, c2, c5) | |
| ast = c53.sink() | |
| return ast | |
| def k44(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(800), arg=0) | |
| c2 = UOp.range(5, 2, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(5, 3, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(32, 1, AxisType.LOOP, tag=()) | |
| c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(204800), arg=1) | |
| c15 = UOp.range(256, 0, AxisType.REDUCE, tag=()) | |
| c24 = c12.index(((((c4*256)+c15)+(c2*1280))+(c7*6400))).load() | |
| c25 = c24.reduce(c15, arg=Ops.ADD) | |
| c27 = c0.index((((c2*5)+c4)+(c7*25))).store(c25).end(c7, c2, c4) | |
| ast = c27.sink() | |
| return ast | |
| def k45(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0) | |
| c2 = UOp.range(32, 3, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(256, 4, AxisType.LOOP, tag=()) | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=1) | |
| c10 = UOp.range(24, 1, AxisType.REDUCE, tag=()) | |
| c12 = UOp.range(24, 2, AxisType.REDUCE, tag=()) | |
| c19 = UOp.range(2, 0, AxisType.REDUCE, tag=()) | |
| c25 = c8.index(((((c10*24)+c12)+(c2*576))+(((c5*2)+c19)*18432))).load() | |
| c26 = c25.reduce(c19, c10, c12, arg=Ops.ADD) | |
| c28 = c0.index(((c2*256)+c5)).store(c26).end(c2, c5) | |
| ast = c28.sink() | |
| return ast | |
| def k46(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0) | |
| c2 = UOp.range(32, 1, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=1) | |
| c7 = UOp.range(256, 0, AxisType.REDUCE, tag=()) | |
| c10 = c4.index(((c2*256)+c7)).load() | |
| c11 = c10.reduce(c7, arg=Ops.ADD) | |
| c13 = c0.index(c2).store(c11).end(c2) | |
| ast = c13.sink() | |
| return ast | |
| def k47(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(3276800), arg=0) | |
| c2 = UOp.range(5, 6, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(128, 7, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(5, 5, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(32, 4, AxisType.LOOP, tag=()) | |
| c16 = UOp.range(32, 3, AxisType.LOOP, tag=()) | |
| c22 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9437184), arg=1) | |
| c24 = UOp.range(20, 2, AxisType.REDUCE, tag=()) | |
| c26 = UOp.range(20, 1, AxisType.REDUCE, tag=()) | |
| c36 = UOp.range(4, 0, AxisType.REDUCE, tag=()) | |
| c37 = ((c5*4)+c36) | |
| c42 = c22.index(((((c2+c24)+((c7+c26)*24))+(c12*576))+(c37*18432))).load() | |
| c44 = (0.0<c42).where(c42, UOp.const(dtypes.float, 0.0)) | |
| c45 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=2) | |
| c55 = c45.index(((((c26*20)+c24)+(c16*400))+(c37*12800))).load() | |
| c56 = (c44*c55) | |
| c57 = c56.reduce(c36, c26, c24, arg=Ops.ADD) | |
| c59 = c0.index((((((c2*128)+c5)+(c7*640))+(c12*3200))+(c16*102400))).store(c57).end(c16, c12, c7, c2, c5) | |
| ast = c59.sink() | |
| return ast | |
| def k48(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25600), arg=0) | |
| c2 = UOp.range(5, 3, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(5, 4, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(32, 2, AxisType.LOOP, tag=()) | |
| c11 = UOp.range(32, 1, AxisType.LOOP, tag=()) | |
| c16 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(3276800), arg=1) | |
| c19 = UOp.range(128, 0, AxisType.REDUCE, tag=()) | |
| c31 = c16.index((((((c4*128)+c19)+(c2*640))+(c7*3200))+(c11*102400))).load() | |
| c32 = c31.reduce(c19, arg=Ops.ADD) | |
| c34 = c0.index(((((c2*5)+c4)+(c7*25))+(c11*800))).store(c32).end(c11, c7, c2, c4) | |
| ast = c34.sink() | |
| return ast | |
| def k49(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0) | |
| c2 = UOp.range(32, 3, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(256, 4, AxisType.LOOP, tag=()) | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1) | |
| c10 = UOp.range(20, 1, AxisType.REDUCE, tag=()) | |
| c12 = UOp.range(20, 2, AxisType.REDUCE, tag=()) | |
| c19 = UOp.range(2, 0, AxisType.REDUCE, tag=()) | |
| c25 = c8.index(((((c10*20)+c12)+(c2*400))+(((c5*2)+c19)*12800))).load() | |
| c26 = c25.reduce(c19, c10, c12, arg=Ops.ADD) | |
| c28 = c0.index(((c2*256)+c5)).store(c26).end(c2, c5) | |
| ast = c28.sink() | |
| return ast | |
| def k50(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8192), arg=0) | |
| c2 = UOp.range(32, 3, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(256, 4, AxisType.LOOP, tag=()) | |
| c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=1) | |
| c11 = UOp.range(20, 1, AxisType.REDUCE, tag=()) | |
| c13 = UOp.range(20, 2, AxisType.REDUCE, tag=()) | |
| c20 = UOp.range(2, 0, AxisType.REDUCE, tag=()) | |
| c24 = ((((c11*20)+c13)+(c2*400))+(((c5*2)+c20)*12800)) | |
| c26 = c9.index(c24).load() | |
| c28 = (0.0<c26).where(c26, UOp.const(dtypes.float, 0.0)) | |
| c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2) | |
| c31 = c29.index(c2).load() | |
| c35 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=3) | |
| c37 = c35.index(c2).load() | |
| c42 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6553600), arg=4) | |
| c44 = c42.index(c24).load() | |
| c46 = ((c28+(c31*-1.0))*((c37+1e-05).sqrt().reciprocal()*c44)) | |
| c47 = c46.reduce(c20, c11, c13, arg=Ops.ADD) | |
| c49 = c0.index(((c2*256)+c5)).store(c47).end(c2, c5) | |
| ast = c49.sink() | |
| return ast | |
| def k51(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2359296), arg=0) | |
| c2 = UOp.range(3, 6, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(128, 7, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(3, 5, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(32, 4, AxisType.LOOP, tag=()) | |
| c17 = UOp.range(64, 3, AxisType.LOOP, tag=()) | |
| c22 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1638400), arg=1) | |
| c24 = UOp.range(8, 2, AxisType.REDUCE, tag=()) | |
| c26 = UOp.range(8, 1, AxisType.REDUCE, tag=()) | |
| c36 = UOp.range(4, 0, AxisType.REDUCE, tag=()) | |
| c37 = ((c5*4)+c36) | |
| c42 = c22.index(((((c2+c24)+((c7+c26)*10))+(c12*100))+(c37*3200))).load() | |
| c43 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=2) | |
| c52 = c43.index(((((c26*8)+c24)+(c17*64))+(c37*4096))).load() | |
| c53 = (c42*c52) | |
| c54 = c53.reduce(c36, c26, c24, arg=Ops.ADD) | |
| c56 = c0.index((((((c2*128)+c5)+(c7*384))+(c12*1152))+(c17*36864))).store(c54).end(c17, c12, c7, c2, c5) | |
| ast = c56.sink() | |
| return ast | |
| def k52(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18432), arg=0) | |
| c2 = UOp.range(3, 3, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(3, 4, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(32, 2, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(64, 1, AxisType.LOOP, tag=()) | |
| c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2359296), arg=1) | |
| c20 = UOp.range(128, 0, AxisType.REDUCE, tag=()) | |
| c32 = c17.index((((((c4*128)+c20)+(c2*384))+(c7*1152))+(c12*36864))).load() | |
| c33 = c32.reduce(c20, arg=Ops.ADD) | |
| c35 = c0.index(((((c2*3)+c4)+(c7*9))+(c12*288))).store(c33).end(c12, c7, c2, c4) | |
| ast = c35.sink() | |
| return ast | |
| def k53(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(16384), arg=0) | |
| c2 = UOp.range(64, 3, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(256, 4, AxisType.LOOP, tag=()) | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=1) | |
| c10 = UOp.range(8, 1, AxisType.REDUCE, tag=()) | |
| c12 = UOp.range(8, 2, AxisType.REDUCE, tag=()) | |
| c18 = UOp.range(2, 0, AxisType.REDUCE, tag=()) | |
| c24 = c8.index(((((c10*8)+c12)+(c2*64))+(((c5*2)+c18)*4096))).load() | |
| c25 = c24.reduce(c18, c10, c12, arg=Ops.ADD) | |
| c27 = c0.index(((c2*256)+c5)).store(c25).end(c2, c5) | |
| ast = c27.sink() | |
| return ast | |
| def k54(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0) | |
| c2 = UOp.range(64, 1, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(16384), arg=1) | |
| c7 = UOp.range(256, 0, AxisType.REDUCE, tag=()) | |
| c10 = c4.index(((c2*256)+c7)).load() | |
| c11 = c10.reduce(c7, arg=Ops.ADD) | |
| c13 = c0.index(c2).store(c11).end(c2) | |
| ast = c13.sink() | |
| return ast | |
| def k55(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=0) | |
| c2 = UOp.range(3, 5, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(3, 6, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(64, 4, AxisType.LOOP, tag=()) | |
| c11 = UOp.range(64, 3, AxisType.LOOP, tag=()) | |
| c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=1) | |
| c19 = UOp.range(6, 2, AxisType.REDUCE, tag=()) | |
| c21 = UOp.range(6, 1, AxisType.REDUCE, tag=()) | |
| c29 = UOp.range(512, 0, AxisType.REDUCE, tag=()) | |
| c34 = c17.index(((((c4+c19)+((c2+c21)*8))+(c7*64))+(c29*4096))).load() | |
| c36 = (0.0<c34).where(c34, UOp.const(dtypes.float, 0.0)) | |
| c37 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=2) | |
| c47 = c37.index(((((c21*6)+c19)+(c11*36))+(c29*2304))).load() | |
| c48 = (c36*c47) | |
| c49 = c48.reduce(c29, c21, c19, arg=Ops.ADD) | |
| c51 = c0.index(((((c2*3)+c4)+(c7*9))+(c11*576))).store(c49).end(c11, c7, c2, c4) | |
| ast = c51.sink() | |
| return ast | |
| def k56(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0) | |
| c2 = UOp.range(64, 3, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1) | |
| c6 = UOp.range(6, 1, AxisType.REDUCE, tag=()) | |
| c8 = UOp.range(6, 2, AxisType.REDUCE, tag=()) | |
| c14 = UOp.range(512, 0, AxisType.REDUCE, tag=()) | |
| c19 = c4.index(((((c6*6)+c8)+(c2*36))+(c14*2304))).load() | |
| c20 = c19.reduce(c14, c6, c8, arg=Ops.ADD) | |
| c22 = c0.index(c2).store(c20).end(c2) | |
| ast = c22.sink() | |
| return ast | |
| def k57(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0) | |
| c2 = UOp.range(64, 3, AxisType.LOOP, tag=()) | |
| c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=1) | |
| c7 = UOp.range(6, 1, AxisType.REDUCE, tag=()) | |
| c9 = UOp.range(6, 2, AxisType.REDUCE, tag=()) | |
| c15 = UOp.range(512, 0, AxisType.REDUCE, tag=()) | |
| c18 = ((((c7*6)+c9)+(c2*36))+(c15*2304)) | |
| c20 = c5.index(c18).load() | |
| c22 = (0.0<c20).where(c20, UOp.const(dtypes.float, 0.0)) | |
| c23 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=2) | |
| c25 = c23.index(c2).load() | |
| c29 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3) | |
| c31 = c29.index(c2).load() | |
| c36 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=4) | |
| c38 = c36.index(c18).load() | |
| c40 = ((c22+(c25*-1.0))*((c31+1e-05).sqrt().reciprocal()*c38)) | |
| c41 = c40.reduce(c15, c7, c9, arg=Ops.ADD) | |
| c43 = c0.index(c2).store(c41).end(c2) | |
| ast = c43.sink() | |
| return ast | |
| def k58(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5760), arg=0) | |
| c2 = UOp.range(10, 1, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(576, 2, AxisType.LOOP, tag=()) | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=1) | |
| c20 = UOp.range(512, 0, AxisType.REDUCE, tag=()) | |
| c24 = c8.index(((((((c5//3)%3)*3)+(c5%3))+((c5//9)*9))+(c20*576))).load() | |
| c25 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=2) | |
| c29 = c25.index(((c20*10)+c2)).load() | |
| c30 = (c24*c29) | |
| c31 = c30.reduce(c20, arg=Ops.ADD) | |
| c33 = c0.index(((c2*576)+c5)).store(c31).end(c2, c5) | |
| ast = c33.sink() | |
| return ast | |
| def k59(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(10), arg=0) | |
| c2 = UOp.range(10, 1, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5120), arg=1) | |
| c6 = UOp.range(512, 0, AxisType.REDUCE, tag=()) | |
| c10 = c4.index(((c6*10)+c2)).load() | |
| c11 = c10.reduce(c6, arg=Ops.ADD) | |
| c13 = c0.index(c2).store(c11).end(c2) | |
| ast = c13.sink() | |
| return ast | |
| def k60(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=0) | |
| c2 = UOp.range(87850, 0, AxisType.LOOP, tag=()) | |
| c5 = (c2<800) | |
| c6 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(800), arg=1) | |
| c18 = c5.where((((((c2//5)%5)*5)+(c2%5))+((c2//25)*25)), UOp.const(dtypes.index, Invalid)) | |
| c20 = c6.index(c18).load() | |
| c22 = c5.where(c20, UOp.const(dtypes.float, 0.0)) | |
| c26 = (c2<832) | |
| c27 = ((c5!=True)&c26) | |
| c28 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2) | |
| c31 = c27.where((c2+-800), UOp.const(dtypes.index, Invalid)) | |
| c33 = c28.index(c31).load() | |
| c34 = c27.where(c33, UOp.const(dtypes.float, 0.0)) | |
| c38 = (c2<26432) | |
| c39 = ((c26!=True)&c38) | |
| c40 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25600), arg=3) | |
| c42 = (c2+3) | |
| c65 = c39.where(((((((((c42//5)+3)%5)*5)+(c42%5))+(((((c2+18)//25)+30)%32)*25))+(((c2+768)//800)*800))+-1600), UOp.const(dtypes.index, Invalid)) | |
| c67 = c40.index(c65).load() | |
| c68 = c39.where(c67, UOp.const(dtypes.float, 0.0)) | |
| c72 = (c2<26464) | |
| c73 = ((c38!=True)&c72) | |
| c74 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=4) | |
| c77 = c73.where((c2+-26432), UOp.const(dtypes.index, Invalid)) | |
| c79 = c74.index(c77).load() | |
| c80 = c73.where(c79, UOp.const(dtypes.float, 0.0)) | |
| c84 = (c2<26496) | |
| c85 = ((c72!=True)&c84) | |
| c86 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=5) | |
| c89 = c85.where((c2+-26464), UOp.const(dtypes.index, Invalid)) | |
| c91 = c86.index(c89).load() | |
| c92 = c85.where(c91, UOp.const(dtypes.float, 0.0)) | |
| c96 = (c2<26528) | |
| c97 = ((c84!=True)&c96) | |
| c98 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=6) | |
| c101 = c97.where((c2+-26496), UOp.const(dtypes.index, Invalid)) | |
| c103 = c98.index(c101).load() | |
| c104 = c97.where(c103, UOp.const(dtypes.float, 0.0)) | |
| c108 = (c2<44960) | |
| c109 = ((c96!=True)&c108) | |
| c110 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18432), arg=7) | |
| c112 = (c2+1) | |
| c136 = c109.where(((((((((c112//3)+1)%3)*3)+(c112%3))+(((((c2+4)//9)+28)%32)*9))+(((c2+256)//288)*288))+-26784), UOp.const(dtypes.index, Invalid)) | |
| c138 = c110.index(c136).load() | |
| c139 = c109.where(c138, UOp.const(dtypes.float, 0.0)) | |
| c143 = (c2<45024) | |
| c144 = ((c108!=True)&c143) | |
| c145 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=8) | |
| c148 = c144.where((c2+-44960), UOp.const(dtypes.index, Invalid)) | |
| c150 = c145.index(c148).load() | |
| c151 = c144.where(c150, UOp.const(dtypes.float, 0.0)) | |
| c155 = (c2<81888) | |
| c156 = ((c143!=True)&c155) | |
| c157 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=9) | |
| c179 = c156.where(((((((((c2//3)+1)%3)*3)+(c2%3))+((((c42//9)+53)%64)*9))+(((c2+480)//576)*576))+-45504), UOp.const(dtypes.index, Invalid)) | |
| c181 = c157.index(c179).load() | |
| c182 = c156.where(c181, UOp.const(dtypes.float, 0.0)) | |
| c186 = (c2<81952) | |
| c187 = ((c155!=True)&c186) | |
| c188 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=10) | |
| c191 = c187.where((c2+-81888), UOp.const(dtypes.index, Invalid)) | |
| c193 = c188.index(c191).load() | |
| c194 = c187.where(c193, UOp.const(dtypes.float, 0.0)) | |
| c198 = (c2<82016) | |
| c199 = ((c186!=True)&c198) | |
| c200 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=11) | |
| c203 = c199.where((c2+-81952), UOp.const(dtypes.index, Invalid)) | |
| c205 = c200.index(c203).load() | |
| c206 = c199.where(c205, UOp.const(dtypes.float, 0.0)) | |
| c210 = (c2<82080) | |
| c211 = ((c198!=True)&c210) | |
| c212 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=12) | |
| c215 = c211.where((c2+-82016), UOp.const(dtypes.index, Invalid)) | |
| c217 = c212.index(c215).load() | |
| c218 = c211.where(c217, UOp.const(dtypes.float, 0.0)) | |
| c222 = (c2<87840) | |
| c223 = ((c210!=True)&c222) | |
| c224 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5760), arg=13) | |
| c227 = c223.where((c2+-82080), UOp.const(dtypes.index, Invalid)) | |
| c229 = c224.index(c227).load() | |
| c230 = c223.where(c229, UOp.const(dtypes.float, 0.0)) | |
| c232 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(10), arg=14) | |
| c236 = (c222!=True).where((c2+-87840), UOp.const(dtypes.index, Invalid)) | |
| c238 = c232.index(c236).load() | |
| c239 = c222.where(UOp.const(dtypes.float, 0.0), c238) | |
| c240 = (((((((((((((c22+c34)+c68)+c80)+c92)+c104)+c139)+c151)+c182)+c194)+c206)+c218)+c230)+c239) | |
| c242 = c0.index(c2).store(c240).end(c2) | |
| ast = c242.sink() | |
| return ast | |
| def k61(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=0) | |
| c2 = UOp.range(87850, 0, AxisType.LOOP, tag=()) | |
| c3 = c0.index(c2) | |
| c5 = c3.load() | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c10 = c8.index(c2).load() | |
| c12 = ((0.9*c5)+(0.09999999999999998*c10)) | |
| c14 = c3.store(c12).end(c2) | |
| ast = c14.sink() | |
| return ast | |
| def k62(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=0) | |
| c2 = UOp.range(87850, 0, AxisType.LOOP, tag=()) | |
| c3 = c0.index(c2) | |
| c5 = c3.load() | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c10 = c8.index(c2).load() | |
| c13 = ((0.999*c5)+(0.0010000000000000009*(c10*c10))) | |
| c15 = c3.store(c13).end(c2) | |
| ast = c15.sink() | |
| return ast | |
| def k63(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=0) | |
| c2 = UOp.range(87850, 0, AxisType.LOOP, tag=()) | |
| c5 = (c2<800) | |
| c6 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(800), arg=1) | |
| c8 = c5.where(c2, UOp.const(dtypes.index, Invalid)) | |
| c10 = c6.index(c8).load() | |
| c12 = c5.where(c10, UOp.const(dtypes.float, 0.0)) | |
| c16 = (c2<832) | |
| c17 = ((c5!=True)&c16) | |
| c18 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=2) | |
| c21 = c17.where((c2+-800), UOp.const(dtypes.index, Invalid)) | |
| c23 = c18.index(c21).load() | |
| c24 = c17.where(c23, UOp.const(dtypes.float, 0.0)) | |
| c28 = (c2<26432) | |
| c29 = ((c16!=True)&c28) | |
| c30 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25600), arg=3) | |
| c33 = c29.where((c2+-832), UOp.const(dtypes.index, Invalid)) | |
| c35 = c30.index(c33).load() | |
| c36 = c29.where(c35, UOp.const(dtypes.float, 0.0)) | |
| c40 = (c2<26464) | |
| c41 = ((c28!=True)&c40) | |
| c42 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=4) | |
| c45 = c41.where((c2+-26432), UOp.const(dtypes.index, Invalid)) | |
| c47 = c42.index(c45).load() | |
| c48 = c41.where(c47, UOp.const(dtypes.float, 0.0)) | |
| c52 = (c2<26496) | |
| c53 = ((c40!=True)&c52) | |
| c54 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=5) | |
| c57 = c53.where((c2+-26464), UOp.const(dtypes.index, Invalid)) | |
| c59 = c54.index(c57).load() | |
| c60 = c53.where(c59, UOp.const(dtypes.float, 0.0)) | |
| c64 = (c2<26528) | |
| c65 = ((c52!=True)&c64) | |
| c66 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=6) | |
| c69 = c65.where((c2+-26496), UOp.const(dtypes.index, Invalid)) | |
| c71 = c66.index(c69).load() | |
| c72 = c65.where(c71, UOp.const(dtypes.float, 0.0)) | |
| c76 = (c2<44960) | |
| c77 = ((c64!=True)&c76) | |
| c78 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18432), arg=7) | |
| c81 = c77.where((c2+-26528), UOp.const(dtypes.index, Invalid)) | |
| c83 = c78.index(c81).load() | |
| c84 = c77.where(c83, UOp.const(dtypes.float, 0.0)) | |
| c88 = (c2<45024) | |
| c89 = ((c76!=True)&c88) | |
| c90 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=8) | |
| c93 = c89.where((c2+-44960), UOp.const(dtypes.index, Invalid)) | |
| c95 = c90.index(c93).load() | |
| c96 = c89.where(c95, UOp.const(dtypes.float, 0.0)) | |
| c100 = (c2<81888) | |
| c101 = ((c88!=True)&c100) | |
| c102 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=9) | |
| c105 = c101.where((c2+-45024), UOp.const(dtypes.index, Invalid)) | |
| c107 = c102.index(c105).load() | |
| c108 = c101.where(c107, UOp.const(dtypes.float, 0.0)) | |
| c112 = (c2<81952) | |
| c113 = ((c100!=True)&c112) | |
| c114 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=10) | |
| c117 = c113.where((c2+-81888), UOp.const(dtypes.index, Invalid)) | |
| c119 = c114.index(c117).load() | |
| c120 = c113.where(c119, UOp.const(dtypes.float, 0.0)) | |
| c124 = (c2<82016) | |
| c125 = ((c112!=True)&c124) | |
| c126 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=11) | |
| c129 = c125.where((c2+-81952), UOp.const(dtypes.index, Invalid)) | |
| c131 = c126.index(c129).load() | |
| c132 = c125.where(c131, UOp.const(dtypes.float, 0.0)) | |
| c136 = (c2<82080) | |
| c137 = ((c124!=True)&c136) | |
| c138 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=12) | |
| c141 = c137.where((c2+-82016), UOp.const(dtypes.index, Invalid)) | |
| c143 = c138.index(c141).load() | |
| c144 = c137.where(c143, UOp.const(dtypes.float, 0.0)) | |
| c148 = (c2<87840) | |
| c149 = ((c136!=True)&c148) | |
| c150 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5760), arg=13) | |
| c153 = c149.where((c2+-82080), UOp.const(dtypes.index, Invalid)) | |
| c155 = c150.index(c153).load() | |
| c156 = c149.where(c155, UOp.const(dtypes.float, 0.0)) | |
| c158 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(10), arg=14) | |
| c162 = (c148!=True).where((c2+-87840), UOp.const(dtypes.index, Invalid)) | |
| c164 = c158.index(c162).load() | |
| c165 = c148.where(UOp.const(dtypes.float, 0.0), c164) | |
| c167 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=15) | |
| c170 = c167.index(UOp.const(dtypes.index, 0)).load() | |
| c171 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=16) | |
| c173 = c171.index(c2).load() | |
| c175 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=17) | |
| c177 = c175.index(UOp.const(dtypes.index, 0)).load() | |
| c181 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=18) | |
| c183 = c181.index(c2).load() | |
| c184 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=19) | |
| c186 = c184.index(UOp.const(dtypes.index, 0)).load() | |
| c199 = ((((((((((((((c12+c24)+c36)+c48)+c60)+c72)+c84)+c96)+c108)+c120)+c132)+c144)+c156)+c165)+((c170*(c173*((1.0+(c177*-1.0))*((c183*(1.0+(c186*-1.0)).reciprocal()).sqrt()+1e-08)).reciprocal()))*-1.0)) | |
| c201 = c0.index(c2).store(c199).end(c2) | |
| ast = c201.sink() | |
| return ast | |
| def k64(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(800), arg=0) | |
| c2 = UOp.range(5, 1, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(5, 2, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(32, 0, AxisType.LOOP, tag=()) | |
| c10 = (((c2*5)+c4)+(c7*25)) | |
| c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c14 = c12.index(c10).load() | |
| c16 = c0.index(c10).store(c14).end(c7, c2, c4) | |
| ast = c16.sink() | |
| return ast | |
| def k65(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0) | |
| c2 = UOp.range(32, 0, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c8 = c4.index((c2+800)).load() | |
| c10 = c0.index(c2).store(c8).end(c2) | |
| ast = c10.sink() | |
| return ast | |
| def k66(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25600), arg=0) | |
| c2 = UOp.range(5, 2, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(5, 3, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(32, 1, AxisType.LOOP, tag=()) | |
| c11 = UOp.range(32, 0, AxisType.LOOP, tag=()) | |
| c14 = ((((c2*5)+c4)+(c7*25))+(c11*800)) | |
| c16 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c20 = c16.index((c14+832)).load() | |
| c22 = c0.index(c14).store(c20).end(c11, c7, c2, c4) | |
| ast = c22.sink() | |
| return ast | |
| def k67(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0) | |
| c2 = UOp.range(32, 0, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c8 = c4.index((c2+26432)).load() | |
| c10 = c0.index(c2).store(c8).end(c2) | |
| ast = c10.sink() | |
| return ast | |
| def k68(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0) | |
| c2 = UOp.range(32, 0, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c8 = c4.index((c2+26464)).load() | |
| c10 = c0.index(c2).store(c8).end(c2) | |
| ast = c10.sink() | |
| return ast | |
| def k69(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0) | |
| c2 = UOp.range(32, 0, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c8 = c4.index((c2+26496)).load() | |
| c10 = c0.index(c2).store(c8).end(c2) | |
| ast = c10.sink() | |
| return ast | |
| def k70(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18432), arg=0) | |
| c2 = UOp.range(3, 2, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(3, 3, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(32, 1, AxisType.LOOP, tag=()) | |
| c12 = UOp.range(64, 0, AxisType.LOOP, tag=()) | |
| c15 = ((((c2*3)+c4)+(c7*9))+(c12*288)) | |
| c17 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c21 = c17.index((c15+26528)).load() | |
| c23 = c0.index(c15).store(c21).end(c12, c7, c2, c4) | |
| ast = c23.sink() | |
| return ast | |
| def k71(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0) | |
| c2 = UOp.range(64, 0, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c8 = c4.index((c2+44960)).load() | |
| c10 = c0.index(c2).store(c8).end(c2) | |
| ast = c10.sink() | |
| return ast | |
| def k72(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=0) | |
| c2 = UOp.range(3, 2, AxisType.LOOP, tag=()) | |
| c4 = UOp.range(3, 3, AxisType.LOOP, tag=()) | |
| c7 = UOp.range(64, 1, AxisType.LOOP, tag=()) | |
| c11 = UOp.range(64, 0, AxisType.LOOP, tag=()) | |
| c14 = ((((c2*3)+c4)+(c7*9))+(c11*576)) | |
| c16 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c20 = c16.index((c14+45024)).load() | |
| c22 = c0.index(c14).store(c20).end(c11, c7, c2, c4) | |
| ast = c22.sink() | |
| return ast | |
| def k73(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0) | |
| c2 = UOp.range(64, 0, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c8 = c4.index((c2+81888)).load() | |
| c10 = c0.index(c2).store(c8).end(c2) | |
| ast = c10.sink() | |
| return ast | |
| def k74(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0) | |
| c2 = UOp.range(64, 0, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c8 = c4.index((c2+81952)).load() | |
| c10 = c0.index(c2).store(c8).end(c2) | |
| ast = c10.sink() | |
| return ast | |
| def k75(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0) | |
| c2 = UOp.range(64, 0, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c8 = c4.index((c2+82016)).load() | |
| c10 = c0.index(c2).store(c8).end(c2) | |
| ast = c10.sink() | |
| return ast | |
| def k76(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5760), arg=0) | |
| c2 = UOp.range(10, 0, AxisType.LOOP, tag=()) | |
| c5 = UOp.range(576, 1, AxisType.LOOP, tag=()) | |
| c6 = ((c2*576)+c5) | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c12 = c8.index((c6+82080)).load() | |
| c14 = c0.index(c6).store(c12).end(c2, c5) | |
| ast = c14.sink() | |
| return ast | |
| def k77(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(10), arg=0) | |
| c2 = UOp.range(10, 0, AxisType.LOOP, tag=()) | |
| c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(87850), arg=1) | |
| c8 = c4.index((c2+87840)).load() | |
| c10 = c0.index(c2).store(c8).end(c2) | |
| ast = c10.sink() | |
| return ast | |
| def k78(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(1), arg=0) | |
| c2 = c0.index(UOp.const(dtypes.index, 0)) | |
| c3 = c2.load() | |
| c5 = (c3+1) | |
| c6 = c2.store(c5) | |
| ast = c6.sink() | |
| return ast | |
| def k79(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0) | |
| c2 = UOp.range(32, 0, AxisType.LOOP, tag=()) | |
| c3 = c0.index(c2) | |
| c5 = c3.load() | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=1) | |
| c10 = c8.index(c2).load() | |
| c12 = ((0.9*c5)+(0.1*c10)) | |
| c14 = c3.store(c12).end(c2) | |
| ast = c14.sink() | |
| return ast | |
| def k80(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=0) | |
| c2 = UOp.range(32, 0, AxisType.LOOP, tag=()) | |
| c3 = c0.index(c2) | |
| c5 = c3.load() | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(32), arg=1) | |
| c10 = c8.index(c2).load() | |
| c12 = ((0.9*c5)+(0.1000004882836342*c10)) | |
| c14 = c3.store(c12).end(c2) | |
| ast = c14.sink() | |
| return ast | |
| def k81(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0) | |
| c2 = UOp.range(64, 0, AxisType.LOOP, tag=()) | |
| c3 = c0.index(c2) | |
| c5 = c3.load() | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=1) | |
| c10 = c8.index(c2).load() | |
| c12 = ((0.9*c5)+(0.1*c10)) | |
| c14 = c3.store(c12).end(c2) | |
| ast = c14.sink() | |
| return ast | |
| def k82(): | |
| c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0) | |
| c2 = UOp.range(64, 0, AxisType.LOOP, tag=()) | |
| c3 = c0.index(c2) | |
| c5 = c3.load() | |
| c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=1) | |
| c10 = c8.index(c2).load() | |
| c12 = ((0.9*c5)+(0.10000542564158212*c10)) | |
| c14 = c3.store(c12).end(c2) | |
| ast = c14.sink() | |
| return ast | |
| c2 = UOp.new_buffer('METAL', 1, dtypes.uint, 17) | |
| c3 = UOp(Ops.KERNEL, src=(c2,), arg=Kernel(k0(), (Metadata(name='randint', caller='', backward=False),))) | |
| c4 = c2.after(c3) | |
| c7 = c4.forced_reshape((1,)).forced_reshape((1,)) | |
| c9 = UOp.new_buffer('METAL', 512, dtypes.float, 126) | |
| c11 = UOp.new_buffer('METAL', 2, dtypes.uint, 18) | |
| c12 = UOp(Ops.KERNEL, src=(c9, c4, c11), arg=Kernel(k1(), (Metadata(name='randint', caller='', backward=False),))) | |
| c13 = c9.after(c12) | |
| c15 = c13.forced_reshape((512,)) | |
| c17 = UOp.new_buffer('METAL', 9437184, dtypes.float, 127) | |
| c19 = UOp.new_buffer('METAL', 401408, dtypes.uint, 128) | |
| c21 = UOp.new_buffer('METAL', 4014080, dtypes.uchar, 129) | |
| c23 = UOp.new_buffer('METAL', 47040000, dtypes.uchar, 20) | |
| c24 = UOp(Ops.KERNEL, src=(c21, c13, c23), arg=Kernel(k2(), (Metadata(name='randint', caller='', backward=False), Metadata(name='__getitem__', caller='', backward=False)))) | |
| c26 = UOp(Ops.KERNEL, src=(c19, c21.after(c24)), arg=Kernel(k3(), (Metadata(name='conv2d', caller='', backward=False),))) | |
| c27 = c19.after(c26) | |
| c29 = UOp.new_buffer('METAL', 800, dtypes.float, 107) | |
| c31 = UOp.new_buffer('METAL', 32, dtypes.float, 108) | |
| c32 = UOp(Ops.KERNEL, src=(c17, c27, c29, c31), arg=Kernel(k4(), (Metadata(name='conv2d', caller='', backward=False),))) | |
| c33 = c17.after(c32) | |
| c35 = c33.reshape((512,32,24,24)) | |
| c37 = UOp.new_buffer('METAL', 6553600, dtypes.float, 130) | |
| c39 = UOp.new_buffer('METAL', 25600, dtypes.float, 109) | |
| c41 = UOp.new_buffer('METAL', 32, dtypes.float, 110) | |
| c42 = UOp(Ops.KERNEL, src=(c37, c33, c39, c41), arg=Kernel(k5(), (Metadata(name='relu', caller='', backward=False), Metadata(name='conv2d', caller='', backward=False)))) | |
| c43 = c37.after(c42) | |
| c45 = c43.reshape((512,32,20,20)) | |
| c47 = UOp.new_buffer('METAL', 32, dtypes.float, 131) | |
| c49 = UOp.new_buffer('METAL', 8192, dtypes.float, 132) | |
| c50 = UOp(Ops.KERNEL, src=(c49, c43), arg=Kernel(k6(), (Metadata(name='relu', caller='', backward=False),))) | |
| c52 = UOp(Ops.KERNEL, src=(c47, c49.after(c50)), arg=Kernel(k7(), (Metadata(name='mean', caller='', backward=False),))) | |
| c53 = c47.after(c52) | |
| c55 = c53.forced_reshape((32,)) | |
| c57 = UOp.new_buffer('METAL', 32, dtypes.float, 133) | |
| c59 = UOp.new_buffer('METAL', 8192, dtypes.float, 134) | |
| c60 = UOp(Ops.KERNEL, src=(c59, c43, c53), arg=Kernel(k8(), (Metadata(name='relu', caller='', backward=False), Metadata(name='__sub__', caller='', backward=False), Metadata(name='__mul__', caller='', backward=False)))) | |
| c62 = UOp(Ops.KERNEL, src=(c57, c59.after(c60)), arg=Kernel(k7(), (Metadata(name='mean', caller='', backward=False),))) | |
| c63 = c57.after(c62) | |
| c64 = c63.forced_reshape((32,)) | |
| c66 = UOp.new_buffer('METAL', 2097152, dtypes.float, 135) | |
| c68 = UOp.new_buffer('METAL', 1638400, dtypes.float, 136) | |
| c70 = UOp.new_buffer('METAL', 6553600, dtypes.float, 137) | |
| c72 = UOp.new_buffer('METAL', 32, dtypes.float, 111) | |
| c74 = UOp.new_buffer('METAL', 32, dtypes.float, 112) | |
| c75 = UOp(Ops.KERNEL, src=(c70, c43, c53, c72, c63, c74), arg=Kernel(k9(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='relu', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=False), Metadata(name='max_pool2d', caller='', backward=False)))) | |
| c76 = c70.after(c75) | |
| c77 = UOp(Ops.KERNEL, src=(c68, c76), arg=Kernel(k10(), (Metadata(name='max_pool2d', caller='', backward=False),))) | |
| c78 = c68.after(c77) | |
| c80 = UOp.new_buffer('METAL', 18432, dtypes.float, 113) | |
| c82 = UOp.new_buffer('METAL', 64, dtypes.float, 114) | |
| c83 = UOp(Ops.KERNEL, src=(c66, c78, c80, c82), arg=Kernel(k11(), (Metadata(name='conv2d', caller='', backward=False),))) | |
| c84 = c66.after(c83) | |
| c86 = c84.reshape((512,64,8,8)) | |
| c88 = UOp.new_buffer('METAL', 1179648, dtypes.float, 138) | |
| c90 = UOp.new_buffer('METAL', 36864, dtypes.float, 115) | |
| c92 = UOp.new_buffer('METAL', 64, dtypes.float, 116) | |
| c93 = UOp(Ops.KERNEL, src=(c88, c84, c90, c92), arg=Kernel(k12(), (Metadata(name='relu', caller='', backward=False), Metadata(name='conv2d', caller='', backward=False)))) | |
| c94 = c88.after(c93) | |
| c96 = c94.reshape((512,64,6,6)) | |
| c98 = UOp.new_buffer('METAL', 64, dtypes.float, 139) | |
| c99 = UOp(Ops.KERNEL, src=(c98, c94), arg=Kernel(k13(), (Metadata(name='relu', caller='', backward=False), Metadata(name='mean', caller='', backward=False)))) | |
| c100 = c98.after(c99) | |
| c102 = c100.forced_reshape((64,)) | |
| c104 = UOp.new_buffer('METAL', 64, dtypes.float, 140) | |
| c105 = UOp(Ops.KERNEL, src=(c104, c94, c100), arg=Kernel(k14(), (Metadata(name='relu', caller='', backward=False), Metadata(name='__sub__', caller='', backward=False), Metadata(name='__mul__', caller='', backward=False), Metadata(name='mean', caller='', backward=False)))) | |
| c106 = c104.after(c105) | |
| c107 = c106.forced_reshape((64,)) | |
| c109 = UOp.new_buffer('METAL', 5120, dtypes.float, 141) | |
| c111 = UOp.new_buffer('METAL', 294912, dtypes.float, 142) | |
| c113 = UOp.new_buffer('METAL', 1179648, dtypes.float, 143) | |
| c115 = UOp.new_buffer('METAL', 64, dtypes.float, 117) | |
| c117 = UOp.new_buffer('METAL', 64, dtypes.float, 118) | |
| c118 = UOp(Ops.KERNEL, src=(c113, c94, c100, c115, c106, c117), arg=Kernel(k15(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='relu', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=False), Metadata(name='max_pool2d', caller='', backward=False)))) | |
| c119 = c113.after(c118) | |
| c120 = UOp(Ops.KERNEL, src=(c111, c119), arg=Kernel(k16(), (Metadata(name='max_pool2d', caller='', backward=False),))) | |
| c121 = c111.after(c120) | |
| c123 = UOp.new_buffer('METAL', 5760, dtypes.float, 119) | |
| c125 = UOp.new_buffer('METAL', 10, dtypes.float, 120) | |
| c126 = UOp(Ops.KERNEL, src=(c109, c121, c123, c125), arg=Kernel(k17(), (Metadata(name='linear', caller='', backward=False),))) | |
| c127 = c109.after(c126) | |
| c129 = c127.reshape((512,10)) | |
| c131 = UOp.new_buffer('METAL', 1, dtypes.float, 144) | |
| c133 = UOp.new_buffer('METAL', 512, dtypes.float, 145) | |
| c134 = UOp(Ops.KERNEL, src=(c133, c127), arg=Kernel(k18(), (Metadata(name='sparse_categorical_crossentropy', caller='', backward=False),))) | |
| c135 = c133.after(c134) | |
| c137 = UOp.new_buffer('METAL', 512, dtypes.float, 146) | |
| c138 = UOp(Ops.KERNEL, src=(c137, c127, c135), arg=Kernel(k19(), (Metadata(name='sparse_categorical_crossentropy', caller='', backward=False),))) | |
| c139 = c137.after(c138) | |
| c141 = UOp.new_buffer('METAL', 512, dtypes.int, 147) | |
| c143 = UOp.new_buffer('METAL', 128000, dtypes.uchar, 148) | |
| c145 = UOp.new_buffer('METAL', 60000, dtypes.uchar, 47) | |
| c146 = UOp(Ops.KERNEL, src=(c143, c13, c145), arg=Kernel(k20(), (Metadata(name='randint', caller='', backward=False), Metadata(name='__getitem__', caller='', backward=False)))) | |
| c148 = UOp(Ops.KERNEL, src=(c141, c143.after(c146)), arg=Kernel(k21(), (Metadata(name='sparse_categorical_crossentropy', caller='', backward=False),))) | |
| c149 = c141.after(c148) | |
| c150 = UOp(Ops.KERNEL, src=(c131, c127, c135, c139, c149), arg=Kernel(k22(), (Metadata(name='sparse_categorical_crossentropy', caller='', backward=False),))) | |
| c152 = UOp(Ops.VECTORIZE, dtypes.index.vec(0)) | |
| c153 = c131.after(c150).reshape(()) | |
| c155 = UOp.new_buffer('METAL', 1, dtypes.float, 55) | |
| c156 = UOp(Ops.KERNEL, src=(c155,), arg=Kernel(k23(), (Metadata(name='__imul__', caller='', backward=False),))) | |
| c157 = c155.after(c156) | |
| c159 = c157.forced_reshape((1,)).forced_reshape((1,)) | |
| c161 = UOp.new_buffer('METAL', 1, dtypes.float, 56) | |
| c162 = UOp(Ops.KERNEL, src=(c161,), arg=Kernel(k24(), (Metadata(name='__imul__', caller='', backward=False),))) | |
| c163 = c161.after(c162) | |
| c165 = c163.forced_reshape((1,)).forced_reshape((1,)) | |
| c167 = UOp.new_buffer('METAL', 5120, dtypes.float, 149) | |
| c169 = UOp.new_buffer('METAL', 512, dtypes.float, 150) | |
| c170 = UOp(Ops.KERNEL, src=(c169, c149, c139), arg=Kernel(k25(), (Metadata(name='sparse_categorical_crossentropy', caller='', backward=False), Metadata(name='sparse_categorical_crossentropy', caller='', backward=True)))) | |
| c172 = UOp(Ops.KERNEL, src=(c167, c149, c127, c135, c169.after(c170)), arg=Kernel(k26(), (Metadata(name='sparse_categorical_crossentropy', caller='', backward=False), Metadata(name='sparse_categorical_crossentropy', caller='', backward=True)))) | |
| c173 = c167.after(c172) | |
| c174 = c173.reshape((512,10)) | |
| c176 = UOp.new_buffer('METAL', 64, dtypes.float, 151) | |
| c178 = UOp.new_buffer('METAL', 1179648, dtypes.float, 152) | |
| c180 = UOp.new_buffer('METAL', 294912, dtypes.float, 153) | |
| c181 = UOp(Ops.KERNEL, src=(c180, c119, c121), arg=Kernel(k27(), (Metadata(name='max_pool2d', caller='', backward=True),))) | |
| c184 = UOp.new_buffer('METAL', 294912, dtypes.float, 154) | |
| c185 = UOp(Ops.KERNEL, src=(c184, c123, c173), arg=Kernel(k28(), (Metadata(name='linear', caller='', backward=True),))) | |
| c187 = UOp(Ops.KERNEL, src=(c178, c119, c121, c180.after(c181), c184.after(c185)), arg=Kernel(k29(), (Metadata(name='max_pool2d', caller='', backward=True),))) | |
| c188 = c178.after(c187) | |
| c189 = UOp(Ops.KERNEL, src=(c176, c94, c100, c115, c188, c106), arg=Kernel(k30(), (Metadata(name='rsqrt', caller='', backward=True), Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='relu', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True), Metadata(name='mean', caller='', backward=True)))) | |
| c190 = c176.after(c189) | |
| c191 = c190.forced_reshape((64,)) | |
| c193 = UOp.new_buffer('METAL', 64, dtypes.float, 155) | |
| c194 = UOp(Ops.KERNEL, src=(c193, c115, c106, c188), arg=Kernel(k31(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True), Metadata(name='mean', caller='', backward=True)))) | |
| c195 = c193.after(c194) | |
| c196 = c195.forced_reshape((64,)) | |
| c198 = UOp.new_buffer('METAL', 1179648, dtypes.float, 156) | |
| c199 = UOp(Ops.KERNEL, src=(c198, c94, c100, c190, c115, c106, c188, c195), arg=Kernel(k32(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True), Metadata(name='__sub__', caller='', backward=False), Metadata(name='__mul__', caller='', backward=True), Metadata(name='relu', caller='', backward=False), Metadata(name='relu', caller='', backward=True)))) | |
| c200 = c198.after(c199) | |
| c201 = c200.reshape((512,64,6,6)) | |
| c203 = UOp.new_buffer('METAL', 2097152, dtypes.float, 157) | |
| c204 = UOp(Ops.KERNEL, src=(c203, c84, c90, c200), arg=Kernel(k33(), (Metadata(name='relu', caller='', backward=False), Metadata(name='conv2d', caller='', backward=True)))) | |
| c205 = c203.after(c204) | |
| c206 = c205.reshape((512,64,8,8)) | |
| c208 = UOp.new_buffer('METAL', 32, dtypes.float, 158) | |
| c210 = UOp.new_buffer('METAL', 8192, dtypes.float, 159) | |
| c212 = UOp.new_buffer('METAL', 6553600, dtypes.float, 160) | |
| c214 = UOp.new_buffer('METAL', 1638400, dtypes.float, 161) | |
| c215 = UOp(Ops.KERNEL, src=(c214, c76, c78), arg=Kernel(k34(), (Metadata(name='max_pool2d', caller='', backward=True),))) | |
| c218 = UOp.new_buffer('METAL', 1638400, dtypes.float, 162) | |
| c219 = UOp(Ops.KERNEL, src=(c218, c80, c205), arg=Kernel(k35(), (Metadata(name='conv2d', caller='', backward=True),))) | |
| c221 = UOp(Ops.KERNEL, src=(c212, c76, c78, c214.after(c215), c218.after(c219)), arg=Kernel(k36(), (Metadata(name='max_pool2d', caller='', backward=True),))) | |
| c222 = c212.after(c221) | |
| c223 = UOp(Ops.KERNEL, src=(c210, c43, c53, c72, c222), arg=Kernel(k37(), (Metadata(name='relu', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True)))) | |
| c225 = UOp(Ops.KERNEL, src=(c208, c210.after(c223), c63), arg=Kernel(k38(), (Metadata(name='rsqrt', caller='', backward=True), Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='mean', caller='', backward=True)))) | |
| c226 = c208.after(c225) | |
| c227 = c226.forced_reshape((32,)) | |
| c229 = UOp.new_buffer('METAL', 32, dtypes.float, 163) | |
| c231 = UOp.new_buffer('METAL', 8192, dtypes.float, 164) | |
| c232 = UOp(Ops.KERNEL, src=(c231, c72, c63, c222), arg=Kernel(k39(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True)))) | |
| c234 = UOp(Ops.KERNEL, src=(c229, c231.after(c232)), arg=Kernel(k40(), (Metadata(name='mean', caller='', backward=True),))) | |
| c235 = c229.after(c234) | |
| c236 = c235.forced_reshape((32,)) | |
| c238 = UOp.new_buffer('METAL', 6553600, dtypes.float, 165) | |
| c239 = UOp(Ops.KERNEL, src=(c238, c43, c53, c226, c72, c63, c222, c235), arg=Kernel(k41(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True), Metadata(name='__sub__', caller='', backward=False), Metadata(name='__mul__', caller='', backward=True), Metadata(name='relu', caller='', backward=False), Metadata(name='relu', caller='', backward=True)))) | |
| c240 = c238.after(c239) | |
| c241 = c240.reshape((512,32,20,20)) | |
| c243 = UOp.new_buffer('METAL', 9437184, dtypes.float, 166) | |
| c244 = UOp(Ops.KERNEL, src=(c243, c33, c39, c240), arg=Kernel(k42(), (Metadata(name='relu', caller='', backward=False), Metadata(name='conv2d', caller='', backward=True)))) | |
| c245 = c243.after(c244) | |
| c246 = c245.reshape((512,32,24,24)) | |
| c248 = UOp.new_buffer('METAL', 800, dtypes.float, 167) | |
| c250 = UOp.new_buffer('METAL', 204800, dtypes.float, 168) | |
| c251 = UOp(Ops.KERNEL, src=(c250, c27, c245), arg=Kernel(k43(), (Metadata(name='conv2d', caller='', backward=False), Metadata(name='conv2d', caller='', backward=True)))) | |
| c253 = UOp(Ops.KERNEL, src=(c248, c250.after(c251)), arg=Kernel(k44(), (Metadata(name='contiguous', caller='', backward=False),))) | |
| c254 = c248.after(c253) | |
| c256 = c254.reshape((32,1,5,5)) | |
| c258 = UOp.new_buffer('METAL', 32, dtypes.float, 169) | |
| c260 = UOp.new_buffer('METAL', 8192, dtypes.float, 170) | |
| c261 = UOp(Ops.KERNEL, src=(c260, c245), arg=Kernel(k45(), ())) | |
| c263 = UOp(Ops.KERNEL, src=(c258, c260.after(c261)), arg=Kernel(k46(), (Metadata(name='contiguous', caller='', backward=False),))) | |
| c264 = c258.after(c263) | |
| c265 = c264.forced_reshape((32,)) | |
| c267 = UOp.new_buffer('METAL', 25600, dtypes.float, 171) | |
| c269 = UOp.new_buffer('METAL', 3276800, dtypes.float, 172) | |
| c270 = UOp(Ops.KERNEL, src=(c269, c33, c240), arg=Kernel(k47(), (Metadata(name='relu', caller='', backward=False), Metadata(name='conv2d', caller='', backward=False), Metadata(name='conv2d', caller='', backward=True)))) | |
| c272 = UOp(Ops.KERNEL, src=(c267, c269.after(c270)), arg=Kernel(k48(), (Metadata(name='contiguous', caller='', backward=False),))) | |
| c273 = c267.after(c272) | |
| c275 = c273.reshape((32,32,5,5)) | |
| c277 = UOp.new_buffer('METAL', 32, dtypes.float, 173) | |
| c279 = UOp.new_buffer('METAL', 8192, dtypes.float, 174) | |
| c280 = UOp(Ops.KERNEL, src=(c279, c240), arg=Kernel(k49(), ())) | |
| c282 = UOp(Ops.KERNEL, src=(c277, c279.after(c280)), arg=Kernel(k46(), (Metadata(name='contiguous', caller='', backward=False),))) | |
| c283 = c277.after(c282) | |
| c284 = c283.forced_reshape((32,)) | |
| c286 = UOp.new_buffer('METAL', 32, dtypes.float, 175) | |
| c288 = UOp.new_buffer('METAL', 8192, dtypes.float, 176) | |
| c289 = UOp(Ops.KERNEL, src=(c288, c43, c53, c63, c222), arg=Kernel(k50(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='relu', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True)))) | |
| c291 = UOp(Ops.KERNEL, src=(c286, c288.after(c289)), arg=Kernel(k46(), (Metadata(name='contiguous', caller='', backward=False),))) | |
| c292 = c286.after(c291) | |
| c293 = c292.forced_reshape((32,)) | |
| c295 = UOp.new_buffer('METAL', 32, dtypes.float, 177) | |
| c297 = UOp.new_buffer('METAL', 8192, dtypes.float, 178) | |
| c298 = UOp(Ops.KERNEL, src=(c297, c222), arg=Kernel(k49(), ())) | |
| c300 = UOp(Ops.KERNEL, src=(c295, c297.after(c298)), arg=Kernel(k46(), (Metadata(name='contiguous', caller='', backward=False),))) | |
| c301 = c295.after(c300) | |
| c302 = c301.forced_reshape((32,)) | |
| c304 = UOp.new_buffer('METAL', 18432, dtypes.float, 179) | |
| c306 = UOp.new_buffer('METAL', 2359296, dtypes.float, 180) | |
| c307 = UOp(Ops.KERNEL, src=(c306, c78, c205), arg=Kernel(k51(), (Metadata(name='conv2d', caller='', backward=True),))) | |
| c309 = UOp(Ops.KERNEL, src=(c304, c306.after(c307)), arg=Kernel(k52(), (Metadata(name='contiguous', caller='', backward=False),))) | |
| c310 = c304.after(c309) | |
| c312 = c310.reshape((64,32,3,3)) | |
| c314 = UOp.new_buffer('METAL', 64, dtypes.float, 181) | |
| c316 = UOp.new_buffer('METAL', 16384, dtypes.float, 182) | |
| c317 = UOp(Ops.KERNEL, src=(c316, c205), arg=Kernel(k53(), ())) | |
| c319 = UOp(Ops.KERNEL, src=(c314, c316.after(c317)), arg=Kernel(k54(), (Metadata(name='contiguous', caller='', backward=False),))) | |
| c320 = c314.after(c319) | |
| c321 = c320.forced_reshape((64,)) | |
| c323 = UOp.new_buffer('METAL', 36864, dtypes.float, 183) | |
| c324 = UOp(Ops.KERNEL, src=(c323, c84, c200), arg=Kernel(k55(), (Metadata(name='relu', caller='', backward=False), Metadata(name='conv2d', caller='', backward=False), Metadata(name='conv2d', caller='', backward=True), Metadata(name='contiguous', caller='', backward=False)))) | |
| c325 = c323.after(c324) | |
| c327 = c325.reshape((64,64,3,3)) | |
| c329 = UOp.new_buffer('METAL', 64, dtypes.float, 184) | |
| c330 = UOp(Ops.KERNEL, src=(c329, c200), arg=Kernel(k56(), (Metadata(name='conv2d', caller='', backward=True), Metadata(name='contiguous', caller='', backward=False)))) | |
| c331 = c329.after(c330) | |
| c332 = c331.forced_reshape((64,)) | |
| c334 = UOp.new_buffer('METAL', 64, dtypes.float, 185) | |
| c335 = UOp(Ops.KERNEL, src=(c334, c94, c100, c106, c188), arg=Kernel(k57(), (Metadata(name='add', caller='', backward=False), Metadata(name='rsqrt', caller='', backward=False), Metadata(name='relu', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=False), Metadata(name='batchnorm', caller='', backward=True), Metadata(name='contiguous', caller='', backward=False)))) | |
| c336 = c334.after(c335) | |
| c337 = c336.forced_reshape((64,)) | |
| c339 = UOp.new_buffer('METAL', 64, dtypes.float, 186) | |
| c340 = UOp(Ops.KERNEL, src=(c339, c188), arg=Kernel(k56(), (Metadata(name='batchnorm', caller='', backward=True), Metadata(name='contiguous', caller='', backward=False)))) | |
| c341 = c339.after(c340) | |
| c342 = c341.forced_reshape((64,)) | |
| c344 = UOp.new_buffer('METAL', 5760, dtypes.float, 187) | |
| c345 = UOp(Ops.KERNEL, src=(c344, c121, c173), arg=Kernel(k58(), (Metadata(name='linear', caller='', backward=True), Metadata(name='contiguous', caller='', backward=False)))) | |
| c346 = c344.after(c345) | |
| c348 = c346.reshape((10,576)) | |
| c350 = UOp.new_buffer('METAL', 10, dtypes.float, 188) | |
| c351 = UOp(Ops.KERNEL, src=(c350, c173), arg=Kernel(k59(), (Metadata(name='linear', caller='', backward=True), Metadata(name='contiguous', caller='', backward=False)))) | |
| c352 = c350.after(c351) | |
| c354 = c352.forced_reshape((10,)) | |
| c356 = UOp.new_buffer('METAL', 87850, dtypes.float, 189) | |
| c357 = UOp(Ops.KERNEL, src=(c356, c254, c264, c273, c283, c292, c301, c310, c320, c325, c331, c336, c341, c346, c352), arg=Kernel(k60(), (Metadata(name='cat', caller='', backward=False),))) | |
| c358 = c356.after(c357) | |
| c360 = c358.forced_reshape((87850,)) | |
| c362 = UOp.new_buffer('METAL', 87850, dtypes.float, 103) | |
| c363 = UOp(Ops.KERNEL, src=(c362, c358), arg=Kernel(k61(), (Metadata(name='__rmul__', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='assign', caller='', backward=False)))) | |
| c364 = c362.after(c363) | |
| c366 = c364.forced_reshape((87850,)).forced_reshape((87850,)) | |
| c368 = UOp.new_buffer('METAL', 87850, dtypes.float, 104) | |
| c369 = UOp(Ops.KERNEL, src=(c368, c358), arg=Kernel(k62(), (Metadata(name='__mul__', caller='', backward=False), Metadata(name='__rmul__', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='assign', caller='', backward=False)))) | |
| c370 = c368.after(c369) | |
| c372 = c370.forced_reshape((87850,)).forced_reshape((87850,)) | |
| c374 = UOp.new_buffer('METAL', 87850, dtypes.float, 190) | |
| c376 = UOp.new_buffer('METAL', 1, dtypes.float, 105) | |
| c377 = UOp(Ops.KERNEL, src=(c374, c29, c31, c39, c41, c72, c74, c80, c82, c90, c92, c115, c117, c123, c125, c376, c364, c157, c370, c163), arg=Kernel(k63(), (Metadata(name='__truediv__', caller='', backward=False), Metadata(name='sqrt', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='__rsub__', caller='', backward=False), Metadata(name='__mul__', caller='', backward=False), Metadata(name='cat', caller='', backward=False), Metadata(name='__sub__', caller='', backward=False)))) | |
| c378 = c374.after(c377) | |
| c379 = c378.forced_reshape((87850,)) | |
| c380 = UOp(Ops.KERNEL, src=(c29, c378), arg=Kernel(k64(), (Metadata(name='assign', caller='', backward=False),))) | |
| c383 = c29.after(c380).reshape((32,1,5,5)).forced_reshape((32,1,5,5)) | |
| c384 = UOp(Ops.KERNEL, src=(c31, c378), arg=Kernel(k65(), (Metadata(name='assign', caller='', backward=False),))) | |
| c387 = c31.after(c384).forced_reshape((32,)).forced_reshape((32,)) | |
| c388 = UOp(Ops.KERNEL, src=(c39, c378), arg=Kernel(k66(), (Metadata(name='assign', caller='', backward=False),))) | |
| c391 = c39.after(c388).reshape((32,32,5,5)).forced_reshape((32,32,5,5)) | |
| c392 = UOp(Ops.KERNEL, src=(c41, c378), arg=Kernel(k67(), (Metadata(name='assign', caller='', backward=False),))) | |
| c395 = c41.after(c392).forced_reshape((32,)).forced_reshape((32,)) | |
| c396 = UOp(Ops.KERNEL, src=(c72, c378), arg=Kernel(k68(), (Metadata(name='assign', caller='', backward=False),))) | |
| c399 = c72.after(c396).forced_reshape((32,)).forced_reshape((32,)) | |
| c400 = UOp(Ops.KERNEL, src=(c74, c378), arg=Kernel(k69(), (Metadata(name='assign', caller='', backward=False),))) | |
| c403 = c74.after(c400).forced_reshape((32,)).forced_reshape((32,)) | |
| c404 = UOp(Ops.KERNEL, src=(c80, c378), arg=Kernel(k70(), (Metadata(name='assign', caller='', backward=False),))) | |
| c407 = c80.after(c404).reshape((64,32,3,3)).forced_reshape((64,32,3,3)) | |
| c408 = UOp(Ops.KERNEL, src=(c82, c378), arg=Kernel(k71(), (Metadata(name='assign', caller='', backward=False),))) | |
| c411 = c82.after(c408).forced_reshape((64,)).forced_reshape((64,)) | |
| c412 = UOp(Ops.KERNEL, src=(c90, c378), arg=Kernel(k72(), (Metadata(name='assign', caller='', backward=False),))) | |
| c415 = c90.after(c412).reshape((64,64,3,3)).forced_reshape((64,64,3,3)) | |
| c416 = UOp(Ops.KERNEL, src=(c92, c378), arg=Kernel(k73(), (Metadata(name='assign', caller='', backward=False),))) | |
| c419 = c92.after(c416).forced_reshape((64,)).forced_reshape((64,)) | |
| c420 = UOp(Ops.KERNEL, src=(c115, c378), arg=Kernel(k74(), (Metadata(name='assign', caller='', backward=False),))) | |
| c423 = c115.after(c420).forced_reshape((64,)).forced_reshape((64,)) | |
| c424 = UOp(Ops.KERNEL, src=(c117, c378), arg=Kernel(k75(), (Metadata(name='assign', caller='', backward=False),))) | |
| c427 = c117.after(c424).forced_reshape((64,)).forced_reshape((64,)) | |
| c428 = UOp(Ops.KERNEL, src=(c123, c378), arg=Kernel(k76(), (Metadata(name='assign', caller='', backward=False),))) | |
| c431 = c123.after(c428).reshape((10,576)).forced_reshape((10,576)) | |
| c432 = UOp(Ops.KERNEL, src=(c125, c378), arg=Kernel(k77(), (Metadata(name='assign', caller='', backward=False),))) | |
| c435 = c125.after(c432).forced_reshape((10,)).forced_reshape((10,)) | |
| c437 = UOp.new_buffer('METAL', 1, dtypes.long, 121) | |
| c438 = UOp(Ops.KERNEL, src=(c437,), arg=Kernel(k78(), (Metadata(name='__iadd__', caller='', backward=False),))) | |
| c441 = c437.after(c438).reshape(()).forced_reshape(()) | |
| c443 = UOp.new_buffer('METAL', 32, dtypes.float, 122) | |
| c444 = UOp(Ops.KERNEL, src=(c443, c53), arg=Kernel(k79(), (Metadata(name='__rmul__', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='assign', caller='', backward=False)))) | |
| c447 = c443.after(c444).forced_reshape((32,)).forced_reshape((32,)) | |
| c449 = UOp.new_buffer('METAL', 32, dtypes.float, 123) | |
| c450 = UOp(Ops.KERNEL, src=(c449, c63), arg=Kernel(k80(), (Metadata(name='__rmul__', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='assign', caller='', backward=False)))) | |
| c453 = c449.after(c450).forced_reshape((32,)).forced_reshape((32,)) | |
| c455 = UOp.new_buffer('METAL', 64, dtypes.float, 124) | |
| c456 = UOp(Ops.KERNEL, src=(c455, c100), arg=Kernel(k81(), (Metadata(name='__rmul__', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='assign', caller='', backward=False)))) | |
| c459 = c455.after(c456).forced_reshape((64,)).forced_reshape((64,)) | |
| c461 = UOp.new_buffer('METAL', 64, dtypes.float, 125) | |
| c462 = UOp(Ops.KERNEL, src=(c461, c106), arg=Kernel(k82(), (Metadata(name='__rmul__', caller='', backward=False), Metadata(name='__add__', caller='', backward=False), Metadata(name='assign', caller='', backward=False)))) | |
| c465 = c461.after(c462).forced_reshape((64,)).forced_reshape((64,)) | |
| ast = c7.sink(c15, c35, c45, c55, c64, c86, c96, c102, c107, c129, c153, c159, c165, c174, c191, c196, c201, c206, c227, c236, c241, c246, c256, c265, c275, c284, c293, c302, c312, c321, c327, c332, c337, c342, c348, c354, c360, c366, c372, c379, c383, c387, c391, c395, c399, c403, c407, c411, c415, c419, c423, c427, c431, c435, c441, c447, c453, c459, c465) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment