From 64416f4914ee2905d87be79ebc1078ec81a98b89 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Sat, 29 Jul 2023 15:29:48 +0800 Subject: [PATCH 1/2] fix: factorization model diverged --- base/progress/progress.go | 23 +++++++++++++++++++---- master/tasks.go | 4 +++- model/click/model.go | 1 + storage/cache/mongodb.go | 3 +++ 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/base/progress/progress.go b/base/progress/progress.go index ddab44463..05027c079 100644 --- a/base/progress/progress.go +++ b/base/progress/progress.go @@ -89,12 +89,15 @@ func (s *Span) Add(n int) { } func (s *Span) End() { - s.status = StatusComplete - s.count = s.total - s.finish = time.Now() + if s.status == StatusRunning { + s.status = StatusComplete + s.count = s.total + s.finish = time.Now() + } } -func (s *Span) Error(err error) { +func (s *Span) Fail(err error) { + s.status = StatusFailed s.err = err.Error() } @@ -111,6 +114,10 @@ func (s *Span) Progress() Progress { if progress.Status == StatusRunning { children = append(children, progress) } + if s.err == "" && progress.Error != "" { + s.err = progress.Error + s.status = StatusFailed + } return true }) // leaf node @@ -162,6 +169,14 @@ func Start(ctx context.Context, name string, total int) (context.Context, *Span) return context.WithValue(ctx, spanKeyName, childSpan), childSpan } +func Fail(ctx context.Context, err error) { + span, ok := (ctx).Value(spanKeyName).(*Span) + if !ok { + return + } + span.Fail(err) +} + type Progress struct { Tracer string Name string diff --git a/master/tasks.go b/master/tasks.go index 666c66086..cdeb2f8b3 100644 --- a/master/tasks.go +++ b/master/tasks.go @@ -223,7 +223,7 @@ func (t *FindItemNeighborsTask) run(ctx context.Context, j *task.JobsAllocator) numItems := dataset.ItemCount() numFeedback := dataset.Count() - _, span := t.tracer.Start(ctx, "Find Item Neighbors", dataset.ItemCount()) + newCtx, span := t.tracer.Start(ctx, "Find Item Neighbors", dataset.ItemCount()) defer span.End() if numItems == 0 { @@ -312,6 +312,7 @@ func (t *FindItemNeighborsTask) run(ctx context.Context, j *task.JobsAllocator) close(completed) if err != nil { log.Logger().Error("failed to searching neighbors of items", zap.Error(err)) + progress.Fail(newCtx, err) FindItemNeighborsTotalSeconds.Set(0) } else { if err := t.CacheClient.Set(ctx, cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateItemNeighborsTime), time.Now())); err != nil { @@ -641,6 +642,7 @@ func (t *FindUserNeighborsTask) run(ctx context.Context, j *task.JobsAllocator) close(completed) if err != nil { log.Logger().Error("failed to searching neighbors of users", zap.Error(err)) + progress.Fail(newCtx, err) FindUserNeighborsTotalSeconds.Set(0) } else { if err := t.CacheClient.Set(ctx, cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateUserNeighborsTime), time.Now())); err != nil { diff --git a/model/click/model.go b/model/click/model.go index 0b43e20ce..4c20f659d 100644 --- a/model/click/model.go +++ b/model/click/model.go @@ -434,6 +434,7 @@ func (fm *FM) Fit(ctx context.Context, trainSet, testSet *Dataset, config *FitCo // check NaN if math32.IsNaN(cost) || math32.IsNaN(score.GetValue()) { log.Logger().Warn("model diverged", zap.Float32("lr", fm.lr)) + span.Fail(errors.New("model diverged")) break } snapshots.AddSnapshot(score, fm.V, fm.W, fm.B) diff --git a/storage/cache/mongodb.go b/storage/cache/mongodb.go index 6ba7e0567..cfbc8195b 100644 --- a/storage/cache/mongodb.go +++ b/storage/cache/mongodb.go @@ -318,6 +318,9 @@ func (m MongoDB) Remain(ctx context.Context, name string) (int64, error) { } func (m MongoDB) AddDocuments(ctx context.Context, collection, subset string, documents []Document) error { + if len(documents) == 0 { + return nil + } var models []mongo.WriteModel for _, document := range documents { models = append(models, mongo.NewUpdateOneModel(). From 5dc9ef6b2bc63f0d7f956069d010daac1493c35f Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Sat, 29 Jul 2023 17:24:15 +0800 Subject: [PATCH 2/2] use math32.Log1p --- model/click/model.go | 12 ++++++++++-- model/ranking/model.go | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/model/click/model.go b/model/click/model.go index 4c20f659d..1dab347f9 100644 --- a/model/click/model.go +++ b/model/click/model.go @@ -327,8 +327,8 @@ func (fm *FM) Fit(ctx context.Context, trainSet, testSet *Dataset, config *FitCo cost += grad * grad / 2 case FMClassification: grad = -target * (1 - 1/(1+math32.Exp(-target*prediction))) - cost += (1 + target) * math32.Log(1+math32.Exp(-prediction)) / 2 - cost += (1 - target) * math32.Log(1+math32.Exp(prediction)) / 2 + cost += (1 + target) * math32.Log1p(exp(-prediction)) / 2 + cost += (1 - target) * math32.Log1p(exp(prediction)) / 2 default: log.Logger().Fatal("unknown task", zap.String("task", string(fm.Task))) } @@ -646,3 +646,11 @@ func (fm *FM) Unmarshal(r io.Reader) error { } return nil } + +func exp(x float32) float32 { + e := math32.Exp(x) + if math32.IsInf(e, 1) { + return math32.MaxFloat32 + } + return e +} diff --git a/model/ranking/model.go b/model/ranking/model.go index 43685b77b..2e5e3c747 100644 --- a/model/ranking/model.go +++ b/model/ranking/model.go @@ -448,7 +448,7 @@ func (bpr *BPR) Fit(ctx context.Context, trainSet, valSet *DataSet, config *FitC } } diff := bpr.InternalPredict(userIndex, posIndex) - bpr.InternalPredict(userIndex, negIndex) - cost[workerId] += math32.Log(1 + math32.Exp(-diff)) + cost[workerId] += math32.Log1p(math32.Exp(-diff)) grad := math32.Exp(-diff) / (1.0 + math32.Exp(-diff)) // Pairwise update copy(userFactor[workerId], bpr.UserFactor[userIndex])