Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save gorshkov-leonid/a7464721457015af6713a747a4981f7f to your computer and use it in GitHub Desktop.
Save gorshkov-leonid/a7464721457015af6713a747a4981f7f to your computer and use it in GitHub Desktop.
Simple Kotlin OpenAI Websocket implementation of the Realtime API
import android.util.Log
import io.ktor.util.decodeBase64Bytes
import io.ktor.util.encodeBase64
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import okhttp3.WebSocket
import okhttp3.WebSocketListener
import okio.ByteString
import org.json.JSONArray
import org.json.JSONObject
/**
* Listener for the OpenAI Realtime API.
*/
interface OpenAIRealtimeWebSocketListener {
val system: String
val voice: Boolean
val voiceId: String
val tools: List<FunctionCall>?
fun onException(webSocket: OpenAIRealtimeWebSocket, throwable: Throawble)
fun onFunctionCall(webSocket: OpenAIRealtimeWebSocket, functionCall: FunctionCallRequest)
fun onConnected(webSocket: OpenAIRealtimeWebSocket)
fun onSpeechStarted(webSocket: OpenAIRealtimeWebSocket)
fun onSpeechStopped(webSocket: OpenAIRealtimeWebSocket)
fun onResponseDone(webSocket: OpenAIRealtimeWebSocket)
fun onResponseAudio(webSocket: OpenAIRealtimeWebSocket, data: ByteArray)
fun onResponseAudioDone(webSocket: OpenAIRealtimeWebSocket)
fun onSessionCreated(webSocket: OpenAIRealtimeWebSocket)
fun onResponseTextDone(socket: OpenAIRealtimeWebSocket, transcript: String)
fun onResponseTextDelta(socket: OpenAIRealtimeWebSocket, delta: String)
fun onResponseAudioTranscriptDone(socket: OpenAIRealtimeWebSocket, transcript: String)
fun onResponseAudioTranscriptDelta(socket: OpenAIRealtimeWebSocket, delta: String)
}
/**
* WebSocket client for the OpenAI Realtime API.
*/
class OpenAIRealtimeWebSocket(val listener: OpenAIRealtimeWebSocketListener) {
companion object {
const val TAG = "OpenAIRealtimeWebSocket"
}
val url = "wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview";
private val apiKey = getRemoteConfigString("openai")
private val client = OkHttpClient()
private lateinit var ws: WebSocket
private var connected = false
val isConnected: Boolean
get() = connected
private fun send(text: String) {
Log.d(TAG, "sending $text")
ws.send(text)
}
/**
* Handle output item from the OpenAI Realtime API.
*/
private fun handleOutputItem(output: JSONObject) {
output.optJSONObject("item")?.let { item ->
val type = item.optString("type")
if (type == "message") { // test response
if (item.optString("role") == "assistant" && item.optString("status") == "completed") {
val content = item.optJSONArray("content")
content?.let {
val text = it.optJSONObject(0)?.optString("text")
text?.let {
listener.onResponseTextDone(this@OpenAIRealtimeWebSocket, it)
}
}
}
}
if (type == "function_call") { // function call request
val name = item.optString("name")
val id = item.optString("call_id")
val arguments = item.optString("arguments")
val fc = FunctionCallRequest(name, arguments, id)
try {
var json = JSONObject(arguments)
fc.json = json
} catch (e: Exception) {
Log.i(TAG, "Error parsing arguments: $arguments", e)
}
listener.onFunctionCall(this@OpenAIRealtimeWebSocket, fc)
}
}
}
/**
* Connect to the OpenAI Realtime API web socket server.
*/
fun connect() {
val request = Request.Builder().url(url)
.header("Authorization", "Bearer $apiKey")
.header("OpenAI-Beta", "realtime=v1")
.build()
ws = client.newWebSocket(request, object : WebSocketListener() {
/**
* Called when the connection to the server is established.
*/
override fun onOpen(webSocket: WebSocket, response: Response) {
Log.i(TAG, "Connection opened")
connected = true
setupSession()
listener.onConnected(this@OpenAIRealtimeWebSocket)
}
/**
* Called when a message is received from the server.
*/
override fun onMessage(webSocket: WebSocket, text: String) {
Log.i(TAG, "Received text: $text")
val json = JSONObject(text)
val type = json.optString("type")
when (type) {
"session.created" -> {
listener.onSessionCreated(this@OpenAIRealtimeWebSocket)
}
"session.updated" -> {
//listener.onSpeechStopped(this@OpenAIRealtimeWebSocket)
}
"response.done" -> {
listener.onResponseDone(this@OpenAIRealtimeWebSocket)
}
"response.audio.delta" -> {
val delta = json.optString("delta")
listener.onResponseAudio(this@OpenAIRealtimeWebSocket, delta.decodeBase64Bytes())
}
"response.audio.done" -> {
listener.onResponseAudioDone(this@OpenAIRealtimeWebSocket)
}
"response.audio_transcript.done" -> {
listener.onResponseAudioTranscriptDone(this@OpenAIRealtimeWebSocket, json.optString("transcript"))
}
"response.audio_transcript.delta" -> {
listener.onResponseAudioTranscriptDelta(this@OpenAIRealtimeWebSocket, json.optString("delta"))
}
"response.output_item.done" -> {
handleOutputItem(json)
}
"response.text.delta" -> {
listener.onResponseTextDelta(this@OpenAIRealtimeWebSocket, json.optString("delta"))
}
}
}
/**
* Called when binary data is received from the server.
*/
override fun onMessage(webSocket: WebSocket, bytes: ByteString) {
Log.i(TAG, "Received bytes: ${bytes.hex()}")
onMessage(webSocket, String(bytes.toByteArray(), Charsets.UTF_8))
}
/**
* Called when the connection to the server is closed.
*/
override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
Log.i(TAG, "Connection closing: $code / $reason")
webSocket.close(1000, null)
connected = false
}
/**
* Called when an error occurs during the connection to the server.
*/
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
Log.i(TAG, "Connection failed: ${t.message}")
listener.onException(this@OpenAIRealtimeWebSocket, t)
connected = false
}
})
}
/**
* Setup the session with the OpenAI Realtime API.
*/
fun setupSession() {
val session = JSONObject()
.put("instructions", listener.system)
session.put("voice", listener.voiceId)
val modalities = JSONArray().put("text")
if (listener.voice) {
modalities.put("audio")
val transcription = JSONObject()
transcription.put("model", "whisper-1")
session.put("input_audio_transcription", transcription)
} else {
session.put("input_audio_transcription", null)
}
session.put("modalities", modalities)
val tools = JSONArray()
listener.tools?.forEach { tool ->
tools.put(tool.toJson())
}
if (tools.length() > 0) {
session.put("tools", tools)
session.put("tool_choice", "auto")
}
val content = JSONObject()
.put("type", "session.update")
.put("session", session)
send(content.toString())
}
/**
* Request a new response for the OpenAI Realtime API.
*/
fun createResponse() {
val content = JSONObject()
.put("type", "response.create")
send(content.toString())
}
/**
* Append audio data to the OpenAI Realtime API input buffer.
*/
fun appendAudio(audio: ByteArray) {
val content = JSONObject()
.put("type", "input_audio_buffer.append")
.put("audio", audio.encodeBase64())
send(content.toString())
}
/**
* Commit the audio data to the OpenAI Realtime API input buffer.
* Use for manual turn detection
*/
fun commitAudio() {
val content = JSONObject()
.put("type", "input_audio_buffer.commit")
send(content.toString())
}
/**
* Clear the audio data from the OpenAI Realtime API input buffer.
*/
fun clearAudio() {
val content = JSONObject()
.put("type", "input_audio_buffer.clear")
send(content.toString())
}
/**
* Send the result of a function call to the OpenAI Realtime API.
*/
fun sendResult(functionCall: FunctionCallRequest, result: String) {
Log.i(TAG, "Sending result: $result")
val item = JSONObject()
.put("type", "function_call_output")
.put("call_id", functionCall.id)
.put("output", result)
val json = JSONObject()
.put("type", "conversation.item.create")
.put("item", item)
send(json.toString())
createResponse()
}
/**
* Send a message to the OpenAI Realtime API.
*/
fun sendMessage(message: String, respond : Boolean = true) {
Log.i(TAG, "Sending message: $message")
val content = JSONObject()
.put("type", "input_text")
.put("text", message)
val item = JSONObject()
.put("type", "message")
.put("role", "user")
.put("content", JSONArray().put(content))
val json = JSONObject()
.put("type", "conversation.item.create")
.put("item", item)
send(json.toString())
if (respond) {
createResponse()
}
}
/**
* Close the WebSocket connection.
*/
fun close(): Boolean {
try {
return ws.close(1000, "Closed by user")
} finally {
client.dispatcher.executorService.shutdown()
}
}
}
@Serializable
data class FunctionCallRequest (
@SerialName("name" ) var name : String,
@SerialName("parameters" ) var parameters : String? = null,
@Transient var id : String? = null,
@Transient var extra : Any? = null,
@Transient var json : JSONObject? = null
)
@Serializable
data class FunctionCallResponse (
@SerialName("name" ) var name : String,
@SerialName("content" ) var content : String,
)
@Serializable
data class FunctionCall (
@SerialName("name" ) var name : String,
@SerialName("description" ) var description : String,
@SerialName("parameters" ) var parameters : FunctionParameters? = null
)
@Serializable
data class FunctionParameters (
@SerialName("type" ) var type : String,
@SerialName("properties" ) var properties : Map<String, FunctionParameter>,
@SerialName("required" ) var required : List<String>? = null
)
@Serializable
data class FunctionParameter (
@SerialName("type" ) var type : String,
@SerialName("description" ) var description : String,
@SerialName("enum" ) var enum : List<String>? = null,
@SerialName("items" ) var items : FunctionParameterItemsType? = null
)
@Serializable
data class FunctionParameterItemsType (
@SerialName("type") var type : String
)
fun FunctionCallRequest.getStringParameter(key: String) : String? {
return json?.optString(key)?.takeIf {
it.isNotEmpty()
}
}
fun FunctionCallRequest.getJsonArrayParameter(key: String) : JSONArray? {
return json?.optJSONArray(key)
}
fun FunctionParameters.toJson(): JSONObject {
val json = JSONObject()
json.put("type", "object")
val properties = JSONObject()
this.properties.forEach { (key, value) ->
val property = JSONObject().put("type", value.type).put("description", value.description)
if (value.enum?.isNotEmpty() == true) {
property.put("enum", value.enum?.toJsonArray())
}
value.items?.let {
property.put("items", JSONObject().put("type", it.type))
}
properties.put(key, property)
}
json.put("properties", properties)
if (required?.isEmpty() == false) {
json.put("required", required?.toJsonArray())
}
return json
}
fun FunctionCall.toJson(): JSONObject {
val json = JSONObject()
json.put("type", "function")
json.put("name", name)
json.put("description", description)
json.put("parameters", parameters?.toJson())
return json
}
fun getRemoteConfigString(key: String, defaultValue: String? = null): String? {
val value = FirebaseRemoteConfig.getInstance().getString(key)
return value.ifEmpty { defaultValue }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment