From 55922b716b96ce63c96d02e090b698e06d6f5c4b Mon Sep 17 00:00:00 2001 From: Gary Tokman Date: Sat, 21 Oct 2023 13:14:04 -0400 Subject: [PATCH] feat: custom domain support --- .../ReactNativeOpenaiModule.kt | 29 +++++++++++++----- example/ios/Podfile.lock | 4 +-- example/src/App.tsx | 9 +++++- ios/OpenAIKit/Chat/ChatProvider.swift | 2 -- ios/ReactNativeOpenai.mm | 2 +- ios/ReactNativeOpenai.swift | 30 ++++++++++++++++--- src/index.tsx | 17 +++++++++-- 7 files changed, 74 insertions(+), 19 deletions(-) diff --git a/android/src/main/java/com/candlefinance/reactnativeopenai/ReactNativeOpenaiModule.kt b/android/src/main/java/com/candlefinance/reactnativeopenai/ReactNativeOpenaiModule.kt index c4f292d..ca66310 100644 --- a/android/src/main/java/com/candlefinance/reactnativeopenai/ReactNativeOpenaiModule.kt +++ b/android/src/main/java/com/candlefinance/reactnativeopenai/ReactNativeOpenaiModule.kt @@ -1,5 +1,6 @@ package com.candlefinance.reactnativeopenai +import com.aallam.openai.api.BetaOpenAI import com.aallam.openai.api.chat.ChatCompletionChunk import com.aallam.openai.api.chat.ChatCompletionRequest import com.aallam.openai.api.chat.ChatMessage @@ -8,6 +9,7 @@ import com.aallam.openai.api.http.Timeout import com.aallam.openai.api.model.ModelId import com.aallam.openai.client.OpenAI import com.aallam.openai.client.OpenAIConfig +import com.aallam.openai.client.OpenAIHost import com.facebook.react.bridge.Arguments import com.facebook.react.bridge.ReactApplicationContext import com.facebook.react.bridge.ReactContextBaseJavaModule @@ -15,6 +17,7 @@ import com.facebook.react.bridge.ReactMethod import com.facebook.react.bridge.ReadableMap import com.facebook.react.bridge.WritableMap import com.facebook.react.modules.core.DeviceEventManagerModule +import io.ktor.client.utils.EmptyContent.headers import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job @@ -59,16 +62,28 @@ class ReactNativeOpenaiModule(reactContext: ReactApplicationContext) : private var openAIClient: OpenAI? = null @ReactMethod - fun initialize(apiKey: String, organization: String) { - println("Initializing client with $apiKey and org $organization") + fun initialize(config: ReadableMap) { + val apiKey = if (config.hasKey("apiKey")) config.getString("apiKey") else null + val organization = if (config.hasKey("organization")) config.getString("organization") else null + val scheme = if (config.hasKey("scheme")) config.getString("scheme") else null + val baseUrl = if (config.hasKey("host")) config.getString("host") else null + val pathPrefix = if (config.hasKey("pathPrefix")) config.getString("pathPrefix") else null + val host = baseUrl?.let { + OpenAIHost( + baseUrl = "${scheme ?: "https"}://${it}/${pathPrefix ?: "v1"}/" + ) + } + println(host) val config = OpenAIConfig( - token = apiKey, + token = apiKey ?: "", organization = organization, - timeout = Timeout(socket = 60.seconds) + timeout = Timeout(socket = 60.seconds), + host = host ?: OpenAIHost.OpenAI ) this.openAIClient = OpenAI(config) } + @OptIn(BetaOpenAI::class) @ReactMethod fun stream(input: ReadableMap) { val model = input.getString("model") @@ -114,11 +129,11 @@ class ReactNativeOpenaiModule(reactContext: ReactApplicationContext) : "choices" to (completion.choices?.map { mapOf( "delta" to mapOf( - "content" to it.delta.content, - "role" to it.delta.role.toString() + "content" to (it.delta?.content ?: ""), + "role" to it.delta?.role.toString() ), "index" to it.index, - "finishReason" to (it.finishReason?.value ?: "stop") + "finishReason" to (it.finishReason ?: "stop") ) } ?: {}), ) diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock index b62cf7c..deea1af 100644 --- a/example/ios/Podfile.lock +++ b/example/ios/Podfile.lock @@ -448,7 +448,7 @@ PODS: - React-jsi (= 0.72.3) - React-logger (= 0.72.3) - React-perflogger (= 0.72.3) - - ReactNativeOpenAI (0.1.0): + - ReactNativeOpenAI (0.3.0): - React-Core - SocketRocket (0.6.1) - Yoga (1.14.0) @@ -666,7 +666,7 @@ SPEC CHECKSUMS: React-runtimescheduler: af0b24628c1d543a3f87251c9efa29c5a589e08a React-utils: bcb57da67eec2711f8b353f6e3d33bd8e4b2efa3 ReactCommon: d7d63a5b3c3ff29304a58fc8eb3b4f1b077cd789 - ReactNativeOpenAI: 93da4285ad4a32ab357523643ec91e8d4d1513e8 + ReactNativeOpenAI: c1cfbd3c7de80c94b55401268b78cac86c1cf03e SocketRocket: f32cd54efbe0f095c4d7594881e52619cfe80b17 Yoga: 8796b55dba14d7004f980b54bcc9833ee45b28ce YogaKit: f782866e155069a2cca2517aafea43200b01fd5a diff --git a/example/src/App.tsx b/example/src/App.tsx index 46e2bd9..c86803b 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -18,7 +18,14 @@ const AnimatedTextInput = Animated.createAnimatedComponent(TextInput); export default function App() { const scheme = useColorScheme(); const [result, setResult] = React.useState(''); - const openAI = React.useMemo(() => new OpenAI('', ''), []); + const openAI = React.useMemo( + () => + new OpenAI({ + apiKey: 'YOUR_API_KEY', + organization: 'YOUR_ORGANIZATION', + }), + [] + ); const yPosition = React.useRef(new Animated.Value(0)).current; diff --git a/ios/OpenAIKit/Chat/ChatProvider.swift b/ios/OpenAIKit/Chat/ChatProvider.swift index afab050..7108a59 100644 --- a/ios/OpenAIKit/Chat/ChatProvider.swift +++ b/ios/OpenAIKit/Chat/ChatProvider.swift @@ -88,8 +88,6 @@ public struct ChatProvider { logitBias: logitBias, user: user ) - return try await requestHandler.stream(request: request) - } } diff --git a/ios/ReactNativeOpenai.mm b/ios/ReactNativeOpenai.mm index 63e97f2..fb781f7 100644 --- a/ios/ReactNativeOpenai.mm +++ b/ios/ReactNativeOpenai.mm @@ -5,7 +5,7 @@ @interface RCT_EXTERN_MODULE(ReactNativeOpenai, RCTEventEmitter) // API RCT_EXTERN_METHOD(supportedEvents) -RCT_EXTERN_METHOD(initialize:(NSString *)apiKey organization:(NSString *)organization) +RCT_EXTERN_METHOD(initialize:(NSDictionary *)config) // Chat RCT_EXTERN_METHOD(stream:(NSDictionary *)input) diff --git a/ios/ReactNativeOpenai.swift b/ios/ReactNativeOpenai.swift index 9bd6b35..5b2c148 100644 --- a/ios/ReactNativeOpenai.swift +++ b/ios/ReactNativeOpenai.swift @@ -19,9 +19,31 @@ final class ReactNativeOpenai: RCTEventEmitter { Self.emitter = self } - @objc(initialize:organization:) - public func initialize(apiKey: String, organization: String) { - self.configuration = Configuration(apiKey: apiKey, organization: organization) + struct Config: Codable { + let apiKey: String? + let organization: String? + let scheme: String? + let host: String? + let pathPrefix: String? + } + + @objc(initialize:) + public func initialize(config: NSDictionary) { + do { + let decoded = try DictionaryDecoder().decode(Config.self, from: config) + var api: API? + if let host = decoded.host { + let scheme = decoded.scheme != nil ? API.Scheme.custom(decoded.scheme!) : .https + api = API(scheme: scheme, host: host, pathPrefix: decoded.pathPrefix) + } + self.configuration = Configuration( + apiKey: decoded.apiKey ?? "", + organization: decoded.organization ?? "", + api: api + ) + } catch { + print("Error:", error.localizedDescription) + } } override public static func requiresMainQueueSetup() -> Bool { @@ -128,5 +150,5 @@ extension ReactNativeOpenai { reject("error", "error", error) } } - } + } } diff --git a/src/index.tsx b/src/index.tsx index d86a0c4..2bdc709 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -1,13 +1,26 @@ import { NativeEventEmitter, NativeModules } from 'react-native'; +export type Config = + | { + apiKey: string; + organization: string; + } + | { + apiKey?: string; + organization?: string; + scheme?: string; + host: string; + pathPrefix?: string; + }; + class OpenAI { module = NativeModules.ReactNativeOpenai; private bridge: NativeEventEmitter; public chat: Chat; - public constructor(apiKey: string, organization: string) { + public constructor(config: Config) { this.bridge = new NativeEventEmitter(this.module); - this.module.initialize(apiKey, organization); + this.module.initialize(config); this.chat = new Chat(this.module, this.bridge); } }