Skip to content

Commit

Permalink
✨ メッセージからFileのUUIDをパースし、Base64に変換し、OpenAIに画像認識できる機能を追加
Browse files Browse the repository at this point in the history
  • Loading branch information
pikachu0310 committed May 15, 2024
1 parent c7d160d commit ef16b09
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 8 deletions.
18 changes: 15 additions & 3 deletions internal/bot/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import (
traqwsbot "github.com/traPtitech/traq-ws-bot"
"log"
"os"
"strings"
)

var (
Bot *traqwsbot.Bot
Info *traq.MyUserDetail
)

func init() {
func InitBot() {
token := GetToken()

bot, err := traqwsbot.NewBot(&traqwsbot.Options{
Expand All @@ -29,9 +30,20 @@ func init() {
if err != nil || res.StatusCode != 200 {
log.Fatalf("error: 自分の情報を取得できませんでした: %v", err)
}

Bot = bot
Info = botInfo
}

func RemoveFirstBotId(input string) string {

Check warning on line 38 in internal/bot/bot.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: func RemoveFirstBotId should be RemoveFirstBotID (revive)
BotId := Info.Id

Check warning on line 39 in internal/bot/bot.go

View workflow job for this annotation

GitHub Actions / Lint

var-naming: var BotId should be BotID (revive)
index := strings.Index(input, BotId)
if index == -1 {
return input
}
return input[:index] + input[index+len(BotId):]

Check failure on line 44 in internal/bot/bot.go

View workflow job for this annotation

GitHub Actions / Lint

return with no blank line before (nlreturn)
}

func GetToken() (token string) {
token, exist := os.LookupEnv("BOT_TOKEN")
if !exist {
Expand All @@ -40,8 +52,8 @@ func GetToken() (token string) {
return token

Check failure on line 52 in internal/bot/bot.go

View workflow job for this annotation

GitHub Actions / Lint

return with no blank line before (nlreturn)
}

func GetBot() (bot *traqwsbot.Bot) {
return bot
func GetBot() *traqwsbot.Bot {
return Bot
}

func BotJoin(ChannelID string) error {

Check warning on line 59 in internal/bot/bot.go

View workflow job for this annotation

GitHub Actions / Lint

exported: func name will be used as bot.BotJoin by other packages, and that stutters; consider calling this Join (revive)
Expand Down
112 changes: 112 additions & 0 deletions internal/bot/file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package bot

import (
"context"
"encoding/base64"
"github.com/traPtitech/go-traq"
"io"
"log"
"os"
"regexp"
"strings"
)

func GetFileMetadata(fileID string) *traq.FileInfo {
bot := GetBot()

fileInfo, _, err := bot.API().
FileApi.GetFileMeta(context.Background(), fileID).
Execute()
if err != nil {
log.Println(err)
}

return fileInfo
}

func GetFileData(fileID string) *os.File {
bot := GetBot()

fileData, _, err := bot.API().
FileApi.GetFile(context.Background(), fileID).
Execute()
if err != nil {
log.Println(err)
}

return fileData
}

func ConvertFileToBase64IfFileIsImage(fileID string) *string {
if !isImage(fileID) {
log.Println("Not an image")
return nil
}

fileData := GetFileData(fileID)
if fileData == nil {
log.Println("Failed to get file data")
return nil
}
defer fileData.Close()

base64Data, err := fileToBase64(fileData)
if err != nil {
log.Printf("Error reading file: %v\n", err)
return nil
}

return base64Data
}

// isImage は、MIMEタイプが画像であるかを判定する関数
func isImage(fileID string) bool {
fileInfo := GetFileMetadata(fileID)
if fileInfo == nil {
log.Println("Failed to get file metadata")
return false
}

return strings.HasPrefix(fileInfo.Mime, "image/")
}

// fileToBase64 は、ファイルを読み込み、BASE64エンコードされた文字列に変換する関数
func fileToBase64(file *os.File) (*string, error) {
data, err := io.ReadAll(file)
if err != nil {
return nil, err
}

base64Data := base64.StdEncoding.EncodeToString(data)
return &base64Data, nil
}

func ExtractFileUUIDs(text string) []string {
const pattern = `https://q\.trap\.jp/files/([a-fA-F0-9-]{36})`

re := regexp.MustCompile(pattern)
matches := re.FindAllStringSubmatch(text, -1)

var uuids []string
for _, match := range matches {
if len(match) > 1 {
uuids = append(uuids, match[1])
}
}

return uuids
}

func GetBase64ImagesFromMessage(message string) []string {
fileUUIDs := ExtractFileUUIDs(message)
var base64Images []string

for _, uuid := range fileUUIDs {
base64Data := ConvertFileToBase64IfFileIsImage(uuid)
if base64Data != nil {
base64Images = append(base64Images, *base64Data)
}
}

return base64Images
}
37 changes: 34 additions & 3 deletions internal/gpt/gpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ var SystemRoleMessage = "あなたはサークルである東京工業大学デ

const GptSystemString = "FirstSystemMessageを変更しました。/gptsys showで確認できます。\nFirstSystemMessageとは、常に履歴の一番最初に入り、最初にgptに情報や状況を説明するのに使用する文字列です"

func init() {
func InitGPT() {
apiKey = getApiKey()
}

Expand Down Expand Up @@ -108,8 +108,12 @@ func OpenAIStream(messages []Message, do func(string)) (responseMessage string,
return
}

func Chat(channelID, newMessage string) {
addMessageAsUser(newMessage)
func Chat(channelID, newMessage string, imageBase64 []string) {
if len(imageBase64) >= 1 {
addImageAndTextAsUser(newMessage, imageBase64)
} else {
addMessageAsUser(newMessage)
}
updateSystemRoleMessage(SystemRoleMessage)
postMessage, err := bot.PostMessageWithErr(channelID, blobs[rand.Intn(len(blobs))]+":loading:")
if err != nil {
Expand Down Expand Up @@ -170,6 +174,33 @@ func addMessageAsUser(message string) {
})
}

func addImageAndTextAsUser(message string, imageDataBase64 []string) {
var parts []openai.ChatMessagePart

parts = append(parts, openai.ChatMessagePart{
Type: openai.ChatMessagePartTypeText,
Text: message,
})

for _, b64 := range imageDataBase64 {
imageURL := "data:image/jpeg;base64," + b64
imagePart := openai.ChatMessagePart{
Type: openai.ChatMessagePartTypeImageURL,
ImageURL: &openai.ChatMessageImageURL{
URL: imageURL,
},
}
parts = append(parts, imagePart)
}

messagePart := Message{
Role: "user",
MultiContent: parts,
}

Messages = append(Messages, messagePart)
}

func addMessageAsAssistant(message string) {
Messages = append(Messages, Message{
Role: openai.ChatMessageRoleAssistant,
Expand Down
14 changes: 13 additions & 1 deletion internal/handler/OnMessageCreated.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package handler

import (
"fmt"
"github.com/pikachu0310/BOT_GPT/internal/bot"
"github.com/pikachu0310/BOT_GPT/internal/gpt"
"github.com/traPtitech/traq-ws-bot/payload"
"log"
Expand All @@ -18,6 +20,16 @@ func (h *Handler) MessageReceived() func(p *payload.MessageCreated) {
return
}

gpt.Chat(p.Message.ChannelID, p.Message.PlainText)
plainTextWithoutMention := bot.RemoveFirstBotId(p.Message.PlainText)

// if first 5 text = debug
if len(plainTextWithoutMention) >= 5 && plainTextWithoutMention[:5] == "debug" {
fmt.Printf("embed: %#v\n", p.Message.Embedded)
return
}

imagesBase64 := bot.GetBase64ImagesFromMessage(p.Message.Text)

gpt.Chat(p.Message.ChannelID, plainTextWithoutMention, imagesBase64)
}
}
10 changes: 9 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"github.com/jmoiron/sqlx"
"github.com/joho/godotenv"
"github.com/pikachu0310/BOT_GPT/internal/bot"
"github.com/pikachu0310/BOT_GPT/internal/gpt"
"github.com/pikachu0310/BOT_GPT/internal/handler"
"github.com/pikachu0310/BOT_GPT/internal/pkg/config"
"github.com/pikachu0310/BOT_GPT/internal/repository"
Expand All @@ -14,9 +15,12 @@ import (
func main() {
err := godotenv.Load(".env")
if err != nil {
fmt.Printf("error: .env is not exist: %v", err)
fmt.Printf(".env is not exist: %v", err)
}

bot.InitBot()
gpt.InitGPT()

// connect to database
db, err := sqlx.Connect("mysql", config.MySQL().FormatDSN())
if err != nil {
Expand All @@ -33,4 +37,8 @@ func main() {
// setup bot
traQBot := bot.GetBot()
traQBot.OnMessageCreated(h.MessageReceived())

if err := traQBot.Start(); err != nil {
panic(err)
}
}

0 comments on commit ef16b09

Please sign in to comment.