Last active
August 9, 2020 05:09
-
-
Save fowlmouth/9b9010397ad5fe4b9872 to your computer and use it in GitHub Desktop.
nim getType() fun
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# here is an implementation of curry, arguments passed are wrapped up in a new | |
# closure and a function is returned that accepts the remaining arguments | |
# usage: | |
# proc foo (a,b: int): int = a + b | |
# let f = curry(foo, 10) | |
# assert f(10) == 20 | |
# | |
# note: to use an overloaded function you must annotate its type | |
# curry((proc(c:char,len:int):string)strutils.repeat, 'x') | |
# | |
import macros | |
proc type_to_nim (n:NimNode): NimNode {.compileTime.} = | |
# returns a symbol for a type | |
# n should be from the typegraph returned by macros.getType | |
result = case n.typeKind | |
of ntyRef, ntyPtr, ntyRange, ntyProc: | |
n[0] | |
else: | |
echo n.typeKind | |
echo n.treerepr | |
n | |
when false: | |
#now you can use the symbol returned from getType() in AST for | |
#the represented type, so this old function is useless | |
proc type_to_nim (n: NimNode): NimNode {.compileTime.} = | |
let ty = n.typeKind | |
case ty | |
of ntyRef: | |
result = newNimNode(nnkRefTy).add(n[1].type_to_nim) | |
of ntyPtr: | |
result = newNimNode(nnkRefTy).add(n[1].type_to_nim) | |
of ntyRange: | |
result = newNimNode(nnkBracketExpr).add( | |
ident"range", | |
infix(n[1].type_to_nim, "..", n[2].type_to_nim)) | |
of ntyArray: | |
result = newNimNode(nnkBracketExpr).add( | |
ident"array", | |
n[1].type_to_nim, n[2].type_to_nim) | |
of ntyEmpty: | |
result = newEmptyNode() | |
of ntyProc: | |
let params = newNimNode(nnkFormalParams) | |
params.add n[1].type_to_nim | |
for i in 2 .. len(n)-1: | |
params.add newIdentDefs(ident("arg"& $(i-2)), n[i].type_to_nim) | |
result = newNimNode(nnkProcTy).add(params, newEmptyNode()) | |
of ntyInt .. ntyFloat128, ntyString: | |
if n.kind == nnkSym: | |
result = ident($ n.symbol) | |
else: | |
# literal? | |
result = n | |
else: | |
echo ty, ": ", n.repr | |
quit "unhandled type" | |
result.repr.echo | |
macro curry (f:stmt; args:varargs[expr]): expr = | |
let ty = getType(f) | |
assert($ty[0] == "proc", "first param is not a function") | |
let n_remaining = ty.len - 2 - args.len | |
assert n_remaining > 0, "cannot curry all the parameters" | |
#echo treerepr ty | |
var callExpr = newCall(f) | |
args.copyChildrenTo callExpr | |
var params: seq[NimNode] = @[] | |
# return type | |
params.add ty[1].type_to_nim | |
for i in 0 .. <n_remaining: | |
let param = ident("arg"& $i) | |
params.add newIdentDefs(param, ty[i+2+args.len].type_to_nim2) | |
callExpr.add param | |
result = newProc(procType = nnkLambda, params = params, body = callExpr) | |
when defined(Debug): | |
result.repr.echo | |
when isMainModule: | |
# need better examples | |
proc foo (a,b,c: int): int = | |
a + b * c | |
let f = curry(foo, 1) | |
assert f(2, 3) == 7 # 1+2*3 = 7 | |
assert curry(foo,1,2)(3) == 7 | |
proc qux (x:int, y:float, z:float): int = | |
(x.float * y * z).int | |
let f1 = curry(qux, 42) | |
let f2 = curry(f1, 100.0) | |
assert f2(6) == int(42 * 100 * 6) | |
import strutils | |
let fz = curry((proc(c:char,n:int):string)strutils.repeat, 'x') | |
assert fz(3) == "xxx" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# this is an implementation of `==` that properly handles variant types | |
# this is the expression generated for the type `Variant` near line 122: | |
discard """ | |
a.xx == b.xx and a.k == b.k and a.k2 == b.k2 and | |
case a.k | |
of aa, cc: a.i == b.i and a.z == b.z | |
of bb: a.f == b.f | |
of dd: true | |
else: true | |
and | |
case a.k2 | |
of xa: a.xi == b.xi | |
of xb: a.xb == b.xb | |
""" | |
import macros | |
proc flatten_reclist (n:NimNode): seq[NimNode] {.compileTime.}= | |
# flatten a reclist of syms or just a sym into a seq | |
result.newseq 0 | |
if n.kind == nnkRecList: | |
for child in n.children: result.add child | |
elif n.kind == nnkSym: | |
result.add n | |
proc `==` [T: object] (a,b: T): bool = | |
macro mkEq (a,b: stmt): expr = | |
let | |
ty = a.getType | |
tyKind = ty.typeKind | |
#echo ty.typeKind, " : ", ty.treerepr | |
case tyKind | |
of ntyObject: | |
when defined(Debug): | |
echo ty.treerepr | |
ty[1].expectKind nnkRecList | |
# the goal here is to check all of the objects normal fields first, including variant discriminators | |
# after that we will check the union fields | |
var varfields: seq[NimNode] = @[] | |
var conditions: seq[NimNode] = @[] | |
template checkcond (sym): expr = | |
infix( | |
newDotExpr(ident"a", ident($sym)), | |
"==", | |
newDotExpr(ident"b", ident($sym))) | |
for field in ty[1].children: | |
case field.kind | |
of nnkSym: | |
conditions.add checkcond(field.symbol)#parseExpr("a.$1 == b.$1".format($ field.symbol)) | |
of nnkRecCase: | |
field[0].expectKind nnkSym | |
conditions.add checkcond(field[0].symbol) #parseExpr("a.$1 == b.$1".format($ field[0].symbol)) | |
varfields.add field | |
else: | |
quit "unexpect field member "& treerepr(field) | |
for vf in varfields: | |
let cs = newNimNode(nnkCaseStmt) | |
cs.add newDotExpr(ident"a", ident($ vf[0].symbol)) | |
# iterate over "of"s | |
for i in 1 .. <len(vf): | |
let tyBranch = vf[i] | |
let newBranch = newNimNode(tyBranch.kind) | |
# last entry is reclist(sym, ...) or sym | |
let syms = flatten_reclist(tyBranch[< len(tyBranch)]) | |
if tyBranch.len> 1: | |
for ii in 0 .. len(tyBranch)-2: | |
newBranch.add tyBranch[ii] | |
if syms.len > 0: | |
var res: NimNode | |
for s in syms: | |
if res.isNil: | |
res = checkcond(s.symbol) | |
else: | |
res = infix(res, "and", checkcond(s.symbol)) | |
newBranch.add res | |
else: | |
# this is nil/discard so it always passes | |
newBranch.add ident"true" | |
cs.add newBranch | |
conditions.add newNimNode(nnkStmtListExpr).add(cs) | |
block: | |
var res: NimNode | |
for c in conditions: | |
res = | |
if res.isNil: c | |
else: infix(res, "and", c) | |
result = res | |
if result.isNil: | |
result = ident"true" | |
of ntyRef, ntyPtr: | |
result = quote do: (if a.isNil: b.isNil elif b.isNil: false else: a[] == b[]) | |
#parseExpr("return (if a.isNil: b.isNil elif b.isNil: false else: a[] == b[])") | |
# of ntyEnum: | |
# result.add parseExpr("return system.`==`(a,b)") | |
else: | |
echo "typekind not handled for ==: "& $tyKind | |
result = quote do: system.`==`(a,b) | |
# ^ this causes error when its hit..? but its fine under ntyEnum | |
if result.isNil: | |
result = ident"false" | |
when defined(Debug): | |
echo "result: ", repr(result) | |
mkEq(a,b) | |
when isMainModule: | |
# a test | |
type | |
En = enum | |
aa,bb,cc,dd | |
En2 = enum | |
xa, xb | |
Variant = object | |
xx: int | |
case k: En | |
of aa,cc: | |
i:int | |
z:string | |
of bb: | |
f:float | |
of dd: | |
discard | |
else: discard | |
case k2: En2 | |
of xa: xi: int | |
of xb: xb: int | |
template test (boolExpr): stmt = | |
echo "[", (if boolExpr: "pass" else: "FAIL"), "] ", astToStr(boolExpr) | |
let | |
v1 = Variant(k: aa, i: 42) | |
v2 = Variant(k: bb, f: 1.9) | |
test v1 == v1 | |
test v2 == v2 | |
test v1 != v2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Another Example (gives the all important Answer to the live, universe you get it...: