diff --git a/bcs-common/common/task/manager.go b/bcs-common/common/task/manager.go index baf5e480e7..719f26bcab 100644 --- a/bcs-common/common/task/manager.go +++ b/bcs-common/common/task/manager.go @@ -374,63 +374,58 @@ func (m *TaskManager) doWork(taskID string, stepName string) error { // nolint state.updateStepSuccess(start) return nil } - - retErr := fmt.Errorf("task %s step %s running failed, err=%s", taskID, stepName, stepErr) - // 单步骤主动revoke的不再重试 - if errors.Is(stepErr, istep.ErrRevoked) { - state.updateStepFailure(start, stepErr, &taskEndStatus{status: types.TaskStatusFailure}) - return retErr - } - state.updateStepFailure(start, stepErr, nil) - if step.GetRetryCount() < step.MaxRetries { + + // 单步骤主动revoke或者没有重试次数时, 不再重试 + if !errors.Is(stepErr, istep.ErrRevoked) && step.GetRetryCount() < step.MaxRetries { retryIn := time.Second * time.Duration(retryNext(int(step.GetRetryCount()))) - log.INFO.Printf("retry task %s step %s, retried=%d, maxRetries=%d, retryIn=%s", - taskID, step.GetName(), step.GetRetryCount(), step.MaxRetries, retryIn) - return tasks.NewErrRetryTaskLater(retErr.Error(), retryIn) + log.INFO.Printf("retry task %s step %s, err=%s, retried=%d, maxRetries=%d, retryIn=%s", + taskID, stepName, stepErr, step.GetRetryCount(), step.MaxRetries, retryIn) + return tasks.NewErrRetryTaskLater(stepErr.Error(), retryIn) } if step.GetSkipOnFailed() { return nil } + retErr := fmt.Errorf("task %s step %s running failed, err=%w", taskID, stepName, stepErr) return retErr case <-stepCtx.Done(): + // step timeout stepErr := fmt.Errorf("step exec timeout") - - retErr := fmt.Errorf("%w, task=%s, step=%s", stepErr, taskID, stepName) state.updateStepFailure(start, stepErr, nil) if step.GetRetryCount() < step.MaxRetries { retryIn := time.Second * time.Duration(retryNext(int(step.GetRetryCount()))) - log.INFO.Printf("retry task %s step %s, retried=%d, maxRetries=%d, retryIn=%s", - taskID, step.GetName(), step.GetRetryCount(), step.MaxRetries, retryIn) - return tasks.NewErrRetryTaskLater(retErr.Error(), retryIn) + log.INFO.Printf("retry task %s step %s, err=%s, retried=%d, maxRetries=%d, retryIn=%s", + taskID, stepName, stepErr, step.GetRetryCount(), step.MaxRetries, retryIn) + return tasks.NewErrRetryTaskLater(stepErr.Error(), retryIn) } if step.GetSkipOnFailed() { return nil } + retErr := fmt.Errorf("task %s step %s running failed, err=%w", taskID, stepName, stepErr) return retErr case <-revokeCtx.Done(): // task revoke stepErr := fmt.Errorf("task has been revoked") - retErr := fmt.Errorf("%w, task=%s, step=%s", stepErr, taskID, stepName) state.updateStepFailure(start, stepErr, &taskEndStatus{status: types.TaskStatusRevoked}) // 取消指令, 不再重试 + retErr := fmt.Errorf("task %s step %s running failed, err=%w", taskID, stepName, stepErr) return retErr case <-taskCtx.Done(): // task timeout stepErr := fmt.Errorf("task exec timeout") - retErr := fmt.Errorf("%w, task=%s, step=%s", stepErr, taskID, stepName) state.updateStepFailure(start, stepErr, &taskEndStatus{status: types.TaskStatusTimeout}) // 整个任务结束 + retErr := fmt.Errorf("task %s step %s running failed, err=%w", taskID, stepName, stepErr) return retErr case <-m.ctx.Done(): diff --git a/bcs-common/common/task/state.go b/bcs-common/common/task/state.go index c9f4929815..efd1b17288 100644 --- a/bcs-common/common/task/state.go +++ b/bcs-common/common/task/state.go @@ -15,6 +15,7 @@ package task import ( "context" + "errors" "fmt" "time" @@ -290,7 +291,7 @@ func (s *State) updateStepFailure(start time.Time, stepErr error, taskStatus *ta } // 重试流程中 - if s.step.GetRetryCount() < s.step.MaxRetries { + if !errors.Is(stepErr, istep.ErrRevoked) && s.step.GetRetryCount() < s.step.MaxRetries { s.task.SetStatus(types.TaskStatusRunning).SetMessage(taskFailMsg) return } @@ -310,9 +311,9 @@ func (s *State) updateStepFailure(start time.Time, stepErr error, taskStatus *ta func (s *State) isLastStep(step *types.Step) bool { count := len(s.task.Steps) - // 没有step默认返回false + // 没有step也就没有后续流程, 返回true if count == 0 { - return false + return true } // 非最后一步 diff --git a/bcs-common/common/task/stores/mysql/mysql.go b/bcs-common/common/task/stores/mysql/mysql.go index d2dc45a415..cf2a49b36f 100644 --- a/bcs-common/common/task/stores/mysql/mysql.go +++ b/bcs-common/common/task/stores/mysql/mysql.go @@ -83,11 +83,11 @@ func (s *mysqlStore) EnsureTable(ctx context.Context, dst ...any) error { if len(dst) == 0 { dst = []any{&TaskRecord{}, &StepRecord{}} } - return s.db.AutoMigrate(dst...) + return s.db.WithContext(ctx).AutoMigrate(dst...) } func (s *mysqlStore) CreateTask(ctx context.Context, task *types.Task) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { record := getTaskRecord(task) if err := tx.Create(record).Error; err != nil { return err @@ -103,11 +103,11 @@ func (s *mysqlStore) CreateTask(ctx context.Context, task *types.Task) error { } func (s *mysqlStore) ListTask(ctx context.Context, opt *iface.ListOption) ([]types.Task, error) { - return nil, nil + return nil, types.ErrNotImplemented } func (s *mysqlStore) UpdateTask(ctx context.Context, task *types.Task) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { updateTask := getUpdateTaskRecord(task) if err := tx.Model(&TaskRecord{}). Where("task_id = ?", task.TaskID). @@ -133,10 +133,11 @@ func (s *mysqlStore) UpdateTask(ctx context.Context, task *types.Task) error { } func (s *mysqlStore) DeleteTask(ctx context.Context, taskID string) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Where("task_id = ?", taskID).Delete(&TaskRecord{}).Error; err != nil { return err } + if err := tx.Where("task_id = ?", taskID).Delete(&StepRecord{}).Error; err != nil { return err } @@ -145,12 +146,14 @@ func (s *mysqlStore) DeleteTask(ctx context.Context, taskID string) error { } func (s *mysqlStore) GetTask(ctx context.Context, taskID string) (*types.Task, error) { - stepRecord := []*StepRecord{} - if err := s.db.Where("task_id = ?", taskID).Find(&stepRecord).Error; err != nil { + tx := s.db.WithContext(ctx) + taskRecord := TaskRecord{} + if err := tx.Where("task_id = ?", taskID).First(&taskRecord).Error; err != nil { return nil, err } - taskRecord := TaskRecord{} - if err := s.db.Where("task_id = ?", taskID).First(&taskRecord).Error; err != nil { + + stepRecord := []*StepRecord{} + if err := tx.Where("task_id = ?", taskID).Find(&stepRecord).Error; err != nil { return nil, err } return toTask(&taskRecord, stepRecord), nil diff --git a/bcs-common/common/task/stores/mysql/table.go b/bcs-common/common/task/stores/mysql/table.go index 4e48e94a6a..5a657b1e82 100644 --- a/bcs-common/common/task/stores/mysql/table.go +++ b/bcs-common/common/task/stores/mysql/table.go @@ -86,8 +86,8 @@ func (t *TaskRecord) BeforeUpdate(tx *gorm.DB) error { // StepRecord 步骤记录 type StepRecord struct { gorm.Model - TaskID string `json:"taskID" gorm:"type:varchar(255);index:idx_task_id"` // 索引 - Name string `json:"name" gorm:"type:varchar(255)"` + TaskID string `json:"taskID" gorm:"type:varchar(255);uniqueIndex:idx_task_id_step_name"` + Name string `json:"name" gorm:"type:varchar(255);uniqueIndex:idx_task_id_step_name"` Alias string `json:"alias" gorm:"type:varchar(255)"` Executor string `json:"executor" gorm:"type:varchar(255)"` Params map[string]string `json:"input" gorm:"type:text;serializer:json"`