-
-
Save Deraen/72dd1da901671272dc698012b851c210 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns cursive.llm.api.malli-repro | |
(:require [clojure.data.json :as json] | |
[clojure.string :as str] | |
[malli.core :as m] | |
[malli.transform :as mt])) | |
(defn decode-type | |
"This is required for decoding :type fields, otherwise they are not | |
converted and don't dispatch properly." | |
[x] | |
(if-let [type (:type x)] | |
(assoc x :type (keyword (str/replace type "_" "-"))) | |
x)) | |
(defn encode-type | |
"This is the inverse of the above." | |
[x] | |
(if-let [type (:type x)] | |
(assoc x :type (str/replace (name type) "_" "-")) | |
x)) | |
(def Delta | |
[:map | |
[:stop-reason {:optional true} [:maybe [:enum :end-turn :max-tokens]]]]) | |
(def FirstMessage | |
[:map | |
[:type [:= :first-message]] | |
[:reason {:optional true} [:maybe [:enum :foo-bar :baz-bup]]] | |
[:delta Delta]]) | |
(def SecondMessage | |
[:map | |
[:type [:= :second-message]] | |
[:reason [:enum "foo-bar" "baz-bup"]] | |
[:value :float]]) | |
(def Message | |
[:multi {:dispatch :type | |
:decode/json decode-type} | |
[:first-message FirstMessage] | |
[:second-message SecondMessage]]) | |
(defn kebab-case [s] | |
(clojure.string/replace s "_" "-")) | |
(defn snake-case [s] | |
(clojure.string/replace s "-" "_")) | |
(defn transform-keyword [transform-fn x] | |
(if-let [ns (namespace x)] | |
(keyword (transform-fn ns) (transform-fn (name x))) | |
(keyword (transform-fn (name x))))) | |
(defn transform-map-keys [transform-fn m] | |
(reduce-kv (fn [acc k v] | |
(assoc acc (if (keyword? k) | |
(transform-keyword transform-fn k) | |
k) | |
v)) | |
{} | |
m)) | |
(def ->kebab | |
(mt/transformer | |
{:name :->kebab | |
:decoders | |
{:enum (fn [v] | |
(println "decode enum" v) | |
(if (keyword? v) | |
(transform-keyword kebab-case v) | |
v)) | |
:keyword (fn [v] | |
(println "decode keyword" v) | |
(transform-keyword kebab-case v)) | |
:= #(transform-keyword kebab-case %) | |
:map #(transform-map-keys kebab-case %)} | |
:encoders | |
{ | |
:enum (fn [v] | |
(println "encode enum" v) | |
(clojure.string/replace v "-" "_")) | |
:keyword (fn [v] | |
(println "encode keyword" v) | |
(transform-keyword snake-case v)) | |
:= #(clojure.string/replace % "-" "_") | |
;; Without :compile and :leave step it runs on :enter which | |
;; would mean this runs before :enum encoder fn | |
:map {:compile (fn [_schema _] | |
{:leave (fn [v] | |
(println "encode map" v) | |
(transform-map-keys snake-case v))})}}})) | |
(def t (mt/transformer | |
mt/json-transformer | |
->kebab)) | |
(defn round-trip [msg-json] | |
(let [decoded (m/coerce Message | |
(json/read-str msg-json :key-fn keyword) | |
t) | |
;; _ (println "decoded" decoded) | |
valid? (m/validate Message decoded) | |
_ (println "valid" valid?) | |
encoded (m/encode Message | |
decoded | |
t)] | |
;; (println "encoded" encoded) | |
(json/write-str encoded))) | |
(defn test-it [] | |
(let [message "{\"type\":\"first_message\",\"reason\":\"foo_bar\",\"delta\":{\"stop_reason\":\"end_turn\"}}"] | |
(println message) | |
(println (round-trip message))) | |
(let [message "{\"type\":\"second_message\",\"value\":0.0,\"reason\":\"foo-bar\"}"] | |
(println message) | |
(println (round-trip message)))) | |
(comment | |
(test-it)) | |
; {"type":"first_message","reason":"foo_bar","delta":{"stop_reason":"end_turn"}} | |
; {"type":"first_message","reason":"foo_bar","delta":{"stop_reason":"end-turn"}} | |
; {"type":"second_message","value":0.0,"reason":"foo-bar"} | |
; {"type":"second_message","value":0.0,"reason":"foo_bar"} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment