Last active
September 10, 2024 10:18
-
-
Save cursive-ide/af961a5c513adcb12bd75509813b1619 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns cursive.llm.api.malli-repro | |
(:require [clojure.data.json :as json] | |
[clojure.string :as str] | |
[malli.core :as m] | |
[malli.transform :as mt])) | |
(defn decode-type | |
"This is required for decoding :type fields, otherwise they are not | |
converted and don't dispatch properly." | |
[x] | |
(if-let [type (:type x)] | |
(assoc x :type (keyword (str/replace type "_" "-"))) | |
x)) | |
(defn encode-type | |
"This is the inverse of the above." | |
[x] | |
(if-let [type (:type x)] | |
(assoc x :type (str/replace (name type) "_" "-")) | |
x)) | |
(def Delta | |
[:map | |
[:stop-reason {:optional true} [:maybe [:enum :end-turn :max-tokens]]]]) | |
(def FirstMessage | |
[:map | |
[:type [:= :first-message]] | |
[:reason {:optional true} [:maybe [:enum :foo-bar :baz-bup]]] | |
[:delta Delta]]) | |
(def SecondMessage | |
[:map | |
[:type [:= :second-message]] | |
[:reason {:encode/json identity} [:enum "foo-bar" "baz-bup"]] | |
[:value :float]]) | |
(def Message | |
[:multi {:dispatch :type | |
:decode/json decode-type} | |
[:first-message FirstMessage] | |
[:second-message SecondMessage]]) | |
(defn kebab-case [s] | |
(clojure.string/replace s "_" "-")) | |
(defn snake-case [s] | |
(clojure.string/replace s "-" "_")) | |
(defn transform-keyword [transform-fn x] | |
(if-let [ns (namespace x)] | |
(keyword (transform-fn ns) (transform-fn (name x))) | |
(keyword (transform-fn (name x))))) | |
(defn transform-map-keys [transform-fn m] | |
(reduce-kv (fn [acc k v] | |
(assoc acc (if (keyword? k) | |
(transform-keyword transform-fn k) | |
k) | |
v)) | |
{} | |
m)) | |
(def ->kebab | |
(mt/transformer | |
{:name :->kebab | |
:decoders | |
{:enum #(if (keyword? %) (transform-keyword kebab-case %) %) | |
:keyword #(transform-keyword kebab-case %) | |
:= #(transform-keyword kebab-case %) | |
:map #(transform-map-keys kebab-case %)} | |
:encoders | |
{:enum #(if (keyword %) (clojure.string/replace % "-" "_") %) | |
:keyword #(transform-keyword snake-case %) | |
:= #(clojure.string/replace % "-" "_") | |
:map #(transform-map-keys snake-case %)}})) | |
(def message-input-transformer (mt/transformer | |
mt/json-transformer | |
->kebab)) | |
(def message-output-transformer (mt/transformer | |
mt/json-transformer | |
->kebab)) | |
(defn round-trip [msg-json] | |
(let [decoded (m/coerce Message | |
(json/read-str msg-json :key-fn keyword) | |
message-input-transformer) | |
encoded (m/encode Message | |
decoded | |
message-output-transformer)] | |
(json/write-str encoded))) | |
(defn test-it [] | |
(let [message "{\"type\":\"first_message\",\"reason\":\"foo_bar\",\"delta\":{\"stop_reason\":\"end_turn\"}}"] | |
(println message) | |
(println (round-trip message))) | |
(let [message "{\"type\":\"second_message\",\"value\":0.0,\"reason\":\"foo-bar\"}"] | |
(println message) | |
(println (round-trip message)))) | |
(comment | |
(test-it)) | |
; {"type":"first_message","reason":"foo_bar","delta":{"stop_reason":"end_turn"}} | |
; {"type":"first_message","reason":"foo_bar","delta":{"stop_reason":"end-turn"}} | |
; {"type":"second_message","value":0.0,"reason":"foo-bar"} | |
; {"type":"second_message","value":0.0,"reason":"foo_bar"} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment