From c15b3f1a15924dca11c5e890213b2ce63cf70fec Mon Sep 17 00:00:00 2001 From: Foad Kesheh Date: Thu, 23 Nov 2023 22:51:11 -0300 Subject: [PATCH 1/3] Fixes Function Execution on GPT-4-Turbo --- .../completion/chat/ChatFunctionCall.java | 3 +- .../openai/utils/TikTokensUtil.java | 10 +++-- .../OpenAiApiDynamicFunctionExample.java | 28 +++++++++--- .../OpenAiApiFunctionsWithStreamExample.java | 18 ++++++-- .../openai/service/FunctionExecutor.java | 10 +++-- .../openai/service/OpenAiService.java | 40 ++++++++++++----- .../openai/service/ChatCompletionTest.java | 44 +++++++++++++------ 7 files changed, 111 insertions(+), 42 deletions(-) diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatFunctionCall.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatFunctionCall.java index 962fbe12..1e47a851 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatFunctionCall.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatFunctionCall.java @@ -1,6 +1,5 @@ package com.theokanning.openai.completion.chat; -import com.fasterxml.jackson.databind.JsonNode; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; @@ -18,6 +17,6 @@ public class ChatFunctionCall { /** * The arguments of the call produced by the model, represented as a JsonNode for easy manipulation. */ - JsonNode arguments; + String arguments; } diff --git a/api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java b/api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java index 0a50907e..017ef43e 100644 --- a/api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java +++ b/api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java @@ -9,7 +9,11 @@ import lombok.AllArgsConstructor; import lombok.Getter; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; /** * Token calculation tool class @@ -173,12 +177,12 @@ public static int tokens(String modelName, List messages) { Encoding encoding = getEncoding(modelName); int tokensPerMessage = 0; int tokensPerName = 0; - //3.5统一处理 + //3.5 if (modelName.equals("gpt-3.5-turbo-0301") || modelName.equals("gpt-3.5-turbo")) { tokensPerMessage = 4; tokensPerName = -1; } - //4.0统一处理 + //4.0 if (modelName.equals("gpt-4") || modelName.equals("gpt-4-0314")) { tokensPerMessage = 3; tokensPerName = 1; diff --git a/example/src/main/java/example/OpenAiApiDynamicFunctionExample.java b/example/src/main/java/example/OpenAiApiDynamicFunctionExample.java index 75f9b8e2..b787e558 100644 --- a/example/src/main/java/example/OpenAiApiDynamicFunctionExample.java +++ b/example/src/main/java/example/OpenAiApiDynamicFunctionExample.java @@ -1,17 +1,30 @@ package example; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; -import com.theokanning.openai.completion.chat.*; +import com.theokanning.openai.completion.chat.ChatCompletionRequest; +import com.theokanning.openai.completion.chat.ChatFunctionCall; +import com.theokanning.openai.completion.chat.ChatFunctionDynamic; +import com.theokanning.openai.completion.chat.ChatFunctionProperty; +import com.theokanning.openai.completion.chat.ChatMessage; +import com.theokanning.openai.completion.chat.ChatMessageRole; import com.theokanning.openai.service.OpenAiService; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Scanner; public class OpenAiApiDynamicFunctionExample { - + static ObjectMapper mapper = new ObjectMapper(); private static JsonNode getWeather(String location, String unit) { - ObjectMapper mapper = new ObjectMapper(); + ObjectNode response = mapper.createObjectNode(); response.put("location", location); response.put("unit", unit); @@ -20,7 +33,7 @@ private static JsonNode getWeather(String location, String unit) { return response; } - public static void main(String... args) { + public static void main(String... args) throws JsonProcessingException { String token = System.getenv("OPENAI_TOKEN"); OpenAiService service = new OpenAiService(token); @@ -68,8 +81,9 @@ public static void main(String... args) { ChatFunctionCall functionCall = responseMessage.getFunctionCall(); if (functionCall != null) { if (functionCall.getName().equals("get_weather")) { - String location = functionCall.getArguments().get("location").asText(); - String unit = functionCall.getArguments().get("unit").asText(); + JsonNode arguments = mapper.readTree(functionCall.getArguments()); + String location = arguments.get("location").asText(); + String unit = arguments.get("unit").asText(); JsonNode weather = getWeather(location, unit); ChatMessage weatherMessage = new ChatMessage(ChatMessageRole.FUNCTION.value(), weather.toString(), "get_weather"); messages.add(weatherMessage); diff --git a/example/src/main/java/example/OpenAiApiFunctionsWithStreamExample.java b/example/src/main/java/example/OpenAiApiFunctionsWithStreamExample.java index e6de65b6..07c58408 100644 --- a/example/src/main/java/example/OpenAiApiFunctionsWithStreamExample.java +++ b/example/src/main/java/example/OpenAiApiFunctionsWithStreamExample.java @@ -1,13 +1,22 @@ package example; -import com.theokanning.openai.completion.chat.*; +import com.theokanning.openai.completion.chat.ChatCompletionChunk; +import com.theokanning.openai.completion.chat.ChatCompletionRequest; +import com.theokanning.openai.completion.chat.ChatFunction; +import com.theokanning.openai.completion.chat.ChatMessage; +import com.theokanning.openai.completion.chat.ChatMessageRole; import com.theokanning.openai.service.FunctionExecutor; import com.theokanning.openai.service.OpenAiService; import example.OpenAiApiFunctionsExample.Weather; import example.OpenAiApiFunctionsExample.WeatherResponse; import io.reactivex.Flowable; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Random; +import java.util.Scanner; import java.util.concurrent.atomic.AtomicBoolean; public class OpenAiApiFunctionsWithStreamExample { @@ -34,7 +43,7 @@ public static void main(String... args) { while (true) { ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest .builder() - .model("gpt-3.5-turbo-0613") + .model("gpt-4-1106-preview") .messages(messages) .functions(functionExecutor.getFunctions()) .functionCall(ChatCompletionRequest.ChatCompletionRequestFunctionCall.of("auto")) @@ -48,6 +57,7 @@ public static void main(String... args) { ChatMessage chatMessage = service.mapStreamToAccumulator(flowable) .doOnNext(accumulator -> { if (accumulator.isFunctionCall()) { + System.out.println("Trying to execute " + accumulator.getAccumulatedChatFunctionCall().getArguments()); if (isFirst.getAndSet(false)) { System.out.println("Executing function " + accumulator.getAccumulatedChatFunctionCall().getName() + "..."); } @@ -83,4 +93,4 @@ public static void main(String... args) { } } -} \ No newline at end of file +} diff --git a/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java b/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java index 5d143a95..c0f1c4ff 100644 --- a/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java +++ b/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java @@ -10,7 +10,11 @@ import com.theokanning.openai.completion.chat.ChatMessage; import com.theokanning.openai.completion.chat.ChatMessageRole; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; public class FunctionExecutor { @@ -81,8 +85,8 @@ public T execute(ChatFunctionCall call) { ChatFunction function = FUNCTIONS.get(call.getName()); Object obj; try { - JsonNode arguments = call.getArguments(); - obj = MAPPER.readValue(arguments instanceof TextNode ? arguments.asText() : arguments.toPrettyString(), function.getParametersClass()); + String arguments = call.getArguments(); + obj = MAPPER.readValue(arguments, function.getParametersClass()); } catch (JsonProcessingException e) { throw new RuntimeException(e); } diff --git a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java index ee63c419..04b46fa7 100644 --- a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java +++ b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java @@ -5,17 +5,34 @@ import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategy; -import com.fasterxml.jackson.databind.node.TextNode; -import com.theokanning.openai.*; -import com.theokanning.openai.assistants.*; -import com.theokanning.openai.audio.*; +import com.theokanning.openai.DeleteResult; +import com.theokanning.openai.ListSearchParameters; +import com.theokanning.openai.OpenAiError; +import com.theokanning.openai.OpenAiHttpException; +import com.theokanning.openai.OpenAiResponse; +import com.theokanning.openai.assistants.Assistant; +import com.theokanning.openai.assistants.AssistantFile; +import com.theokanning.openai.assistants.AssistantFileRequest; +import com.theokanning.openai.assistants.AssistantRequest; +import com.theokanning.openai.assistants.ModifyAssistantRequest; +import com.theokanning.openai.audio.CreateSpeechRequest; +import com.theokanning.openai.audio.CreateTranscriptionRequest; +import com.theokanning.openai.audio.CreateTranslationRequest; +import com.theokanning.openai.audio.TranscriptionResult; +import com.theokanning.openai.audio.TranslationResult; import com.theokanning.openai.billing.BillingUsage; import com.theokanning.openai.billing.Subscription; import com.theokanning.openai.client.OpenAiApi; import com.theokanning.openai.completion.CompletionChunk; import com.theokanning.openai.completion.CompletionRequest; import com.theokanning.openai.completion.CompletionResult; -import com.theokanning.openai.completion.chat.*; +import com.theokanning.openai.completion.chat.ChatCompletionChunk; +import com.theokanning.openai.completion.chat.ChatCompletionRequest; +import com.theokanning.openai.completion.chat.ChatCompletionResult; +import com.theokanning.openai.completion.chat.ChatFunction; +import com.theokanning.openai.completion.chat.ChatFunctionCall; +import com.theokanning.openai.completion.chat.ChatMessage; +import com.theokanning.openai.completion.chat.ChatMessageRole; import com.theokanning.openai.edit.EditRequest; import com.theokanning.openai.edit.EditResult; import com.theokanning.openai.embedding.EmbeddingRequest; @@ -48,7 +65,12 @@ import io.reactivex.BackpressureStrategy; import io.reactivex.Flowable; import io.reactivex.Single; -import okhttp3.*; +import okhttp3.ConnectionPool; +import okhttp3.MediaType; +import okhttp3.MultipartBody; +import okhttp3.OkHttpClient; +import okhttp3.RequestBody; +import okhttp3.ResponseBody; import retrofit2.Call; import retrofit2.HttpException; import retrofit2.Retrofit; @@ -584,7 +606,6 @@ public static ObjectMapper defaultObjectMapper() { mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); mapper.addMixIn(ChatFunction.class, ChatFunctionMixIn.class); mapper.addMixIn(ChatCompletionRequest.class, ChatCompletionRequestMixIn.class); - mapper.addMixIn(ChatFunctionCall.class, ChatFunctionCallMixIn.class); return mapper; } @@ -617,8 +638,8 @@ public Flowable mapStreamToAccumulator(Flowable mapStreamToAccumulator(Flowable functions = Collections.singletonList(ChatFunction.builder() .name("get_weather") .description("Get the current weather in a given location") @@ -220,7 +236,8 @@ void streamChatCompletionWithFunctions() { .getAccumulatedMessage(); assertNotNull(accumulatedMessage.getFunctionCall()); assertEquals("get_weather", accumulatedMessage.getFunctionCall().getName()); - assertInstanceOf(ObjectNode.class, accumulatedMessage.getFunctionCall().getArguments()); + JsonNode arguments = this.mapper.readTree(accumulatedMessage.getFunctionCall().getArguments()); + assertInstanceOf(ObjectNode.class, arguments); ChatMessage callResponse = functionExecutor.executeAndConvertToMessageHandlingExceptions(accumulatedMessage.getFunctionCall()); assertNotEquals("error", callResponse.getName()); @@ -256,7 +273,7 @@ void streamChatCompletionWithFunctions() { } @Test - void streamChatCompletionWithDynamicFunctions() { + void streamChatCompletionWithDynamicFunctions() throws JsonProcessingException { ChatFunctionDynamic function = ChatFunctionDynamic.builder() .name("get_weather") .description("Get the current weather of a location") @@ -295,9 +312,10 @@ void streamChatCompletionWithDynamicFunctions() { .getAccumulatedMessage(); assertNotNull(accumulatedMessage.getFunctionCall()); assertEquals("get_weather", accumulatedMessage.getFunctionCall().getName()); - assertInstanceOf(ObjectNode.class, accumulatedMessage.getFunctionCall().getArguments()); - assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("location")); - assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("unit")); + JsonNode arguments = this.mapper.readTree(accumulatedMessage.getFunctionCall().getArguments()); + assertInstanceOf(ObjectNode.class, arguments); + assertNotNull(arguments.get("location")); + assertNotNull(arguments.get("unit")); } } From 65f76c3c9343dc1c6a25b1fe5db726653da46dcc Mon Sep 17 00:00:00 2001 From: Foad Kesheh Date: Thu, 23 Nov 2023 23:12:45 -0300 Subject: [PATCH 2/3] Fixes Tests --- .../openai/service/ChatFunctionCallMixIn.java | 13 ------------- .../openai/service/AssistantFunctionTest.java | 19 ++++++------------- 2 files changed, 6 insertions(+), 26 deletions(-) delete mode 100644 service/src/main/java/com/theokanning/openai/service/ChatFunctionCallMixIn.java diff --git a/service/src/main/java/com/theokanning/openai/service/ChatFunctionCallMixIn.java b/service/src/main/java/com/theokanning/openai/service/ChatFunctionCallMixIn.java deleted file mode 100644 index 7b32e051..00000000 --- a/service/src/main/java/com/theokanning/openai/service/ChatFunctionCallMixIn.java +++ /dev/null @@ -1,13 +0,0 @@ -package com.theokanning.openai.service; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; - -public abstract class ChatFunctionCallMixIn { - - @JsonSerialize(using = ChatFunctionCallArgumentsSerializerAndDeserializer.Serializer.class) - @JsonDeserialize(using = ChatFunctionCallArgumentsSerializerAndDeserializer.Deserializer.class) - abstract JsonNode getArguments(); - -} diff --git a/service/src/test/java/com/theokanning/openai/service/AssistantFunctionTest.java b/service/src/test/java/com/theokanning/openai/service/AssistantFunctionTest.java index 9ad819a7..a175537f 100644 --- a/service/src/test/java/com/theokanning/openai/service/AssistantFunctionTest.java +++ b/service/src/test/java/com/theokanning/openai/service/AssistantFunctionTest.java @@ -6,7 +6,6 @@ import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategy; -import com.theokanning.openai.ListSearchParameters; import com.theokanning.openai.OpenAiResponse; import com.theokanning.openai.assistants.Assistant; import com.theokanning.openai.assistants.AssistantFunction; @@ -15,15 +14,12 @@ import com.theokanning.openai.assistants.Tool; import com.theokanning.openai.completion.chat.ChatCompletionRequest; import com.theokanning.openai.completion.chat.ChatFunction; -import com.theokanning.openai.completion.chat.ChatFunctionCall; import com.theokanning.openai.messages.Message; import com.theokanning.openai.messages.MessageRequest; import com.theokanning.openai.runs.RequiredAction; import com.theokanning.openai.runs.Run; import com.theokanning.openai.runs.RunCreateRequest; -import com.theokanning.openai.runs.RunStep; import com.theokanning.openai.runs.SubmitToolOutputRequestItem; -import com.theokanning.openai.runs.SubmitToolOutputs; import com.theokanning.openai.runs.SubmitToolOutputsRequest; import com.theokanning.openai.runs.ToolCall; import com.theokanning.openai.threads.Thread; @@ -35,9 +31,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Objects; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; class AssistantFunctionTest { @@ -53,8 +47,7 @@ void createRetrieveRun() throws JsonProcessingException { mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); mapper.addMixIn(ChatFunction.class, ChatFunctionMixIn.class); mapper.addMixIn(ChatCompletionRequest.class, ChatCompletionRequestMixIn.class); - mapper.addMixIn(ChatFunctionCall.class, ChatFunctionCallMixIn.class); - + String funcDef = "{\n" + " \"type\": \"object\",\n" + " \"properties\": {\n" + @@ -79,8 +72,8 @@ void createRetrieveRun() throws JsonProcessingException { List toolList = new ArrayList<>(); Tool funcTool = new Tool(AssistantToolsEnum.FUNCTION, function); toolList.add(funcTool); - - + + AssistantRequest assistantRequest = AssistantRequest.builder() .model(TikTokensUtil.ModelEnum.GPT_4_1106_preview.getName()) .name("MATH_TUTOR") @@ -107,8 +100,8 @@ void createRetrieveRun() throws JsonProcessingException { assertNotNull(run); Run retrievedRun = service.retrieveRun(thread.getId(), run.getId()); - while (!(retrievedRun.getStatus().equals("completed")) - && !(retrievedRun.getStatus().equals("failed")) + while (!(retrievedRun.getStatus().equals("completed")) + && !(retrievedRun.getStatus().equals("failed")) && !(retrievedRun.getStatus().equals("requires_action"))){ retrievedRun = service.retrieveRun(thread.getId(), run.getId()); } @@ -142,7 +135,7 @@ void createRetrieveRun() throws JsonProcessingException { List messages = response.getData(); System.out.println(mapper.writeValueAsString(messages)); - + } } } From b5d511ed0e8c49c9fc9adc87ba3e64024c15851d Mon Sep 17 00:00:00 2001 From: Foad Kesheh Date: Thu, 23 Nov 2023 23:14:12 -0300 Subject: [PATCH 3/3] Minor improvements --- .../openai/service/AssistantTest.java | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/service/src/test/java/com/theokanning/openai/service/AssistantTest.java b/service/src/test/java/com/theokanning/openai/service/AssistantTest.java index cf4fc361..6ed41377 100644 --- a/service/src/test/java/com/theokanning/openai/service/AssistantTest.java +++ b/service/src/test/java/com/theokanning/openai/service/AssistantTest.java @@ -3,14 +3,19 @@ import com.theokanning.openai.DeleteResult; import com.theokanning.openai.ListSearchParameters; import com.theokanning.openai.OpenAiResponse; -import com.theokanning.openai.assistants.*; +import com.theokanning.openai.assistants.Assistant; +import com.theokanning.openai.assistants.AssistantFile; +import com.theokanning.openai.assistants.AssistantFileRequest; +import com.theokanning.openai.assistants.AssistantRequest; +import com.theokanning.openai.assistants.AssistantToolsEnum; +import com.theokanning.openai.assistants.ModifyAssistantRequest; +import com.theokanning.openai.assistants.Tool; import com.theokanning.openai.file.File; import com.theokanning.openai.utils.TikTokensUtil; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; import java.util.Collections; -import java.util.List; import static org.junit.jupiter.api.Assertions.*; @@ -19,7 +24,7 @@ public class AssistantTest { public static final String MATH_TUTOR = "Math Tutor"; public static final String ASSISTANT_INSTRUCTION = "You are a personal Math Tutor."; - static String token = System.getenv("OPENAI_TOKEN");; + static String token = System.getenv("OPENAI_TOKEN"); static OpenAiService service = new OpenAiService(token); @@ -105,9 +110,7 @@ static void clean() { .limit(100) .build(); OpenAiResponse assistantListAssistant = service.listAssistants(queryFilter); - assistantListAssistant.getData().forEach(assistant ->{ - service.deleteAssistant(assistant.getId()); - }); + assistantListAssistant.getData().forEach(assistant -> service.deleteAssistant(assistant.getId())); } private static File uploadAssistantFile() { @@ -137,7 +140,7 @@ private static void validateAssistantResponse(Assistant assistantResponse) { assertNotNull(assistantResponse.getId()); assertNotNull(assistantResponse.getCreatedAt()); assertNotNull(assistantResponse.getObject()); - assertEquals(assistantResponse.getTools().get(0).getType(), AssistantToolsEnum.CODE_INTERPRETER); + assertEquals(AssistantToolsEnum.CODE_INTERPRETER, assistantResponse.getTools().get(0).getType()); assertEquals(MATH_TUTOR, assistantResponse.getName()); } }