Skip to content

Commit

Permalink
feat: add support for chat on android
Browse files Browse the repository at this point in the history
  • Loading branch information
gtokman committed Oct 22, 2023
1 parent d362e3a commit 384abe5
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 8 deletions.
2 changes: 1 addition & 1 deletion android/build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
buildscript {
// Buildscript is evaluated before everything else so we can't use getExtOrDefault
def kotlin_version = '1.9.0'
def kotlin_version = '1.7.22'

repositories {
google()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.candlefinance.reactnativeopenai

import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.chat.ChatCompletionChunk
import com.aallam.openai.api.chat.ChatCompletionRequest
import com.aallam.openai.api.chat.ChatMessage
Expand All @@ -11,13 +10,14 @@ import com.aallam.openai.client.OpenAI
import com.aallam.openai.client.OpenAIConfig
import com.aallam.openai.client.OpenAIHost
import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReactContextBaseJavaModule
import com.facebook.react.bridge.ReactMethod
import com.facebook.react.bridge.ReadableArray
import com.facebook.react.bridge.ReadableMap
import com.facebook.react.bridge.WritableMap
import com.facebook.react.modules.core.DeviceEventManagerModule
import io.ktor.client.utils.EmptyContent.headers
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
Expand Down Expand Up @@ -83,7 +83,6 @@ class ReactNativeOpenaiModule(reactContext: ReactApplicationContext) :
this.openAIClient = OpenAI(config)
}

@OptIn(BetaOpenAI::class)
@ReactMethod
fun stream(input: ReadableMap) {
val model = input.getString("model")
Expand Down Expand Up @@ -116,6 +115,8 @@ class ReactNativeOpenaiModule(reactContext: ReactApplicationContext) :
presencePenalty = presencePenalty,
frequencyPenalty = frequencyPenalty,
user = user,
stop = toList(stops),
logitBias = toMap(logitBias)
)
runBlocking {
job = scope.launch {
Expand Down Expand Up @@ -143,6 +144,93 @@ class ReactNativeOpenaiModule(reactContext: ReactApplicationContext) :
}
}

@ReactMethod
fun create(input: ReadableMap, promise: Promise) {
val model = input.getString("model")
val messages = input.getArray("messages")
val temperature = if (input.hasKey("temperature")) input.getDouble("temperature") else null
val topP = if (input.hasKey("topP")) input.getDouble("topP") else null
val n = if (input.hasKey("n")) input.getInt("n") else null
val stops = if (input.hasKey("stops")) input.getArray("stops") else null
val maxTokens = if (input.hasKey("maxTokens")) input.getInt("maxTokens") else null
val presencePenalty = if (input.hasKey("presencePenalty")) input.getDouble("presencePenalty") else null
val frequencyPenalty = if (input.hasKey("frequencyPenalty")) input.getDouble("frequencyPenalty") else null
val logitBias = if (input.hasKey("logitBias")) input.getMap("logitBias") else null
val user = if (input.hasKey("user")) input.getString("user") else null
val m = messages?.toArrayList()?.map { it ->
val role: String = (it as HashMap<String, String>).get("role") ?: "user"
val content: String = it.get("content") as String
ChatMessage(
role = ChatRole(role),
content = content
)
} ?: emptyList()

val chatCompletionRequest = ChatCompletionRequest(
model = ModelId(model ?: "gpt-3.5-turbo"),
messages = m,
maxTokens = maxTokens,
temperature = temperature,
topP = topP,
n = n,
presencePenalty = presencePenalty,
frequencyPenalty = frequencyPenalty,
user = user,
stop = toList(stops),
logitBias = toMap(logitBias)
)
runBlocking {
job = scope.launch {
val completion = openAIClient?.chatCompletion(chatCompletionRequest)
val map = mapOf(
"id" to (completion?.id ?: ""),
"created" to (completion?.created ?: ""),
"model" to (completion?.model?.id ?: "$model"),
"object" to "chat.completions",
"choices" to (completion?.choices?.map {
mapOf(
"message" to mapOf(
"content" to (it.message?.content ?: ""),
"role" to it.message?.role.toString()
),
"index" to it.index,
"finishReason" to (it.finishReason.toString() ?: "stop")
)
} ?: {}),
"usage" to mapOf(
"promptTokens" to (completion?.usage?.promptTokens ?: 0),
"totalTokens" to (completion?.usage?.totalTokens ?: 0),
"completionTokens" to (completion?.usage?.completionTokens ?: 0)
),
)
val toReadableMap = Arguments.makeNativeMap(map)
promise.resolve(toReadableMap)
}
}
}

private fun toList(array: ReadableArray?): List<String> {
val list = mutableListOf<String>()
if (array != null) {
for (i in 0 until array.size()) {
list.add(array.getString(i) ?: "")
}
}
return list
}

private fun toMap(map: ReadableMap?): Map<String, Int>? {
val hashMap = mutableMapOf<String, Int>()
if (map != null) {
val iterator = map.keySetIterator()
while (iterator.hasNextKey()) {
val key = iterator.nextKey()
hashMap[key] = map.getInt(key)
}
}
return hashMap
}

private fun dispatch(action: String, payload: Map<String, Any?>) {
val map = mapOf(
"type" to action,
Expand Down
4 changes: 4 additions & 0 deletions example/android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ buildscript {
classpath("com.facebook.react:react-native-gradle-plugin")
}
}

plugins {
id "org.jetbrains.kotlin.android" version "1.9.0" apply false
}
4 changes: 2 additions & 2 deletions example/ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ PODS:
- React-jsi (= 0.72.3)
- React-logger (= 0.72.3)
- React-perflogger (= 0.72.3)
- ReactNativeOpenAI (0.3.0):
- ReactNativeOpenAI (0.4.0):
- React-Core
- SocketRocket (0.6.1)
- Yoga (1.14.0)
Expand Down Expand Up @@ -666,7 +666,7 @@ SPEC CHECKSUMS:
React-runtimescheduler: af0b24628c1d543a3f87251c9efa29c5a589e08a
React-utils: bcb57da67eec2711f8b353f6e3d33bd8e4b2efa3
ReactCommon: d7d63a5b3c3ff29304a58fc8eb3b4f1b077cd789
ReactNativeOpenAI: c1cfbd3c7de80c94b55401268b78cac86c1cf03e
ReactNativeOpenAI: f2e3b4e5d6dae755f0e687d0035c582a9fc7c653
SocketRocket: f32cd54efbe0f095c4d7594881e52619cfe80b17
Yoga: 8796b55dba14d7004f980b54bcc9833ee45b28ce
YogaKit: f782866e155069a2cca2517aafea43200b01fd5a
Expand Down
7 changes: 5 additions & 2 deletions src/index.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { NativeEventEmitter, NativeModules } from 'react-native';
import { NativeEventEmitter, NativeModules, Platform } from 'react-native';

export type Config =
| {
Expand Down Expand Up @@ -129,7 +129,10 @@ class Chat {
input: ChatModels.StreamInput
): Promise<ChatModels.CreateOutput> {
const result = await this.module.create(input);
return JSON.parse(result);
if (Platform.OS === 'ios') {
return JSON.parse(result);
}
return result;
}

public addListener(
Expand Down

0 comments on commit 384abe5

Please sign in to comment.