Skip to content

Instantly share code, notes, and snippets.

@cursive-ide
Last active September 10, 2024 10:18
Show Gist options
  • Save cursive-ide/af961a5c513adcb12bd75509813b1619 to your computer and use it in GitHub Desktop.
Save cursive-ide/af961a5c513adcb12bd75509813b1619 to your computer and use it in GitHub Desktop.
(ns cursive.llm.api.malli-repro
(:require [clojure.data.json :as json]
[clojure.string :as str]
[malli.core :as m]
[malli.transform :as mt]))
(defn decode-type
"This is required for decoding :type fields, otherwise they are not
converted and don't dispatch properly."
[x]
(if-let [type (:type x)]
(assoc x :type (keyword (str/replace type "_" "-")))
x))
(defn encode-type
"This is the inverse of the above."
[x]
(if-let [type (:type x)]
(assoc x :type (str/replace (name type) "_" "-"))
x))
(def Delta
[:map
[:stop-reason {:optional true} [:maybe [:enum :end-turn :max-tokens]]]])
(def FirstMessage
[:map
[:type [:= :first-message]]
[:reason {:optional true} [:maybe [:enum :foo-bar :baz-bup]]]
[:delta Delta]])
(def SecondMessage
[:map
[:type [:= :second-message]]
[:reason {:encode/json identity} [:enum "foo-bar" "baz-bup"]]
[:value :float]])
(def Message
[:multi {:dispatch :type
:decode/json decode-type}
[:first-message FirstMessage]
[:second-message SecondMessage]])
(defn kebab-case [s]
(clojure.string/replace s "_" "-"))
(defn snake-case [s]
(clojure.string/replace s "-" "_"))
(defn transform-keyword [transform-fn x]
(if-let [ns (namespace x)]
(keyword (transform-fn ns) (transform-fn (name x)))
(keyword (transform-fn (name x)))))
(defn transform-map-keys [transform-fn m]
(reduce-kv (fn [acc k v]
(assoc acc (if (keyword? k)
(transform-keyword transform-fn k)
k)
v))
{}
m))
(def ->kebab
(mt/transformer
{:name :->kebab
:decoders
{:enum #(if (keyword? %) (transform-keyword kebab-case %) %)
:keyword #(transform-keyword kebab-case %)
:= #(transform-keyword kebab-case %)
:map #(transform-map-keys kebab-case %)}
:encoders
{:enum #(if (keyword %) (clojure.string/replace % "-" "_") %)
:keyword #(transform-keyword snake-case %)
:= #(clojure.string/replace % "-" "_")
:map #(transform-map-keys snake-case %)}}))
(def message-input-transformer (mt/transformer
mt/json-transformer
->kebab))
(def message-output-transformer (mt/transformer
mt/json-transformer
->kebab))
(defn round-trip [msg-json]
(let [decoded (m/coerce Message
(json/read-str msg-json :key-fn keyword)
message-input-transformer)
encoded (m/encode Message
decoded
message-output-transformer)]
(json/write-str encoded)))
(defn test-it []
(let [message "{\"type\":\"first_message\",\"reason\":\"foo_bar\",\"delta\":{\"stop_reason\":\"end_turn\"}}"]
(println message)
(println (round-trip message)))
(let [message "{\"type\":\"second_message\",\"value\":0.0,\"reason\":\"foo-bar\"}"]
(println message)
(println (round-trip message))))
(comment
(test-it))
; {"type":"first_message","reason":"foo_bar","delta":{"stop_reason":"end_turn"}}
; {"type":"first_message","reason":"foo_bar","delta":{"stop_reason":"end-turn"}}
; {"type":"second_message","value":0.0,"reason":"foo-bar"}
; {"type":"second_message","value":0.0,"reason":"foo_bar"}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment