diff --git a/spx-backend/internal/aigc/aigc.go b/spx-backend/internal/aigc/aigc.go index 49e717505..c4b57401f 100644 --- a/spx-backend/internal/aigc/aigc.go +++ b/spx-backend/internal/aigc/aigc.go @@ -20,7 +20,7 @@ func NewAigcClient(endpoint string) *AigcClient { return &AigcClient{ endpoint: endpoint, client: &http.Client{ - Timeout: 20 * time.Second, + Timeout: 60 * time.Second, }, } } diff --git a/spx-backend/internal/model/milvus.go b/spx-backend/internal/model/milvus.go index 8a7d2f9ca..9014b07d2 100644 --- a/spx-backend/internal/model/milvus.go +++ b/spx-backend/internal/model/milvus.go @@ -83,6 +83,38 @@ func SearchByVector(ctx context.Context, cli client.Client, collectionName strin return assetNames, nil } +func ExistsMilvusAsset(ctx context.Context, cli client.Client, assetID string) bool { + logger := log.GetReqLogger(ctx) + + if cli == nil || assetID == "" { + logger.Printf("Invalid input: %v, %v", cli, assetID) + return false + } + + opt := client.SearchQueryOptionFunc(func(option *client.SearchQueryOption) { + option.Limit = 3 + option.Offset = 0 + option.ConsistencyLevel = entity.ClStrong + option.IgnoreGrowing = false + }) + + // Search for the asset ID in the collection + _, err := cli.Query( + ctx, + "asset", + []string{}, + "asset_id == '"+assetID+"'", + []string{"asset_id"}, + opt, + ) + if err != nil { + logger.Printf("Failed to search: %v", err) + return false + } + + return true +} + // Add an asset func AddMilvusAsset(ctx context.Context, cli client.Client, asset *MilvusAsset) error { logger := log.GetReqLogger(ctx) @@ -92,6 +124,11 @@ func AddMilvusAsset(ctx context.Context, cli client.Client, asset *MilvusAsset) return nil } + if ExistsMilvusAsset(ctx, cli, asset.AssetID) { + logger.Printf("Asset %s already exists in Milvus", asset.AssetName) + return nil + } + vector := asset.Vector columns := []entity.Column{ diff --git a/spx-backend/internal/model/user_asset.go b/spx-backend/internal/model/user_asset.go index 78ce3ae4b..8c535b6c6 100644 --- a/spx-backend/internal/model/user_asset.go +++ b/spx-backend/internal/model/user_asset.go @@ -40,7 +40,20 @@ const TableUserAsset = "user_asset" // AddUserAsset adds an asset. func AddUserAsset(ctx context.Context, db *gorm.DB, p *UserAsset) error { logger := log.GetReqLogger(ctx) - result := db.Create(p) + + // check if the asset already exists + var count int64 + result := db.Model(&UserAsset{}).Where("asset_id = ? AND relation_type = ? AND owner = ?", p.AssetID, p.RelationType, p.Owner).Count(&count) + if result.Error != nil { + logger.Printf("failed to check if asset exists: %v", result.Error) + return result.Error + } + + if count > 0 { + return nil + } + + result = db.Create(p) if result.Error != nil { logger.Printf("failed to add asset: %v", result.Error) return result.Error diff --git a/spx-backend/internal/model/user_asset_test.go b/spx-backend/internal/model/user_asset_test.go index bc36c593c..d7065439d 100644 --- a/spx-backend/internal/model/user_asset_test.go +++ b/spx-backend/internal/model/user_asset_test.go @@ -25,6 +25,10 @@ func TestAddUserAsset(t *testing.T) { }), &gorm.Config{}) require.NoError(t, err) + mock.ExpectQuery("SELECT count(*) FROM `user_assets` WHERE asset_id = ? AND relation_type = ? AND owner = ?"). + WithArgs(1, "owned", "user1"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectBegin() mock.ExpectExec("INSERT INTO `user_assets` (`owner`,`asset_id`,`relation_type`,`relation_timestamp`) VALUES (?,?,?,?)"). WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()). @@ -55,6 +59,10 @@ func TestAddUserAsset(t *testing.T) { SkipInitializeWithVersion: true}), &gorm.Config{}) require.NoError(t, err) + mock.ExpectQuery("SELECT count(*) FROM `user_assets` WHERE asset_id = ? AND relation_type = ? AND owner = ?"). + WithArgs(1, "owned", "user1"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectBegin() mock.ExpectExec("INSERT INTO `user_assets` (`owner`,`asset_id`,`relation_type`,`relation_timestamp`) VALUES (?,?,?,?)"). WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()). diff --git a/spx-backend/loadToMilvus/loadToMilvus.go b/spx-backend/loadToMilvus/loadToMilvus.go new file mode 100644 index 000000000..d943c3441 --- /dev/null +++ b/spx-backend/loadToMilvus/loadToMilvus.go @@ -0,0 +1,131 @@ +// This script is used to load all assets to Milvus. +// +// The script reads the assets from the database +// and then calls the AIGC service to get the embeddings of the assets +// and inserts them into the Milvus database. +package main + +import ( + "context" + "database/sql" + "errors" + _ "image/png" + "io/fs" + "net/http" + "os" + + _ "github.com/go-sql-driver/mysql" + "github.com/goplus/builder/spx-backend/internal/aigc" + "github.com/goplus/builder/spx-backend/internal/log" + "github.com/joho/godotenv" + milvus "github.com/milvus-io/milvus-sdk-go/v2/client" + _ "github.com/qiniu/go-cdk-driver/kodoblob" + qiniuLog "github.com/qiniu/x/log" + + "github.com/goplus/builder/spx-backend/internal/controller" + "github.com/goplus/builder/spx-backend/internal/model" +) + +var ( + ErrNotExist = errors.New("not exist") + ErrUnauthorized = errors.New("unauthorized") + ErrForbidden = errors.New("forbidden") +) + +func Load() (err error) { + logger := log.GetLogger() + ctx := context.Background() + + if err := godotenv.Load(); err != nil && !errors.Is(err, fs.ErrNotExist) { + logger.Printf("failed to load env: %v", err) + return err + } + + dsn := mustEnv(logger, "GOP_SPX_DSN") + db, err := sql.Open("mysql", dsn) + if err != nil { + logger.Printf("failed to connect sql: %v", err) + return err + } + + aigcClient := aigc.NewAigcClient(mustEnv(logger, "AIGC_ENDPOINT")) + + var milvusClient milvus.Client + if os.Getenv("ENV") != "test" && os.Getenv("MILVUS_ADDRESS") != "disabled" { + milvusClient, err = milvus.NewClient(ctx, milvus.Config{ + Address: os.Getenv("MILVUS_ADDRESS"), + }) + if err != nil { + logger.Printf("failed to create milvus client: %v,%v", err, os.Getenv("MILVUS_ADDRESS")) + return err + } + } + + // load all assets from the database + assets, err := LoadAssets(ctx, db) + if err != nil { + logger.Printf("Failed to load assets: %v", err) + return err + } + + // for each asset, call embedding service to get the embedding + for i, asset := range assets { + logger.Printf("Processing asset %d/%d: %s", i+1, len(assets), asset.DisplayName) + + // check if the asset id is already in the milvus + if model.ExistsMilvusAsset(ctx, milvusClient, asset.ID) { + logger.Printf("Asset %s already exists in Milvus", asset.DisplayName) + continue + } + + var embeddingResult controller.GetEmbeddingResult + err = aigcClient.Call(ctx, http.MethodPost, "/embedding", &controller.GetEmbeddingParams{ + Prompt: asset.DisplayName, + CallbackUrl: "", + }, &embeddingResult) + + if err != nil { + logger.Printf("failed to call: %v", err) + return err + } + + // insert the embedding into the milvus + model.AddMilvusAsset(ctx, milvusClient, &model.MilvusAsset{ + AssetID: asset.ID, + AssetName: asset.DisplayName, + Vector: embeddingResult.Embedding, + }) + } + + return nil +} + +// LoadAssets loads all assets from the database. +func LoadAssets(ctx context.Context, db *sql.DB) ([]model.Asset, error) { + logger := log.GetReqLogger(ctx) + + assets, err := model.ListAssets(ctx, db, model.Pagination{ + Index: 1, + Size: 65535, + }, nil, nil, nil) + if err != nil { + logger.Printf("ListAssets failed: %v", err) + return nil, err + } + return assets.Data, nil +} + +// mustEnv gets the environment variable value or exits the program. +func mustEnv(logger *qiniuLog.Logger, key string) string { + value := os.Getenv(key) + if value == "" { + logger.Fatalf("Missing required environment variable: %s", key) + } + return value +} + +func main() { + if err := Load(); err != nil { + os.Exit(1) + } +} diff --git a/spx-gui/src/components/asset/animation/VideoRecorder.vue b/spx-gui/src/components/asset/animation/VideoRecorder.vue index cceb1f04c..3fbb6573e 100644 --- a/spx-gui/src/components/asset/animation/VideoRecorder.vue +++ b/spx-gui/src/components/asset/animation/VideoRecorder.vue @@ -3,7 +3,7 @@