Skip to content

Commit

Permalink
Finalize base chat components
Browse files Browse the repository at this point in the history
  • Loading branch information
Palm1r committed Sep 19, 2024
1 parent f5de1b9 commit 6f2029f
Show file tree
Hide file tree
Showing 17 changed files with 77 additions and 58 deletions.
7 changes: 5 additions & 2 deletions LLMClientInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,16 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request)
auto updatedContext = prepareContext(request);

LLMConfig config;
config.requestType = RequestType::Fim;
config.provider = LLMProvidersManager::instance().getCurrentFimProvider();
config.promptTemplate = PromptTemplateManager::instance().getCurrentTemplate();
config.promptTemplate = PromptTemplateManager::instance().getCurrentFimTemplate();
config.url = QUrl(QString("%1%2").arg(Settings::generalSettings().url(),
Settings::generalSettings().endPoint()));

config.providerRequest = {{"model", Settings::generalSettings().modelName.value()},
{"stream", true}};
{"stream", true},
{"stop",
QJsonArray::fromStringList(config.promptTemplate->stopWords())}};

config.promptTemplate->prepareRequest(config.providerRequest, updatedContext);
config.provider->prepareRequest(config.providerRequest);
Expand Down
39 changes: 27 additions & 12 deletions PromptTemplateManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,48 @@ PromptTemplateManager &PromptTemplateManager::instance()

void PromptTemplateManager::setCurrentFimTemplate(const QString &name)
{
if (!m_templates.contains(name)) {
logMessage("Can't find prompt with name: " + name);
if (!m_fimTemplates.contains(name) || m_fimTemplates[name] == nullptr) {
logMessage("Error to set current FIM template" + name);
return;
}

if (m_templates[name] == nullptr) {
logMessage("Prompt is null");
return;
m_currentFimTemplate = m_fimTemplates[name];
}

Templates::PromptTemplate *PromptTemplateManager::getCurrentFimTemplate()
{
if (m_currentFimTemplate == nullptr) {
logMessage("Current fim provider is null");
return nullptr;
}

m_currentFimPrompt
return m_currentFimTemplate;
}

Templates::PromptTemplate *PromptTemplateManager::getCurrentTemplate()
void PromptTemplateManager::setCurrentChatTemplate(const QString &name)
{
auto it = m_templates.find(m_currentTemplateName);
return it != m_templates.end() ? it.value() : nullptr;
if (!m_chatTemplates.contains(name) || m_chatTemplates[name] == nullptr) {
logMessage("Error to set current chat template" + name);
return;
}

m_currentChatTemplate = m_chatTemplates[name];
}

QStringList PromptTemplateManager::getTemplateNames() const
Templates::PromptTemplate *PromptTemplateManager::getCurrentChatTemplate()
{
return m_templates.keys();
if (m_currentChatTemplate == nullptr) {
logMessage("Current chat provider is null");
return nullptr;
}

return m_currentChatTemplate;
}

PromptTemplateManager::~PromptTemplateManager()
{
qDeleteAll(m_templates);
qDeleteAll(m_fimTemplates);
qDeleteAll(m_chatTemplates);
}

} // namespace QodeAssist
20 changes: 13 additions & 7 deletions PromptTemplateManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,30 @@ class PromptTemplateManager
"T must inherit from PromptTemplate");
T *template_ptr = new T();
QString name = template_ptr->name();
m_templates[name] = template_ptr;
Settings::generalSettings().fimPrompts.addOption(name);
if (template_ptr->type() == Templates::TemplateType::Fim) {
m_fimTemplates[name] = template_ptr;
Settings::generalSettings().fimPrompts.addOption(name);
} else if (template_ptr->type() == Templates::TemplateType::Chat) {
m_chatTemplates[name] = template_ptr;
Settings::generalSettings().chatPrompts.addOption(name);
}
}

void setCurrentFimTemplate(const QString &name);
Templates::PromptTemplate *getCurrentFimTemplate();

QStringList getTemplateNames() const;

void setCurrentChatTemplate(const QString &name);
Templates::PromptTemplate *getCurrentChatTemplate();

private:
PromptTemplateManager() = default;
PromptTemplateManager(const PromptTemplateManager &) = delete;
PromptTemplateManager &operator=(const PromptTemplateManager &) = delete;

QMap<QString, Templates::PromptTemplate *> m_templates;
Templates::PromptTemplate *m_currentFimPrompt;
Templates::PromptTemplate *m_currentChatPrompt;
QMap<QString, Templates::PromptTemplate *> m_fimTemplates;
QMap<QString, Templates::PromptTemplate *> m_chatTemplates;
Templates::PromptTemplate *m_currentFimTemplate;
Templates::PromptTemplate *m_currentChatTemplate;
};

} // namespace QodeAssist
9 changes: 5 additions & 4 deletions chat/ChatClientInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ ChatClientInterface::ChatClientInterface(QObject *parent)

ChatClientInterface::~ChatClientInterface()
{
logMessage("ChatClientInterface destroyed");
}

void ChatClientInterface::sendMessage(const QString &message)
Expand All @@ -65,9 +64,11 @@ void ChatClientInterface::sendMessage(const QString &message)
prepareRequest(providerRequest, message);

LLMConfig config;
config.requestType = RequestType::Chat;
config.provider = LLMProvidersManager::instance().getCurrentChatProvider();
config.promptTemplate = PromptTemplateManager::instance().getCurrentTemplate();
config.url = QString("%1%2").arg(Settings::generalSettings().url(), "/api/chat");
config.promptTemplate = PromptTemplateManager::instance().getCurrentChatTemplate();
config.url = QString("%1%2").arg(Settings::generalSettings().chatUrl(),
Settings::generalSettings().chatEndPoint());
config.providerRequest = providerRequest;

QJsonObject request;
Expand All @@ -81,7 +82,7 @@ void ChatClientInterface::prepareRequest(QJsonObject &request, const QString &me
{
auto &settings = Settings::presetPromptsSettings();

request["model"] = Settings::generalSettings().modelName(); //MODEL_NAME;
request["model"] = Settings::generalSettings().chatModelName();

QJsonArray messages = {QJsonObject{{"role", "user"}, {"content", message}}};
request["messages"] = messages;
Expand Down
4 changes: 0 additions & 4 deletions chat/ChatClientInterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ class ChatClientInterface : public QObject

LLMRequestHandler *m_requestHandler;
QString m_accumulatedResponse;

const QString MODEL_NAME = "bartowski/Llama-3.1-SauerkrautLM-8b-Instruct-GGUF";
const QString SERVER_URL = "http://localhost:1234";
const QString ENDPOINT = "/v1/chat/completions";
};

} // namespace QodeAssist::Chat
2 changes: 1 addition & 1 deletion chat/ChatWidget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace QodeAssist::Chat {

ChatWidget::ChatWidget(QWidget *parent)
: QWidget(parent)
, m_showTimestamp(true)
, m_showTimestamp(false)
, m_chatClient(new ChatClientInterface(this))
{
setupUi();
Expand Down
3 changes: 3 additions & 0 deletions core/LLMRequestConfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@

namespace QodeAssist {

enum class RequestType { Fim, Chat };

struct LLMConfig
{
QUrl url;
Providers::LLMProvider *provider;
Templates::PromptTemplate *promptTemplate;
QJsonObject providerRequest;
RequestType requestType;
};

} // namespace QodeAssist
26 changes: 13 additions & 13 deletions core/LLMRequestHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,24 @@ void LLMRequestHandler::handleLLMResponse(QNetworkReply *reply,

QString &accumulatedResponse = m_accumulatedResponses[reply];

auto provider = LLMProvidersManager::instance().getCurrentFimProvider();
if (provider == nullptr)
qDebug() << "No provider selected";
bool isComplete = config.provider->handleResponse(reply, accumulatedResponse);

bool isComplete = LLMProvidersManager::instance()
.getCurrentFimProvider()
->handleResponse(reply, accumulatedResponse);

if (!Settings::generalSettings().multiLineCompletion()
&& processSingleLineCompletion(reply, request, accumulatedResponse, config)) {
return;
if (config.requestType == RequestType::Fim) {
if (!Settings::generalSettings().multiLineCompletion()
&& processSingleLineCompletion(reply, request, accumulatedResponse, config)) {
return;
}
}

if (isComplete || reply->isFinished()) {
if (isComplete) {
// auto cleanedCompletion = removeStopWords(accumulatedResponse,
// config.promptTemplate->stopWords());
emit completionReceived(accumulatedResponse, request, true);
if (config.requestType == RequestType::Fim) {
auto cleanedCompletion = removeStopWords(accumulatedResponse,
config.promptTemplate->stopWords());
emit completionReceived(cleanedCompletion, request, true);
} else {
emit completionReceived(accumulatedResponse, request, true);
}
} else {
emit completionReceived(accumulatedResponse, request, false);
}
Expand Down
4 changes: 0 additions & 4 deletions providers/LMStudioProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ QString LMStudioProvider::chatEndpoint() const
void LMStudioProvider::prepareRequest(QJsonObject &request)
{
auto &settings = Settings::presetPromptsSettings();
const auto &currentTemplate = PromptTemplateManager::instance().getCurrentTemplate();
if (currentTemplate->name() == "Custom Template")
return;
if (request.contains("prompt")) {
QJsonArray messages{
{QJsonObject{{"role", "user"}, {"content", request.take("prompt").toString()}}}};
Expand All @@ -67,7 +64,6 @@ void LMStudioProvider::prepareRequest(QJsonObject &request)

request["max_tokens"] = settings.maxTokens();
request["temperature"] = settings.temperature();
request["stop"] = QJsonArray::fromStringList(currentTemplate->stopWords());
if (settings.useTopP())
request["top_p"] = settings.topP();
if (settings.useTopK())
Expand Down
4 changes: 0 additions & 4 deletions providers/OllamaProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,11 @@ QString OllamaProvider::chatEndpoint() const
void OllamaProvider::prepareRequest(QJsonObject &request)
{
auto &settings = Settings::presetPromptsSettings();
auto currentTemplate = PromptTemplateManager::instance().getCurrentTemplate();
if (currentTemplate->name() == "Custom Template")
return;

QJsonObject options;
options["num_predict"] = settings.maxTokens();
options["keep_alive"] = settings.ollamaLivetime();
options["temperature"] = settings.temperature();
options["stop"] = QJsonArray::fromStringList(currentTemplate->stopWords());
if (settings.useTopP())
options["top_p"] = settings.topP();
if (settings.useTopK())
Expand Down
5 changes: 0 additions & 5 deletions providers/OpenAICompatProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ QString OpenAICompatProvider::chatEndpoint() const
void OpenAICompatProvider::prepareRequest(QJsonObject &request)
{
auto &settings = Settings::presetPromptsSettings();
const auto &currentTemplate = PromptTemplateManager::instance().getCurrentTemplate();
if (currentTemplate->name() == "Custom Template")
return;

if (request.contains("prompt")) {
QJsonArray messages{
{QJsonObject{{"role", "user"}, {"content", request.take("prompt").toString()}}}};
Expand All @@ -66,7 +62,6 @@ void OpenAICompatProvider::prepareRequest(QJsonObject &request)

request["max_tokens"] = settings.maxTokens();
request["temperature"] = settings.temperature();
request["stop"] = QJsonArray::fromStringList(currentTemplate->stopWords());
if (settings.useTopP())
request["top_p"] = settings.topP();
if (settings.useTopK())
Expand Down
3 changes: 2 additions & 1 deletion qodeassist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,14 @@ class QodeAssistPlugin final : public ExtensionSystem::IPlugin
providerManager.registerProvider<Providers::LMStudioProvider>();
providerManager.registerProvider<Providers::OpenAICompatProvider>();
providerManager.setCurrentFimProvider("Ollama");
providerManager.setCurrentChatProvider("Ollama");

auto &templateManager = PromptTemplateManager::instance();
templateManager.registerTemplate<Templates::CodeLLamaTemplate>();
templateManager.registerTemplate<Templates::StarCoder2Template>();
templateManager.registerTemplate<Templates::DeepSeekCoderV2Template>();
templateManager.registerTemplate<Templates::CustomTemplate>();
templateManager.setCurrentTemplate("StarCoder2");
templateManager.setCurrentFimTemplate("StarCoder2");

Utils::Icon QCODEASSIST_ICON(
{{":/resources/images/qoderassist-icon.png", Utils::Theme::IconsBaseColor}});
Expand Down
1 change: 1 addition & 0 deletions templates/CodeLLamaTemplate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace QodeAssist::Templates {
class CodeLLamaTemplate : public PromptTemplate
{
public:
TemplateType type() const override { return TemplateType::Fim; }
QString name() const override { return "CodeLlama"; }
QString promptTemplate() const override { return "%1<PRE> %2 <SUF>%3 <MID>"; }
QStringList stopWords() const override
Expand Down
3 changes: 2 additions & 1 deletion templates/CustomTemplate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace QodeAssist::Templates {
class CustomTemplate : public PromptTemplate
{
public:
QString name() const override { return "Custom Template"; }
TemplateType type() const override { return TemplateType::Fim; }
QString name() const override { return "Custom FIM Template"; }
QString promptTemplate() const override
{
return Settings::customPromptSettings().customJsonTemplate();
Expand Down
1 change: 1 addition & 0 deletions templates/DeepSeekCoderV2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace QodeAssist::Templates {
class DeepSeekCoderV2Template : public PromptTemplate
{
public:
TemplateType type() const override { return TemplateType::Fim; }
QString name() const override { return "DeepSeekCoderV2"; }
QString promptTemplate() const override
{
Expand Down
3 changes: 3 additions & 0 deletions templates/PromptTemplate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@

namespace QodeAssist::Templates {

enum class TemplateType { Chat, Fim };

class PromptTemplate
{
public:
virtual ~PromptTemplate() = default;
virtual TemplateType type() const = 0;
virtual QString name() const = 0;
virtual QString promptTemplate() const = 0;
virtual QStringList stopWords() const = 0;
Expand Down
1 change: 1 addition & 0 deletions templates/StarCoder2Template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace QodeAssist::Templates {
class StarCoder2Template : public PromptTemplate
{
public:
TemplateType type() const override { return TemplateType::Fim; }
QString name() const override { return "StarCoder2"; }
QString promptTemplate() const override { return "%1<fim_prefix>%2<fim_suffix>%3<fim_middle>"; }
QStringList stopWords() const override
Expand Down

0 comments on commit 6f2029f

Please sign in to comment.