From 384abe5ab28f81d2e2f7ccf3e46a0fed7fced166 Mon Sep 17 00:00:00 2001 From: Gary Tokman Date: Sun, 22 Oct 2023 11:59:00 -0400 Subject: [PATCH] feat: add support for chat on android --- android/build.gradle | 2 +- .../ReactNativeOpenaiModule.kt | 94 ++++++++++++++++++- example/android/build.gradle | 4 + example/ios/Podfile.lock | 4 +- src/index.tsx | 7 +- 5 files changed, 103 insertions(+), 8 deletions(-) diff --git a/android/build.gradle b/android/build.gradle index 9f9b930..3e25e74 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -1,6 +1,6 @@ buildscript { // Buildscript is evaluated before everything else so we can't use getExtOrDefault - def kotlin_version = '1.9.0' + def kotlin_version = '1.7.22' repositories { google() diff --git a/android/src/main/java/com/candlefinance/reactnativeopenai/ReactNativeOpenaiModule.kt b/android/src/main/java/com/candlefinance/reactnativeopenai/ReactNativeOpenaiModule.kt index ca66310..b4e23f0 100644 --- a/android/src/main/java/com/candlefinance/reactnativeopenai/ReactNativeOpenaiModule.kt +++ b/android/src/main/java/com/candlefinance/reactnativeopenai/ReactNativeOpenaiModule.kt @@ -1,6 +1,5 @@ package com.candlefinance.reactnativeopenai -import com.aallam.openai.api.BetaOpenAI import com.aallam.openai.api.chat.ChatCompletionChunk import com.aallam.openai.api.chat.ChatCompletionRequest import com.aallam.openai.api.chat.ChatMessage @@ -11,13 +10,14 @@ import com.aallam.openai.client.OpenAI import com.aallam.openai.client.OpenAIConfig import com.aallam.openai.client.OpenAIHost import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.Promise import com.facebook.react.bridge.ReactApplicationContext import com.facebook.react.bridge.ReactContextBaseJavaModule import com.facebook.react.bridge.ReactMethod +import com.facebook.react.bridge.ReadableArray import com.facebook.react.bridge.ReadableMap import com.facebook.react.bridge.WritableMap import com.facebook.react.modules.core.DeviceEventManagerModule -import io.ktor.client.utils.EmptyContent.headers import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job @@ -83,7 +83,6 @@ class ReactNativeOpenaiModule(reactContext: ReactApplicationContext) : this.openAIClient = OpenAI(config) } - @OptIn(BetaOpenAI::class) @ReactMethod fun stream(input: ReadableMap) { val model = input.getString("model") @@ -116,6 +115,8 @@ class ReactNativeOpenaiModule(reactContext: ReactApplicationContext) : presencePenalty = presencePenalty, frequencyPenalty = frequencyPenalty, user = user, + stop = toList(stops), + logitBias = toMap(logitBias) ) runBlocking { job = scope.launch { @@ -143,6 +144,93 @@ class ReactNativeOpenaiModule(reactContext: ReactApplicationContext) : } } + @ReactMethod + fun create(input: ReadableMap, promise: Promise) { + val model = input.getString("model") + val messages = input.getArray("messages") + val temperature = if (input.hasKey("temperature")) input.getDouble("temperature") else null + val topP = if (input.hasKey("topP")) input.getDouble("topP") else null + val n = if (input.hasKey("n")) input.getInt("n") else null + val stops = if (input.hasKey("stops")) input.getArray("stops") else null + val maxTokens = if (input.hasKey("maxTokens")) input.getInt("maxTokens") else null + val presencePenalty = if (input.hasKey("presencePenalty")) input.getDouble("presencePenalty") else null + val frequencyPenalty = if (input.hasKey("frequencyPenalty")) input.getDouble("frequencyPenalty") else null + val logitBias = if (input.hasKey("logitBias")) input.getMap("logitBias") else null + val user = if (input.hasKey("user")) input.getString("user") else null + val m = messages?.toArrayList()?.map { it -> + val role: String = (it as HashMap).get("role") ?: "user" + val content: String = it.get("content") as String + ChatMessage( + role = ChatRole(role), + content = content + ) + } ?: emptyList() + + val chatCompletionRequest = ChatCompletionRequest( + model = ModelId(model ?: "gpt-3.5-turbo"), + messages = m, + maxTokens = maxTokens, + temperature = temperature, + topP = topP, + n = n, + presencePenalty = presencePenalty, + frequencyPenalty = frequencyPenalty, + user = user, + stop = toList(stops), + logitBias = toMap(logitBias) + ) + runBlocking { + job = scope.launch { + val completion = openAIClient?.chatCompletion(chatCompletionRequest) + val map = mapOf( + "id" to (completion?.id ?: ""), + "created" to (completion?.created ?: ""), + "model" to (completion?.model?.id ?: "$model"), + "object" to "chat.completions", + "choices" to (completion?.choices?.map { + mapOf( + "message" to mapOf( + "content" to (it.message?.content ?: ""), + "role" to it.message?.role.toString() + ), + "index" to it.index, + "finishReason" to (it.finishReason.toString() ?: "stop") + ) + } ?: {}), + "usage" to mapOf( + "promptTokens" to (completion?.usage?.promptTokens ?: 0), + "totalTokens" to (completion?.usage?.totalTokens ?: 0), + "completionTokens" to (completion?.usage?.completionTokens ?: 0) + ), + ) + val toReadableMap = Arguments.makeNativeMap(map) + promise.resolve(toReadableMap) + } + } + } + + private fun toList(array: ReadableArray?): List { + val list = mutableListOf() + if (array != null) { + for (i in 0 until array.size()) { + list.add(array.getString(i) ?: "") + } + } + return list + } + + private fun toMap(map: ReadableMap?): Map? { + val hashMap = mutableMapOf() + if (map != null) { + val iterator = map.keySetIterator() + while (iterator.hasNextKey()) { + val key = iterator.nextKey() + hashMap[key] = map.getInt(key) + } + } + return hashMap + } + private fun dispatch(action: String, payload: Map) { val map = mapOf( "type" to action, diff --git a/example/android/build.gradle b/example/android/build.gradle index 019c627..1089aeb 100644 --- a/example/android/build.gradle +++ b/example/android/build.gradle @@ -19,3 +19,7 @@ buildscript { classpath("com.facebook.react:react-native-gradle-plugin") } } + +plugins { + id "org.jetbrains.kotlin.android" version "1.9.0" apply false +} diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock index deea1af..f420459 100644 --- a/example/ios/Podfile.lock +++ b/example/ios/Podfile.lock @@ -448,7 +448,7 @@ PODS: - React-jsi (= 0.72.3) - React-logger (= 0.72.3) - React-perflogger (= 0.72.3) - - ReactNativeOpenAI (0.3.0): + - ReactNativeOpenAI (0.4.0): - React-Core - SocketRocket (0.6.1) - Yoga (1.14.0) @@ -666,7 +666,7 @@ SPEC CHECKSUMS: React-runtimescheduler: af0b24628c1d543a3f87251c9efa29c5a589e08a React-utils: bcb57da67eec2711f8b353f6e3d33bd8e4b2efa3 ReactCommon: d7d63a5b3c3ff29304a58fc8eb3b4f1b077cd789 - ReactNativeOpenAI: c1cfbd3c7de80c94b55401268b78cac86c1cf03e + ReactNativeOpenAI: f2e3b4e5d6dae755f0e687d0035c582a9fc7c653 SocketRocket: f32cd54efbe0f095c4d7594881e52619cfe80b17 Yoga: 8796b55dba14d7004f980b54bcc9833ee45b28ce YogaKit: f782866e155069a2cca2517aafea43200b01fd5a diff --git a/src/index.tsx b/src/index.tsx index 2bdc709..52d57f8 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -1,4 +1,4 @@ -import { NativeEventEmitter, NativeModules } from 'react-native'; +import { NativeEventEmitter, NativeModules, Platform } from 'react-native'; export type Config = | { @@ -129,7 +129,10 @@ class Chat { input: ChatModels.StreamInput ): Promise { const result = await this.module.create(input); - return JSON.parse(result); + if (Platform.OS === 'ios') { + return JSON.parse(result); + } + return result; } public addListener(