Forked from paulotaylor/OpenAIRealtimeWebSocket.kt
Created
November 19, 2024 06:28
-
-
Save gorshkov-leonid/a7464721457015af6713a747a4981f7f to your computer and use it in GitHub Desktop.
Simple Kotlin OpenAI Websocket implementation of the Realtime API
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import android.util.Log | |
import io.ktor.util.decodeBase64Bytes | |
import io.ktor.util.encodeBase64 | |
import okhttp3.OkHttpClient | |
import okhttp3.Request | |
import okhttp3.Response | |
import okhttp3.WebSocket | |
import okhttp3.WebSocketListener | |
import okio.ByteString | |
import org.json.JSONArray | |
import org.json.JSONObject | |
/** | |
* Listener for the OpenAI Realtime API. | |
*/ | |
interface OpenAIRealtimeWebSocketListener { | |
val system: String | |
val voice: Boolean | |
val voiceId: String | |
val tools: List<FunctionCall>? | |
fun onException(webSocket: OpenAIRealtimeWebSocket, throwable: Throawble) | |
fun onFunctionCall(webSocket: OpenAIRealtimeWebSocket, functionCall: FunctionCallRequest) | |
fun onConnected(webSocket: OpenAIRealtimeWebSocket) | |
fun onSpeechStarted(webSocket: OpenAIRealtimeWebSocket) | |
fun onSpeechStopped(webSocket: OpenAIRealtimeWebSocket) | |
fun onResponseDone(webSocket: OpenAIRealtimeWebSocket) | |
fun onResponseAudio(webSocket: OpenAIRealtimeWebSocket, data: ByteArray) | |
fun onResponseAudioDone(webSocket: OpenAIRealtimeWebSocket) | |
fun onSessionCreated(webSocket: OpenAIRealtimeWebSocket) | |
fun onResponseTextDone(socket: OpenAIRealtimeWebSocket, transcript: String) | |
fun onResponseTextDelta(socket: OpenAIRealtimeWebSocket, delta: String) | |
fun onResponseAudioTranscriptDone(socket: OpenAIRealtimeWebSocket, transcript: String) | |
fun onResponseAudioTranscriptDelta(socket: OpenAIRealtimeWebSocket, delta: String) | |
} | |
/** | |
* WebSocket client for the OpenAI Realtime API. | |
*/ | |
class OpenAIRealtimeWebSocket(val listener: OpenAIRealtimeWebSocketListener) { | |
companion object { | |
const val TAG = "OpenAIRealtimeWebSocket" | |
} | |
val url = "wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview"; | |
private val apiKey = getRemoteConfigString("openai") | |
private val client = OkHttpClient() | |
private lateinit var ws: WebSocket | |
private var connected = false | |
val isConnected: Boolean | |
get() = connected | |
private fun send(text: String) { | |
Log.d(TAG, "sending $text") | |
ws.send(text) | |
} | |
/** | |
* Handle output item from the OpenAI Realtime API. | |
*/ | |
private fun handleOutputItem(output: JSONObject) { | |
output.optJSONObject("item")?.let { item -> | |
val type = item.optString("type") | |
if (type == "message") { // test response | |
if (item.optString("role") == "assistant" && item.optString("status") == "completed") { | |
val content = item.optJSONArray("content") | |
content?.let { | |
val text = it.optJSONObject(0)?.optString("text") | |
text?.let { | |
listener.onResponseTextDone(this@OpenAIRealtimeWebSocket, it) | |
} | |
} | |
} | |
} | |
if (type == "function_call") { // function call request | |
val name = item.optString("name") | |
val id = item.optString("call_id") | |
val arguments = item.optString("arguments") | |
val fc = FunctionCallRequest(name, arguments, id) | |
try { | |
var json = JSONObject(arguments) | |
fc.json = json | |
} catch (e: Exception) { | |
Log.i(TAG, "Error parsing arguments: $arguments", e) | |
} | |
listener.onFunctionCall(this@OpenAIRealtimeWebSocket, fc) | |
} | |
} | |
} | |
/** | |
* Connect to the OpenAI Realtime API web socket server. | |
*/ | |
fun connect() { | |
val request = Request.Builder().url(url) | |
.header("Authorization", "Bearer $apiKey") | |
.header("OpenAI-Beta", "realtime=v1") | |
.build() | |
ws = client.newWebSocket(request, object : WebSocketListener() { | |
/** | |
* Called when the connection to the server is established. | |
*/ | |
override fun onOpen(webSocket: WebSocket, response: Response) { | |
Log.i(TAG, "Connection opened") | |
connected = true | |
setupSession() | |
listener.onConnected(this@OpenAIRealtimeWebSocket) | |
} | |
/** | |
* Called when a message is received from the server. | |
*/ | |
override fun onMessage(webSocket: WebSocket, text: String) { | |
Log.i(TAG, "Received text: $text") | |
val json = JSONObject(text) | |
val type = json.optString("type") | |
when (type) { | |
"session.created" -> { | |
listener.onSessionCreated(this@OpenAIRealtimeWebSocket) | |
} | |
"session.updated" -> { | |
//listener.onSpeechStopped(this@OpenAIRealtimeWebSocket) | |
} | |
"response.done" -> { | |
listener.onResponseDone(this@OpenAIRealtimeWebSocket) | |
} | |
"response.audio.delta" -> { | |
val delta = json.optString("delta") | |
listener.onResponseAudio(this@OpenAIRealtimeWebSocket, delta.decodeBase64Bytes()) | |
} | |
"response.audio.done" -> { | |
listener.onResponseAudioDone(this@OpenAIRealtimeWebSocket) | |
} | |
"response.audio_transcript.done" -> { | |
listener.onResponseAudioTranscriptDone(this@OpenAIRealtimeWebSocket, json.optString("transcript")) | |
} | |
"response.audio_transcript.delta" -> { | |
listener.onResponseAudioTranscriptDelta(this@OpenAIRealtimeWebSocket, json.optString("delta")) | |
} | |
"response.output_item.done" -> { | |
handleOutputItem(json) | |
} | |
"response.text.delta" -> { | |
listener.onResponseTextDelta(this@OpenAIRealtimeWebSocket, json.optString("delta")) | |
} | |
} | |
} | |
/** | |
* Called when binary data is received from the server. | |
*/ | |
override fun onMessage(webSocket: WebSocket, bytes: ByteString) { | |
Log.i(TAG, "Received bytes: ${bytes.hex()}") | |
onMessage(webSocket, String(bytes.toByteArray(), Charsets.UTF_8)) | |
} | |
/** | |
* Called when the connection to the server is closed. | |
*/ | |
override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { | |
Log.i(TAG, "Connection closing: $code / $reason") | |
webSocket.close(1000, null) | |
connected = false | |
} | |
/** | |
* Called when an error occurs during the connection to the server. | |
*/ | |
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { | |
Log.i(TAG, "Connection failed: ${t.message}") | |
listener.onException(this@OpenAIRealtimeWebSocket, t) | |
connected = false | |
} | |
}) | |
} | |
/** | |
* Setup the session with the OpenAI Realtime API. | |
*/ | |
fun setupSession() { | |
val session = JSONObject() | |
.put("instructions", listener.system) | |
session.put("voice", listener.voiceId) | |
val modalities = JSONArray().put("text") | |
if (listener.voice) { | |
modalities.put("audio") | |
val transcription = JSONObject() | |
transcription.put("model", "whisper-1") | |
session.put("input_audio_transcription", transcription) | |
} else { | |
session.put("input_audio_transcription", null) | |
} | |
session.put("modalities", modalities) | |
val tools = JSONArray() | |
listener.tools?.forEach { tool -> | |
tools.put(tool.toJson()) | |
} | |
if (tools.length() > 0) { | |
session.put("tools", tools) | |
session.put("tool_choice", "auto") | |
} | |
val content = JSONObject() | |
.put("type", "session.update") | |
.put("session", session) | |
send(content.toString()) | |
} | |
/** | |
* Request a new response for the OpenAI Realtime API. | |
*/ | |
fun createResponse() { | |
val content = JSONObject() | |
.put("type", "response.create") | |
send(content.toString()) | |
} | |
/** | |
* Append audio data to the OpenAI Realtime API input buffer. | |
*/ | |
fun appendAudio(audio: ByteArray) { | |
val content = JSONObject() | |
.put("type", "input_audio_buffer.append") | |
.put("audio", audio.encodeBase64()) | |
send(content.toString()) | |
} | |
/** | |
* Commit the audio data to the OpenAI Realtime API input buffer. | |
* Use for manual turn detection | |
*/ | |
fun commitAudio() { | |
val content = JSONObject() | |
.put("type", "input_audio_buffer.commit") | |
send(content.toString()) | |
} | |
/** | |
* Clear the audio data from the OpenAI Realtime API input buffer. | |
*/ | |
fun clearAudio() { | |
val content = JSONObject() | |
.put("type", "input_audio_buffer.clear") | |
send(content.toString()) | |
} | |
/** | |
* Send the result of a function call to the OpenAI Realtime API. | |
*/ | |
fun sendResult(functionCall: FunctionCallRequest, result: String) { | |
Log.i(TAG, "Sending result: $result") | |
val item = JSONObject() | |
.put("type", "function_call_output") | |
.put("call_id", functionCall.id) | |
.put("output", result) | |
val json = JSONObject() | |
.put("type", "conversation.item.create") | |
.put("item", item) | |
send(json.toString()) | |
createResponse() | |
} | |
/** | |
* Send a message to the OpenAI Realtime API. | |
*/ | |
fun sendMessage(message: String, respond : Boolean = true) { | |
Log.i(TAG, "Sending message: $message") | |
val content = JSONObject() | |
.put("type", "input_text") | |
.put("text", message) | |
val item = JSONObject() | |
.put("type", "message") | |
.put("role", "user") | |
.put("content", JSONArray().put(content)) | |
val json = JSONObject() | |
.put("type", "conversation.item.create") | |
.put("item", item) | |
send(json.toString()) | |
if (respond) { | |
createResponse() | |
} | |
} | |
/** | |
* Close the WebSocket connection. | |
*/ | |
fun close(): Boolean { | |
try { | |
return ws.close(1000, "Closed by user") | |
} finally { | |
client.dispatcher.executorService.shutdown() | |
} | |
} | |
} | |
@Serializable | |
data class FunctionCallRequest ( | |
@SerialName("name" ) var name : String, | |
@SerialName("parameters" ) var parameters : String? = null, | |
@Transient var id : String? = null, | |
@Transient var extra : Any? = null, | |
@Transient var json : JSONObject? = null | |
) | |
@Serializable | |
data class FunctionCallResponse ( | |
@SerialName("name" ) var name : String, | |
@SerialName("content" ) var content : String, | |
) | |
@Serializable | |
data class FunctionCall ( | |
@SerialName("name" ) var name : String, | |
@SerialName("description" ) var description : String, | |
@SerialName("parameters" ) var parameters : FunctionParameters? = null | |
) | |
@Serializable | |
data class FunctionParameters ( | |
@SerialName("type" ) var type : String, | |
@SerialName("properties" ) var properties : Map<String, FunctionParameter>, | |
@SerialName("required" ) var required : List<String>? = null | |
) | |
@Serializable | |
data class FunctionParameter ( | |
@SerialName("type" ) var type : String, | |
@SerialName("description" ) var description : String, | |
@SerialName("enum" ) var enum : List<String>? = null, | |
@SerialName("items" ) var items : FunctionParameterItemsType? = null | |
) | |
@Serializable | |
data class FunctionParameterItemsType ( | |
@SerialName("type") var type : String | |
) | |
fun FunctionCallRequest.getStringParameter(key: String) : String? { | |
return json?.optString(key)?.takeIf { | |
it.isNotEmpty() | |
} | |
} | |
fun FunctionCallRequest.getJsonArrayParameter(key: String) : JSONArray? { | |
return json?.optJSONArray(key) | |
} | |
fun FunctionParameters.toJson(): JSONObject { | |
val json = JSONObject() | |
json.put("type", "object") | |
val properties = JSONObject() | |
this.properties.forEach { (key, value) -> | |
val property = JSONObject().put("type", value.type).put("description", value.description) | |
if (value.enum?.isNotEmpty() == true) { | |
property.put("enum", value.enum?.toJsonArray()) | |
} | |
value.items?.let { | |
property.put("items", JSONObject().put("type", it.type)) | |
} | |
properties.put(key, property) | |
} | |
json.put("properties", properties) | |
if (required?.isEmpty() == false) { | |
json.put("required", required?.toJsonArray()) | |
} | |
return json | |
} | |
fun FunctionCall.toJson(): JSONObject { | |
val json = JSONObject() | |
json.put("type", "function") | |
json.put("name", name) | |
json.put("description", description) | |
json.put("parameters", parameters?.toJson()) | |
return json | |
} | |
fun getRemoteConfigString(key: String, defaultValue: String? = null): String? { | |
val value = FirebaseRemoteConfig.getInstance().getString(key) | |
return value.ifEmpty { defaultValue } | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment