Skip to content

Commit

Permalink
server: improve correctness of request parsing and responses (#2929)
Browse files Browse the repository at this point in the history
Signed-off-by: Jared Van Bortel <[email protected]>
  • Loading branch information
cebtenzzre committed Sep 9, 2024
1 parent 1aae4ff commit 3900528
Show file tree
Hide file tree
Showing 22 changed files with 779 additions and 317 deletions.
123 changes: 68 additions & 55 deletions .circleci/continue_config.yml

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@
[submodule "gpt4all-chat/deps/SingleApplication"]
path = gpt4all-chat/deps/SingleApplication
url = https://github.com/nomic-ai/SingleApplication.git
[submodule "gpt4all-chat/deps/fmt"]
path = gpt4all-chat/deps/fmt
url = https://github.com/fmtlib/fmt.git
2 changes: 1 addition & 1 deletion gpt4all-backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ set(LLMODEL_VERSION_PATCH 0)
set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}")
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)

set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD 23)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
set(BUILD_SHARED_LIBS ON)
Expand Down
2 changes: 1 addition & 1 deletion gpt4all-backend/deps/llama.cpp-mainline
7 changes: 4 additions & 3 deletions gpt4all-backend/include/gpt4all-backend/llmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class LLModel {
bool allowContextShift,
PromptContext &ctx,
bool special = false,
std::string *fakeReply = nullptr);
std::optional<std::string_view> fakeReply = {});

using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);

Expand Down Expand Up @@ -212,7 +212,7 @@ class LLModel {
protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0;
virtual std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special = false) = 0;
virtual bool isSpecialToken(Token id) const = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0;
Expand Down Expand Up @@ -249,7 +249,8 @@ class LLModel {
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp);
std::vector<Token> embd_inp,
bool isResponse = false);
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx);
Expand Down
4 changes: 2 additions & 2 deletions gpt4all-backend/src/llamamodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,13 +536,13 @@ size_t LLamaModel::restoreState(const uint8_t *src)
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
}

std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special)
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, std::string_view str, bool special)
{
bool atStart = m_tokenize_last_token == -1;
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
std::vector<LLModel::Token> fres(str.length() + 4);
int32_t fres_len = llama_tokenize_gpt4all(
d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
/*parse_special*/ special, /*insert_space*/ insertSpace
);
fres.resize(fres_len);
Expand Down
3 changes: 2 additions & 1 deletion gpt4all-backend/src/llamamodel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <memory>
#include <string>
#include <string_view>
#include <vector>

struct LLamaPrivate;
Expand Down Expand Up @@ -52,7 +53,7 @@ class LLamaModel : public LLModel {
bool m_supportsCompletion = false;

protected:
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override;
std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special) override;
bool isSpecialToken(Token id) const override;
std::string tokenToString(Token id) const override;
Token sampleToken(PromptContext &ctx) const override;
Expand Down
8 changes: 3 additions & 5 deletions gpt4all-backend/src/llmodel_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <vector>

struct LLModelWrapper {
Expand Down Expand Up @@ -130,13 +131,10 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;

std::string fake_reply_str;
if (fake_reply) { fake_reply_str = fake_reply; }
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;

// Call the C++ prompt method
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
wrapper->promptContext, special, fake_reply_p);
wrapper->promptContext, special,
fake_reply ? std::make_optional<std::string_view>(fake_reply) : std::nullopt);

// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies
Expand Down
14 changes: 9 additions & 5 deletions gpt4all-backend/src/llmodel_shared.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <sstream>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>

namespace ranges = std::ranges;
Expand Down Expand Up @@ -45,7 +46,7 @@ void LLModel::prompt(const std::string &prompt,
bool allowContextShift,
PromptContext &promptCtx,
bool special,
std::string *fakeReply)
std::optional<std::string_view> fakeReply)
{
if (!isModelLoaded()) {
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
Expand Down Expand Up @@ -129,11 +130,11 @@ void LLModel::prompt(const std::string &prompt,
return; // error

// decode the assistant's reply, either generated or spoofed
if (fakeReply == nullptr) {
if (!fakeReply) {
generateResponse(responseCallback, allowContextShift, promptCtx);
} else {
embd_inp = tokenize(promptCtx, *fakeReply, false);
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, true))
return; // error
}

Expand All @@ -157,7 +158,8 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp) {
std::vector<Token> embd_inp,
bool isResponse) {
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
Expand Down Expand Up @@ -196,7 +198,9 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
for (size_t t = 0; t < tokens; ++t) {
promptCtx.tokens.push_back(batch.at(t));
promptCtx.n_past += 1;
if (!promptCallback(batch.at(t)))
Token tok = batch.at(t);
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
if (!res)
return false;
}
i = batch_end;
Expand Down
1 change: 1 addition & 0 deletions gpt4all-chat/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Fix a typo in Model Settings (by [@3Simplex](https://github.com/3Simplex) in [#2916](https://github.com/nomic-ai/gpt4all/pull/2916))
- Fix the antenna icon tooltip when using the local server ([#2922](https://github.com/nomic-ai/gpt4all/pull/2922))
- Fix a few issues with locating files and handling errors when loading remote models on startup ([#2875](https://github.com/nomic-ai/gpt4all/pull/2875))
- Significantly improve API server request parsing and response correctness ([#2929](https://github.com/nomic-ai/gpt4all/pull/2929))

## [3.2.1] - 2024-08-13

Expand Down
10 changes: 8 additions & 2 deletions gpt4all-chat/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.16)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD 23)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

if(APPLE)
Expand Down Expand Up @@ -64,6 +64,12 @@ message(STATUS "Qt 6 root directory: ${Qt6_ROOT_DIR}")

set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)

set(FMT_INSTALL OFF)
set(BUILD_SHARED_LIBS_SAVED "${BUILD_SHARED_LIBS}")
set(BUILD_SHARED_LIBS OFF)
add_subdirectory(deps/fmt)
set(BUILD_SHARED_LIBS "${BUILD_SHARED_LIBS_SAVED}")

add_subdirectory(../gpt4all-backend llmodel)

set(CHAT_EXE_RESOURCES)
Expand Down Expand Up @@ -240,7 +246,7 @@ else()
PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer Qt6::Sql Qt6::Pdf)
endif()
target_link_libraries(chat
PRIVATE llmodel SingleApplication)
PRIVATE llmodel SingleApplication fmt::fmt)


# -- install --
Expand Down
1 change: 1 addition & 0 deletions gpt4all-chat/deps/fmt
Submodule fmt added at 0c9fce
5 changes: 3 additions & 2 deletions gpt4all-chat/src/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,17 @@ void Chat::newPromptResponsePair(const QString &prompt)
resetResponseState();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendPrompt("Prompt: ", prompt);
m_chatModel->appendResponse("Response: ", prompt);
m_chatModel->appendResponse("Response: ", QString());
emit resetResponseRequested();
}

// the server needs to block until response is reset, so it calls resetResponse on its own m_llmThread
void Chat::serverNewPromptResponsePair(const QString &prompt)
{
resetResponseState();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendPrompt("Prompt: ", prompt);
m_chatModel->appendResponse("Response: ", prompt);
m_chatModel->appendResponse("Response: ", QString());
}

bool Chat::restoringFromText() const
Expand Down
4 changes: 2 additions & 2 deletions gpt4all-chat/src/chatapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void ChatAPI::prompt(const std::string &prompt,
bool allowContextShift,
PromptContext &promptCtx,
bool special,
std::string *fakeReply) {
std::optional<std::string_view> fakeReply) {

Q_UNUSED(promptCallback);
Q_UNUSED(allowContextShift);
Expand Down Expand Up @@ -121,7 +121,7 @@ void ChatAPI::prompt(const std::string &prompt,
if (fakeReply) {
promptCtx.n_past += 1;
m_context.append(formattedPrompt);
m_context.append(QString::fromStdString(*fakeReply));
m_context.append(QString::fromUtf8(fakeReply->data(), fakeReply->size()));
return;
}

Expand Down
7 changes: 4 additions & 3 deletions gpt4all-chat/src/chatapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@

#include <cstddef>
#include <cstdint>
#include <stdexcept>
#include <functional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>

class QNetworkAccessManager;
Expand Down Expand Up @@ -72,7 +73,7 @@ class ChatAPI : public QObject, public LLModel {
bool allowContextShift,
PromptContext &ctx,
bool special,
std::string *fakeReply) override;
std::optional<std::string_view> fakeReply) override;

void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;
Expand All @@ -97,7 +98,7 @@ class ChatAPI : public QObject, public LLModel {
// them as they are only called from the default implementation of 'prompt' which we override and
// completely replace

std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override
std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special) override
{
(void)ctx;
(void)str;
Expand Down
38 changes: 22 additions & 16 deletions gpt4all-chat/src/chatllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,16 +626,16 @@ void ChatLLM::regenerateResponse()
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
m_promptResponseTokens = 0;
m_promptTokens = 0;
m_response = std::string();
emit responseChanged(QString::fromStdString(m_response));
m_response = m_trimmedResponse = std::string();
emit responseChanged(QString::fromStdString(m_trimmedResponse));
}

void ChatLLM::resetResponse()
{
m_promptTokens = 0;
m_promptResponseTokens = 0;
m_response = std::string();
emit responseChanged(QString::fromStdString(m_response));
m_response = m_trimmedResponse = std::string();
emit responseChanged(QString::fromStdString(m_trimmedResponse));
}

void ChatLLM::resetContext()
Expand All @@ -645,9 +645,12 @@ void ChatLLM::resetContext()
m_ctx = LLModel::PromptContext();
}

QString ChatLLM::response() const
QString ChatLLM::response(bool trim) const
{
return QString::fromStdString(remove_leading_whitespace(m_response));
std::string resp = m_response;
if (trim)
resp = remove_leading_whitespace(resp);
return QString::fromStdString(resp);
}

ModelInfo ChatLLM::modelInfo() const
Expand Down Expand Up @@ -705,7 +708,8 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
// check for error
if (token < 0) {
m_response.append(response);
emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response)));
m_trimmedResponse = remove_leading_whitespace(m_response);
emit responseChanged(QString::fromStdString(m_trimmedResponse));
return false;
}

Expand All @@ -715,7 +719,8 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
m_timer->inc();
Q_ASSERT(!response.empty());
m_response.append(response);
emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response)));
m_trimmedResponse = remove_leading_whitespace(m_response);
emit responseChanged(QString::fromStdString(m_trimmedResponse));
return !m_stopGenerating;
}

Expand All @@ -741,7 +746,7 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt

bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens)
int32_t repeat_penalty_tokens, std::optional<QString> fakeReply)
{
if (!isModelLoaded())
return false;
Expand All @@ -751,7 +756,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString

QList<ResultInfo> databaseResults;
const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize();
if (!collectionList.isEmpty()) {
if (!fakeReply && !collectionList.isEmpty()) {
emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks
emit databaseResultsChanged(databaseResults);
}
Expand Down Expand Up @@ -797,17 +802,18 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_ctx.n_predict = old_n_predict; // now we are ready for a response
}
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx);
/*allowContextShift*/ true, m_ctx, false,
fakeReply.transform(std::mem_fn(&QString::toStdString)));
#if defined(DEBUG)
printf("\n");
fflush(stdout);
#endif
m_timer->stop();
qint64 elapsed = totalTime.elapsed();
std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) {
m_response = trimmed;
emit responseChanged(QString::fromStdString(m_response));
if (trimmed != m_trimmedResponse) {
m_trimmedResponse = trimmed;
emit responseChanged(QString::fromStdString(m_trimmedResponse));
}

SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
Expand Down Expand Up @@ -1078,6 +1084,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
QString response;
stream >> response;
m_response = response.toStdString();
m_trimmedResponse = trim_whitespace(m_response);
QString nameResponse;
stream >> nameResponse;
m_nameResponse = nameResponse.toStdString();
Expand Down Expand Up @@ -1306,10 +1313,9 @@ void ChatLLM::processRestoreStateFromText()

auto &response = *it++;
Q_ASSERT(response.first != "Prompt: ");
auto responseText = response.second.toStdString();

m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
/*allowContextShift*/ true, m_ctx, false, &responseText);
/*allowContextShift*/ true, m_ctx, false, response.second.toUtf8().constData());
}

if (!m_stopGenerating) {
Expand Down
Loading

0 comments on commit 3900528

Please sign in to comment.