Last active
March 26, 2025 17:52
-
-
Save moea/6e066c0b403f65cb027705a269b93ad3 to your computer and use it in GitHub Desktop.
Typechecking for GADTs in Clojure
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns thebes.gadt | |
(:require [clojure.string :as str] | |
[clojure.walk :as walk])) | |
(defrecord TypeVar [name]) | |
(defrecord Variant [name tag params ret-type]) | |
(defrecord TypeApp [name args]) | |
(defn type-var? [t] (instance? TypeVar t)) | |
(defn variant? [t] (instance? Variant t)) | |
(defn atomic-type? [t] (keyword? t)) | |
(defn type-app? [t] (instance? TypeApp t)) | |
(defn make-type-var [name'] | |
(->TypeVar (gensym (name name')))) | |
(defn make-variant | |
[name tag params ret-type] (->Variant name tag params ret-type)) | |
(defn make-type-app [name args] | |
(->TypeApp name args)) | |
(defn env-lookup | |
[env var] | |
(or (get env var) | |
(throw (ex-info (str "Unbound identifier: " var) {:var var})))) | |
(defn type->str | |
[type] | |
(cond | |
(atomic-type? type) (name type) | |
(type-var? type) (str "'" (:name type)) | |
(variant? type) (str (:name type) "." (:tag type)) | |
(type-app? type) (str "(" (:name type) " " | |
(str/join " " (map type->str (:args type))) ")") | |
:else (str type))) | |
(defn instantiate | |
[type subst] | |
(cond | |
(atomic-type? type) type | |
(type-var? type) (subst (:name type) type) | |
(variant? type) (make-variant | |
(:name type) | |
(:tag type) | |
(map #(instantiate % subst) (:params type)) | |
(instantiate (:ret-type type) subst)) | |
(type-app? type) (update type :args (fn updater [args] | |
(map #(instantiate % subst) args))) | |
:else type)) | |
(defn unify [t1 t2] | |
(cond | |
(= t1 t2) {} | |
(type-var? t1) {(:name t1) t2} | |
(type-var? t2) {(:name t2) t1} | |
(and (type-app? t1) | |
(type-app? t2) | |
(= (:name t1) | |
(:name t2))) (unify (:args t1) (:args t2)) | |
(and | |
(sequential? t1) | |
(sequential? t2) | |
(= (count t1) | |
(count t2))) (reduce | |
(fn [subst [arg1 arg2]] | |
(merge | |
subst | |
(unify (instantiate arg1 subst) | |
(instantiate arg2 subst)))) | |
{} | |
(map vector t1 t2)) | |
:else | |
(throw (ex-info "Cannot unify types" {:t1 t1 :t2 t2})))) | |
(defn typecheck [expr env] | |
(cond | |
(number? expr) :number | |
(string? expr) :string | |
(boolean? expr) :bool | |
(symbol? expr) (env-lookup env expr) | |
(and (list? expr) (symbol? (first expr))) | |
(let [[ctor & args] expr | |
ctor-type (env-lookup env ctor)] | |
(assert (variant? ctor-type)) | |
(let [param-types (:params ctor-type) | |
ret-type (:ret-type ctor-type)] | |
(assert (= (count args) (count param-types))) | |
(let [arg-types (mapv #(typecheck % env) args) | |
subst (unify param-types arg-types)] | |
(instantiate ret-type subst)))))) | |
(defn define-gadt [env name type-params variants] | |
(let [variants (map #(apply make-variant name %) variants)] | |
(reduce | |
(fn reducer [e v] | |
(assoc e (:tag v) v)) | |
env | |
variants))) | |
(defmacro data-type [tname tvars & variants] | |
(let [subst (into {} | |
(for [tvar tvars] | |
[tvar (make-type-var tvar)])) | |
variants (walk/postwalk | |
(fn walk [form] | |
(let [form (subst form form)] | |
(if (and (list? form) (= tname (first form))) | |
(make-type-app (first form) (rest form)) | |
form))) | |
variants)] | |
`(define-gadt | |
{} | |
(quote ~tname) | |
(quote ~tvars) | |
~(into [] | |
(for [[name args ret] variants] | |
`[(quote ~name) ~args ~ret]))))) | |
(defmacro tc-debug [expr env] | |
`(println (quote ~expr) "=>" (type->str (typecheck (quote ~expr) ~env)))) | |
(let [env (data-type List [A] | |
(Nil [] (List A)) | |
(Cons [A (List A)] (List A)))] | |
(tc-debug (Nil) env) | |
(tc-debug (Cons 1 (Cons 2 (Nil))) env) | |
(try | |
(tc-debug (Cons 1 (Cons "hello" (Nil))) env) | |
(println "ERROR: Mixed-type list should have been rejected!") | |
(catch Exception e | |
(println "Type error (expected): " (.getMessage e))))) | |
(let [env (data-type Option [A] | |
(None [] (Option A)) | |
(Some [A] (Option A)))] | |
(tc-debug (None) env) | |
(tc-debug (Some 42) env) | |
(tc-debug (Some "hello") env)) | |
(let [env (data-type Expr [A] | |
(LitNum [:number] (Expr :number)) | |
(LitBool [:bool] (Expr :bool)) | |
(Not [(Expr :bool)] (Expr :bool)) | |
(Add [(Expr :number) | |
(Expr :number)] (Expr :number)) | |
(Eq? [(Expr :number) | |
(Expr :number)] (Expr :bool)) | |
(If [(Expr :bool) | |
(Expr A) | |
(Expr A)] (Expr A)))] | |
(tc-debug (Eq? (LitNum 1) (LitNum 2)) env) | |
(tc-debug (If (Not (LitBool true)) (LitNum 1) (LitNum 2)) env) | |
(tc-debug (If (Eq? (Add (LitNum 1) (LitNum 2)) (LitNum 2)) (LitBool true) (LitBool false)) env)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment