Skip to content

Commit

Permalink
WIP: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
cebtenzzre committed Aug 30, 2024
1 parent 39311bb commit ea253e4
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 50 deletions.
11 changes: 6 additions & 5 deletions gpt4all-chat/src/localdocsmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,25 @@ class LocalDocsCollectionsModel : public QSortFilterProxyModel
Q_OBJECT
Q_PROPERTY(int count READ count NOTIFY countChanged)
Q_PROPERTY(int updatingCount READ updatingCount NOTIFY updatingCountChanged)

public:
explicit LocalDocsCollectionsModel(QObject *parent);
int count() const { return rowCount(); }
int updatingCount() const;

public Q_SLOTS:
int count() const { return rowCount(); }
void setCollections(const QList<QString> &collections);
int updatingCount() const;

Q_SIGNALS:
void countChanged();
void updatingCountChanged();

private Q_SLOT:
void maybeTriggerUpdatingCountChanged();

protected:
bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override;

private Q_SLOTS:
void maybeTriggerUpdatingCountChanged();

private:
QList<QString> m_collections;
int m_updatingCount = 0;
Expand Down
6 changes: 4 additions & 2 deletions gpt4all-chat/src/modellist.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
#include <QVector>
#include <Qt>
#include <QtGlobal>
#include <QtQml>

#include <utility>

using namespace Qt::Literals::StringLiterals;


struct ModelInfo {
Q_GADGET
Q_PROPERTY(QString id READ id WRITE setId)
Expand Down Expand Up @@ -521,7 +523,7 @@ private Q_SLOTS:

protected:
explicit ModelList();
~ModelList() { for (auto *model: m_models) { delete model; } }
~ModelList() override { for (auto *model: std::as_const(m_models)) { delete model; } }
friend class MyModelList;
};

Expand Down
1 change: 1 addition & 0 deletions gpt4all-chat/src/mysettings.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <QSettings>
#include <QString>
#include <QStringList>
#include <QTranslator>
#include <QVector>

#include <cstdint>
Expand Down
88 changes: 51 additions & 37 deletions gpt4all-chat/src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include "mysettings.h"

#include <QByteArray>
#include <QCborArray>
#include <QCborMap>
#include <QCborValue>
#include <QDateTime>
#include <QDebug>
#include <QHostAddress>
Expand All @@ -19,6 +22,7 @@
#include <QtLogging>

#include <algorithm>
#include <cstdint>
#include <format>
#include <iostream>
#include <optional>
Expand All @@ -36,6 +40,8 @@ using namespace Qt::Literals::StringLiterals;

template <typename T>
struct BasicFormatter {
virtual ~BasicFormatter() = default;

template <typename ParseContext>
constexpr ParseContext::iterator parse(ParseContext &ctx)
{
Expand Down Expand Up @@ -82,6 +88,9 @@ class InvalidRequestError: public std::invalid_argument {
return { QJsonObject {{ "error", error }},
QHttpServerResponder::StatusCode::BadRequest };
}

private:
Q_DISABLE_COPY_MOVE(InvalidRequestError)
};

} // namespace
Expand Down Expand Up @@ -133,17 +142,20 @@ class BaseCompletionRequest {
public:
QString model; // required
// NB: some parameters are not supported yet
int max_tokens = 16;
int n = 1;
int32_t max_tokens = 16;
qint64 n = 1;
float temperature = 1.f;
float top_p = 1.f;
float min_p = 0.f;

BaseCompletionRequest() = default;
virtual ~BaseCompletionRequest() = default;

virtual BaseCompletionRequest &parse(const QCborMap &request)
{
using enum Type;

auto reqValue = [this, &request](auto &&...args) { return getValue(request, args...); };
auto reqValue = [&request](auto &&...args) { return getValue(request, args...); };
QCborValue value;

this->model = reqValue("model", String, /*required*/ true).toString();
Expand All @@ -154,7 +166,7 @@ class BaseCompletionRequest {

value = reqValue("max_tokens", Integer, false, /*min*/ 1);
if (!value.isNull())
this->max_tokens = value.toInteger();
this->max_tokens = int32_t(qMin(value.toInteger(), INT32_MAX));

value = reqValue("n", Integer, false, /*min*/ 1);
if (!value.isNull())
Expand Down Expand Up @@ -198,7 +210,7 @@ class BaseCompletionRequest {
}

protected:
enum class Type {
enum class Type : uint8_t {
Boolean,
Integer,
Number,
Expand All @@ -222,7 +234,7 @@ class BaseCompletionRequest {
Q_UNREACHABLE();
}

QCborValue getValue(
static QCborValue getValue(
const QCborMap &obj, const QString &name, std::optional<Type> type = {}, bool required = false,
std::optional<qint64> min = {}, std::optional<qint64> max = {}
) {
Expand All @@ -243,6 +255,9 @@ class BaseCompletionRequest {
}
return value;
}

private:
Q_DISABLE_COPY_MOVE(BaseCompletionRequest)
};

class CompletionRequest : public BaseCompletionRequest {
Expand All @@ -255,7 +270,7 @@ class CompletionRequest : public BaseCompletionRequest {
{
using enum Type;

auto reqValue = [this, &request](auto &&...args) { return getValue(request, args...); };
auto reqValue = [&request](auto &&...args) { return getValue(request, args...); };
QCborValue value;

BaseCompletionRequest::parse(request);
Expand Down Expand Up @@ -309,7 +324,7 @@ const std::unordered_map<BaseCompletionRequest::Type, const char *> BaseCompleti
class ChatRequest : public BaseCompletionRequest {
public:
struct Message {
enum class Role {
enum class Role : uint8_t {
User,
Assistant,
};
Expand All @@ -323,7 +338,7 @@ class ChatRequest : public BaseCompletionRequest {
{
using enum Type;

auto reqValue = [this, &request](auto &&...args) { return getValue(request, args...); };
auto reqValue = [&request](auto &&...args) { return getValue(request, args...); };
QCborValue value;

BaseCompletionRequest::parse(request);
Expand All @@ -339,7 +354,7 @@ class ChatRequest : public BaseCompletionRequest {
{
QCborArray arr = value.toArray();
Message::Role nextRole = Message::Role::User;
for (size_t i = 0; i < arr.size(); i++) {
for (qsizetype i = 0; i < arr.size(); i++) {
const auto &elem = arr[i];
if (!elem.isMap())
throw InvalidRequestError(std::format(
Expand Down Expand Up @@ -416,10 +431,10 @@ class ChatRequest : public BaseCompletionRequest {
};

template <typename T>
T parseRequest(QJsonObject &&obj)
T &parseRequest(T &request, QJsonObject &&obj)
{
// lossless conversion to CBOR exposes more type information
return T().parse(QCborMap::fromJsonObject(obj));
return request.parse(QCborMap::fromJsonObject(obj));
}

Server::Server(Chat *chat)
Expand All @@ -432,10 +447,6 @@ Server::Server(Chat *chat)
connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection);
}

Server::~Server()
{
}

static QJsonObject requestFromJson(const QByteArray &request)
{
QJsonParseError err;
Expand All @@ -450,14 +461,14 @@ static QJsonObject requestFromJson(const QByteArray &request)

void Server::start()
{
m_server = new QHttpServer(this);
m_server = std::make_unique<QHttpServer>(this);
if (!m_server->listen(QHostAddress::LocalHost, MySettings::globalInstance()->networkPort())) {
qWarning() << "ERROR: Unable to start the server";
return;
}

m_server->route("/v1/models", QHttpServerRequest::Method::Get,
[](const QHttpServerRequest &request) {
[](const QHttpServerRequest &) {
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);

Expand All @@ -477,7 +488,7 @@ void Server::start()
);

m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Get,
[](const QString &model, const QHttpServerRequest &request) {
[](const QString &model, const QHttpServerRequest &) {
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);

Expand Down Expand Up @@ -507,7 +518,8 @@ void Server::start()
#if defined(DEBUG)
qDebug().noquote() << "/v1/completions request" << QJsonDocument(reqObj).toJson(QJsonDocument::Indented);
#endif
auto req = parseRequest<CompletionRequest>(std::move(reqObj));
CompletionRequest req;
parseRequest(req, std::move(reqObj));
auto [resp, respObj] = handleCompletionRequest(req);
#if defined(DEBUG)
if (respObj)
Expand All @@ -530,7 +542,8 @@ void Server::start()
#if defined(DEBUG)
qDebug().noquote() << "/v1/chat/completions request" << QJsonDocument(reqObj).toJson(QJsonDocument::Indented);
#endif
auto req = parseRequest<ChatRequest>(std::move(reqObj));
ChatRequest req;
parseRequest(req, std::move(reqObj));
auto [resp, respObj] = handleChatRequest(req);
(void)respObj;
#if defined(DEBUG)
Expand All @@ -546,7 +559,7 @@ void Server::start()

// Respond with code 405 to wrong HTTP methods:
m_server->route("/v1/models", QHttpServerRequest::Method::Post,
[](const QHttpServerRequest &request) {
[] {
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
return QHttpServerResponse(
Expand All @@ -558,7 +571,8 @@ void Server::start()
);

m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Post,
[](const QString &model, const QHttpServerRequest &request) {
[](const QString &model) {
(void)model;
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
return QHttpServerResponse(
Expand All @@ -570,7 +584,7 @@ void Server::start()
);

m_server->route("/v1/completions", QHttpServerRequest::Method::Get,
[](const QHttpServerRequest &request) {
[] {
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
return QHttpServerResponse(
Expand All @@ -581,7 +595,7 @@ void Server::start()
);

m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Get,
[](const QHttpServerRequest &request) {
[] {
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
return QHttpServerResponse(
Expand Down Expand Up @@ -640,10 +654,10 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
resetContext();

// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
const float top_k = modelInfo.topK();
const int n_batch = modelInfo.promptBatchSize();
const float repeat_penalty = modelInfo.repeatPenalty();
const int repeat_last_n = modelInfo.repeatPenaltyTokens();
const int top_k = modelInfo.topK();
const int n_batch = modelInfo.promptBatchSize();
const auto repeat_penalty = float(modelInfo.repeatPenalty());
const int repeat_last_n = modelInfo.repeatPenaltyTokens();

int promptTokens = 0;
int responseTokens = 0;
Expand Down Expand Up @@ -707,9 +721,9 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)

responseObject.insert("choices", choices);
responseObject.insert("usage", QJsonObject {
{ "prompt_tokens", int(promptTokens) },
{ "completion_tokens", int(responseTokens) },
{ "total_tokens", int(promptTokens + responseTokens) },
{ "prompt_tokens", promptTokens },
{ "completion_tokens", responseTokens },
{ "total_tokens", promptTokens + responseTokens },
});

return {QHttpServerResponse(responseObject), responseObject};
Expand Down Expand Up @@ -746,9 +760,9 @@ auto Server::handleChatRequest(const ChatRequest &request)
resetContext();

const QString promptTemplate = modelInfo.promptTemplate();
const float top_k = modelInfo.topK();
const int top_k = modelInfo.topK();
const int n_batch = modelInfo.promptBatchSize();
const float repeat_penalty = modelInfo.repeatPenalty();
const auto repeat_penalty = float(modelInfo.repeatPenalty());
const int repeat_last_n = modelInfo.repeatPenaltyTokens();

int promptTokens = 0;
Expand Down Expand Up @@ -853,9 +867,9 @@ auto Server::handleChatRequest(const ChatRequest &request)

responseObject.insert("choices", choices);
responseObject.insert("usage", QJsonObject {
{ "prompt_tokens", int(promptTokens) },
{ "completion_tokens", int(responseTokens) },
{ "total_tokens", int(promptTokens + responseTokens) },
{ "prompt_tokens", promptTokens },
{ "completion_tokens", responseTokens },
{ "total_tokens", promptTokens + responseTokens },
});

return {QHttpServerResponse(responseObject), responseObject};
Expand Down
13 changes: 7 additions & 6 deletions gpt4all-chat/src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
#include "chatllm.h"
#include "database.h"

#include <QHttpServer>
#include <QHttpServerResponse>
#include <QObject>
#include <QJsonObject>
#include <QList>
#include <QObject>
#include <QString>
#include <QJsonObject>

#include <memory>
#include <optional>
#include <utility>

class Chat;
class QHttpServer;
class CompletionRequest;
class ChatRequest;

Expand All @@ -23,8 +24,8 @@ class Server : public ChatLLM
Q_OBJECT

public:
Server(Chat *parent);
virtual ~Server();
explicit Server(Chat *chat);
~Server() override = default;

public Q_SLOTS:
void start();
Expand All @@ -42,7 +43,7 @@ private Q_SLOTS:

private:
Chat *m_chat;
QHttpServer *m_server;
std::unique_ptr<QHttpServer> m_server;
QList<ResultInfo> m_databaseResults;
QList<QString> m_collections;
};
Expand Down

0 comments on commit ea253e4

Please sign in to comment.