From 98923c371df3f40356669e5bc39c5a493ac1be4e Mon Sep 17 00:00:00 2001 From: 0xff-dev Date: Wed, 24 Jan 2024 11:14:18 +0800 Subject: [PATCH] feat: add a controller implementation of rag --- api/evaluation/v1alpha1/common.go | 21 ++ config/rbac/role.yaml | 12 + controllers/evaluation/rag_controller.go | 394 +++++++++++++++++++++- deploy/charts/arcadia/Chart.yaml | 2 +- deploy/charts/arcadia/templates/rbac.yaml | 12 + main.go | 2 + pkg/evaluation/jobs.go | 344 +++++++++++++++++++ 7 files changed, 783 insertions(+), 4 deletions(-) create mode 100644 pkg/evaluation/jobs.go diff --git a/api/evaluation/v1alpha1/common.go b/api/evaluation/v1alpha1/common.go index cef526144..576cdf3a1 100644 --- a/api/evaluation/v1alpha1/common.go +++ b/api/evaluation/v1alpha1/common.go @@ -17,6 +17,8 @@ limitations under the License. package v1alpha1 import ( + "reflect" + batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" ) @@ -134,3 +136,22 @@ func RagStatus(rag *RAG) (string, RAGPhase, string) { } return status, phase, phaseMsg } + +func RAGSpecChanged(a, b RAGSpec) bool { + if !reflect.DeepEqual(*a.Application, *b.Application) { + return true + } + if !reflect.DeepEqual(a.Datasets, b.Datasets) { + return true + } + if !reflect.DeepEqual(a.JudgeLLM, b.JudgeLLM) { + return true + } + if !reflect.DeepEqual(*a.Storage, *b.Storage) { + return true + } + if a.ServiceAccountName != b.ServiceAccountName { + return true + } + return false +} diff --git a/config/rbac/role.yaml b/config/rbac/role.yaml index 17c6ce854..f398b06e2 100644 --- a/config/rbac/role.yaml +++ b/config/rbac/role.yaml @@ -36,6 +36,12 @@ rules: - get - list - watch +- apiGroups: + - "" + resources: + - persistentvolumeclaims + verbs: + - '*' - apiGroups: - "" resources: @@ -408,6 +414,12 @@ rules: - subjectaccessreviews verbs: - create +- apiGroups: + - batch + resources: + - jobs + verbs: + - '*' - apiGroups: - chain.arcadia.kubeagi.k8s.com.cn resources: diff --git a/controllers/evaluation/rag_controller.go b/controllers/evaluation/rag_controller.go index 65d442f09..fc7c54f03 100644 --- a/controllers/evaluation/rag_controller.go +++ b/controllers/evaluation/rag_controller.go @@ -18,15 +18,35 @@ package evaluationarcadia import ( "context" + "errors" + "fmt" + "reflect" + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/selection" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/util/workqueue" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/predicate" + "sigs.k8s.io/controller-runtime/pkg/source" evaluationarcadiav1alpha1 "github.com/kubeagi/arcadia/api/evaluation/v1alpha1" + "github.com/kubeagi/arcadia/pkg/evaluation" ) +var errJobNotDone = errors.New("wait for the job to complete, go to the next step") + // RAGReconciler reconciles a RAG object type RAGReconciler struct { client.Client @@ -36,6 +56,8 @@ type RAGReconciler struct { //+kubebuilder:rbac:groups=evaluation.arcadia.kubeagi.k8s.com.cn,resources=rags,verbs=get;list;watch;create;update;patch;delete //+kubebuilder:rbac:groups=evaluation.arcadia.kubeagi.k8s.com.cn,resources=rags/status,verbs=get;update;patch //+kubebuilder:rbac:groups=evaluation.arcadia.kubeagi.k8s.com.cn,resources=rags/finalizers,verbs=update +//+kubebuilder:rbac:groups=batch,resources=jobs,verbs=* +//+kubebuilder:rbac:groups="",resources=persistentvolumeclaims,verbs=* // Reconcile is part of the main kubernetes reconciliation loop which aims to // move the current state of the cluster closer to the desired state. @@ -47,16 +69,382 @@ type RAGReconciler struct { // For more details, check Reconcile and its Result here: // - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.12.2/pkg/reconcile func (r *RAGReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { - _ = log.FromContext(ctx) + logger := log.FromContext(ctx) + + logger.V(5).Info("Start RAG Reconcile") // TODO(user): your logic here + instance := &evaluationarcadiav1alpha1.RAG{} + if err := r.Client.Get(ctx, req.NamespacedName, instance); err != nil { + if k8serrors.IsNotFound(err) { + return ctrl.Result{}, nil + } + logger.V(1).Info("failed to get rag") + return ctrl.Result{}, err + } + if instance.DeletionTimestamp != nil { + return ctrl.Result{}, nil + } - return ctrl.Result{}, nil + if instance.Labels == nil { + instance.Labels = make(map[string]string) + } + if app, ok := instance.Labels[evaluationarcadiav1alpha1.EvaluationApplicationLabel]; !ok || app != instance.Spec.Application.Name { + instance.Labels[evaluationarcadiav1alpha1.EvaluationApplicationLabel] = instance.Spec.Application.Name + err := r.Client.Update(ctx, instance) + if err != nil { + logger.Error(err, "failed to add application name label") + } + return ctrl.Result{}, err + } + return r.phaseHandler(ctx, instance) } // SetupWithManager sets up the controller with the Manager. func (r *RAGReconciler) SetupWithManager(mgr ctrl.Manager) error { return ctrl.NewControllerManagedBy(mgr). - For(&evaluationarcadiav1alpha1.RAG{}). + For(&evaluationarcadiav1alpha1.RAG{}, builder.WithPredicates(predicate.Funcs{ + UpdateFunc: func(ue event.UpdateEvent) bool { + n := ue.ObjectNew.(*evaluationarcadiav1alpha1.RAG) + o := ue.ObjectOld.(*evaluationarcadiav1alpha1.RAG) + if !reflect.DeepEqual(n.Spec, o.Spec) { + // If the spec portion of the RAG changes, the process needs to be re-executed + if evaluationarcadiav1alpha1.RAGSpecChanged(n.Spec, o.Spec) { + _ = r.DeleteJobsAndPvc(context.TODO(), n) + return false + } + return true + } + if evaluationarcadiav1alpha1.RagStatusChanged(n.Status, o.Status) { + return true + } + if !reflect.DeepEqual(n.Labels, o.Labels) { + return true + } + return false + }, + })). + Watches(&source.Kind{ + Type: &corev1.PersistentVolumeClaim{}, + }, handler.Funcs{ + DeleteFunc: func(de event.DeleteEvent, rli workqueue.RateLimitingInterface) { + pvc := de.Object.(*corev1.PersistentVolumeClaim) + r.WhenPVCDeleted(pvc) + }, + }). + Watches(&source.Kind{ + Type: &batchv1.Job{}, + }, handler.Funcs{ + UpdateFunc: func(ue event.UpdateEvent, rli workqueue.RateLimitingInterface) { + job := ue.ObjectNew.(*batchv1.Job) + old := ue.ObjectOld.(*batchv1.Job) + if !reflect.DeepEqual(job.Status.Conditions, old.Status.Conditions) { + r.WhenJobChanged(job) + } + }, + }). Complete(r) } + +func (r *RAGReconciler) DeleteJobsAndPvc(ctx context.Context, instance *evaluationarcadiav1alpha1.RAG) error { + logger := log.FromContext(ctx) + selector := labels.NewSelector() + requirtment, _ := labels.NewRequirement(evaluationarcadiav1alpha1.EvaluationJobLabels, selection.Equals, []string{instance.Name}) + selector = selector.Add(*requirtment) + + m := metav1.DeletePropagationForeground + job := &batchv1.Job{} + err := r.Client.DeleteAllOf(ctx, job, &client.DeleteAllOfOptions{ + DeleteOptions: client.DeleteOptions{ + PropagationPolicy: &m, + }, + ListOptions: client.ListOptions{ + Namespace: instance.Namespace, + LabelSelector: selector, + }, + }) + if err != nil && !k8serrors.IsNotFound(err) { + logger.Error(err, "sepc changed, failed to delete rag associated job.") + return err + } + pvc := &corev1.PersistentVolumeClaim{ + ObjectMeta: metav1.ObjectMeta{ + Name: instance.Name, + Namespace: instance.Namespace, + }, + } + + err = r.Client.Delete(ctx, pvc, &client.DeleteOptions{ + PropagationPolicy: &m, + }) + if err != nil && !k8serrors.IsNotFound(err) { + logger.Error(err, "spec changed, failed to delete pvc", "PvcName", pvc.Name) + return err + } + + deepCopyInstance := instance.DeepCopy() + deepCopyInstance.Status.Conditions = nil + deepCopyInstance.Status.Phase = "" + logger.Info("spec changes, delete all related resources") + return r.Client.Status().Patch(ctx, deepCopyInstance, client.MergeFrom(instance)) +} + +func (r *RAGReconciler) phaseHandler(ctx context.Context, instance *evaluationarcadiav1alpha1.RAG) (ctrl.Result, error) { + logger := log.FromContext(ctx) + curPhase := instance.Status.Phase + switch curPhase { + case "": + deepCopyInstance := instance.DeepCopy() + deepCopyInstance.Status.Phase = evaluationarcadiav1alpha1.InitPvcPhase + deepCopyInstance.Status.Conditions = []batchv1.JobCondition{ + { + Type: batchv1.JobComplete, + Status: corev1.ConditionFalse, + Message: "need to create pvc", + }, + } + err := r.Client.Status().Patch(ctx, deepCopyInstance, client.MergeFrom(instance)) + if err != nil { + logger.Error(err, "failed to initialize RAG state") + } + return ctrl.Result{}, err + case evaluationarcadiav1alpha1.InitPvcPhase: + err := r.initPVC(ctx, instance) + return ctrl.Result{}, err + case evaluationarcadiav1alpha1.DownloadFilesPhase: + err := r.JobGenerator(ctx, instance, curPhase, evaluationarcadiav1alpha1.GenerateTestFilesPhase, evaluation.DownloadJob) + if err != nil && err != errJobNotDone { + return ctrl.Result{}, err + } + return ctrl.Result{}, nil + case evaluationarcadiav1alpha1.GenerateTestFilesPhase: + err := r.JobGenerator(ctx, instance, curPhase, evaluationarcadiav1alpha1.JudgeLLMPhase, evaluation.GenTestDataJob) + if err != nil && err != errJobNotDone { + return ctrl.Result{}, err + } + return ctrl.Result{}, nil + case evaluationarcadiav1alpha1.JudgeLLMPhase: + err := r.JobGenerator(ctx, instance, curPhase, evaluationarcadiav1alpha1.UploadFilesPhase, evaluation.JudgeJobGenerator(ctx, r.Client)) + if err != nil && err != errJobNotDone { + return ctrl.Result{}, err + } + return ctrl.Result{}, nil + case evaluationarcadiav1alpha1.UploadFilesPhase: + err := r.JobGenerator(ctx, instance, curPhase, evaluationarcadiav1alpha1.CompletePhase, evaluation.UploadJobGenerator(ctx, r.Client)) + if err != nil && err != errJobNotDone { + return ctrl.Result{}, err + } + return ctrl.Result{}, nil + case evaluationarcadiav1alpha1.CompletePhase: + logger.Info("evaluation process complete, end reconcile") + } + return ctrl.Result{}, nil +} + +func (r *RAGReconciler) initPVC(ctx context.Context, instance *evaluationarcadiav1alpha1.RAG) error { + logger := log.FromContext(ctx) + deepCopyInstance := instance.DeepCopy() + for _, cond := range instance.Status.Conditions { + if cond.Type == batchv1.JobComplete && cond.Status == corev1.ConditionTrue { + // next phase + deepCopyInstance.Status.Phase = evaluationarcadiav1alpha1.DownloadFilesPhase + deepCopyInstance.Status.Conditions = []batchv1.JobCondition{ + { + Type: batchv1.JobComplete, + Status: corev1.ConditionFalse, + Message: "pvc creation complete, create download file job", + }, + } + err := r.Client.Status().Patch(ctx, deepCopyInstance, client.MergeFrom(instance)) + if err != nil { + logger.Error(err, "update the status of the rag to start downloading the file failed.") + } + return err + } + } + + pvc := corev1.PersistentVolumeClaim{} + if err := r.Client.Get(ctx, types.NamespacedName{Namespace: instance.Namespace, Name: instance.Name}, &pvc); err != nil { + if !k8serrors.IsNotFound(err) { + logger.Error(err, "failed to get pvc", "PVCName", instance.Name) + return err + } + pvc.Name = instance.Name + pvc.Namespace = instance.Namespace + pvc.Spec = *instance.Spec.Storage + _ = controllerutil.SetOwnerReference(instance, &pvc, r.Scheme) + err = r.Client.Create(ctx, &pvc) + if err != nil { + logger.Error(err, "failed to create pvc", "PVCName", pvc.Name) + deepCopyInstance.Status.Conditions = []batchv1.JobCondition{ + { + Type: batchv1.JobFailed, + Status: corev1.ConditionTrue, + Message: fmt.Sprintf("pvc creation failure. %s", err), + LastTransitionTime: metav1.Now(), + }, + } + return r.Client.Status().Patch(ctx, deepCopyInstance, client.MergeFrom(instance)) + } + } + if pvc.DeletionTimestamp != nil { + logger.Info("pvc is being deleted, need to wait for next process", "PVCname", pvc.Name) + return errors.New("pvc is being deleted, need to wait for next process") + } + deepCopyInstance.Status.Conditions = []batchv1.JobCondition{ + { + Type: batchv1.JobComplete, + Status: corev1.ConditionTrue, + Message: "pvc created successfully", + LastTransitionTime: metav1.Now(), + }, + } + + logger.Info("pvc already exists", "PVCName", pvc.Name, "Phase", pvc.Status.Phase) + return r.Client.Status().Patch(ctx, deepCopyInstance, client.MergeFrom(instance)) +} + +func (r *RAGReconciler) JobGenerator( + ctx context.Context, + instance *evaluationarcadiav1alpha1.RAG, + curPhase, nextPhse evaluationarcadiav1alpha1.RAGPhase, + genJob func(*evaluationarcadiav1alpha1.RAG) (*batchv1.Job, error), +) error { + logger := log.FromContext(ctx) + deepCopyInstance := instance.DeepCopy() + for _, cond := range deepCopyInstance.Status.Conditions { + if cond.Type == batchv1.JobComplete && cond.Status == corev1.ConditionTrue { + deepCopyInstance.Status.Phase = nextPhse + d := batchv1.JobCondition{ + Type: batchv1.JobComplete, + Status: corev1.ConditionFalse, + Message: fmt.Sprintf("the %s phase execution is complete, opening the next %s phase.", curPhase, nextPhse), + LastTransitionTime: metav1.Now(), + } + if nextPhse == evaluationarcadiav1alpha1.CompletePhase { + d.Status = corev1.ConditionTrue + d.Message = "evaluation process completed" + deepCopyInstance.Status.CompletionTime = &d.LastTransitionTime + } + deepCopyInstance.Status.Conditions = []batchv1.JobCondition{d} + err := r.Client.Status().Patch(ctx, deepCopyInstance, client.MergeFrom(instance)) + if err != nil { + logger.Error(err, "failed to update rag status") + } + return err + } + } + job := &batchv1.Job{} + jobName := evaluation.PhaseJobName(instance, curPhase) + if err := r.Client.Get(ctx, types.NamespacedName{Namespace: instance.Namespace, Name: jobName}, job); err != nil { + if !k8serrors.IsNotFound(err) { + logger.Error(err, fmt.Sprintf("checking for the existence of jobs in the %s phase has failed.", curPhase), "jobName", jobName) + return err + } + + logger.Info(fmt.Sprintf("start creating %s phase job", curPhase), "jobName", jobName) + job, err = genJob(instance) + if err != nil { + logger.Error(err, "faled to generated %s phase job", curPhase) + return err + } + if err := controllerutil.SetOwnerReference(instance, job, r.Scheme); err != nil { + logger.Error(err, "set the job's owner failed.", "jobName", jobName) + return err + } + if err := r.Client.Create(ctx, job); err != nil { + logger.Error(err, fmt.Sprintf("failed to create %s phase job", curPhase), "jobName", jobName) + deepCopyInstance.Status.Conditions = []batchv1.JobCondition{ + { + Type: batchv1.JobFailed, + Status: corev1.ConditionTrue, + Message: fmt.Sprintf("failed to create %s phase job", curPhase), + LastProbeTime: metav1.Now(), + }, + } + return r.Client.Status().Patch(ctx, deepCopyInstance, client.MergeFrom(instance)) + } + // job变化比你来得更早? + deepCopyInstance.Status.Conditions = []batchv1.JobCondition{ + { + Type: batchv1.JobComplete, + Status: corev1.ConditionFalse, + Message: fmt.Sprintf("the %s phase job has been created and is waiting for the job to complete", curPhase), + LastTransitionTime: metav1.Now(), + }, + } + return r.Client.Status().Patch(ctx, deepCopyInstance, client.MergeFrom(instance)) + } + + if job.DeletionTimestamp != nil { + logger.Info("pvc is being deleted, need to wait for next process", "jobName", jobName) + return errors.New("job is being deleted, need to wait for next process") + } + if *job.Spec.Suspend != instance.Spec.Suspend { + complete := false + for _, cond := range job.Status.Conditions { + if cond.Type == batchv1.JobComplete && cond.Status == corev1.ConditionTrue { + complete = true + break + } + } + if !complete { + logger.Info(fmt.Sprintf("job suspend state switch from %v to %v", *job.Spec.Suspend, instance.Spec.Suspend)) + *job.Spec.Suspend = instance.Spec.Suspend + return r.Client.Update(ctx, job) + } + } + + return errJobNotDone +} + +func (r *RAGReconciler) WhenPVCDeleted(pvc *corev1.PersistentVolumeClaim) { + ctx := context.TODO() + logger := log.FromContext(ctx, "PVC", pvc.Name, "Namespace", pvc.Namespace) + for _, owner := range pvc.OwnerReferences { + if owner.APIVersion == evaluationarcadiav1alpha1.GroupVersion.String() && owner.Kind == "RAG" { + rag := &evaluationarcadiav1alpha1.RAG{} + if err := r.Client.Get(ctx, types.NamespacedName{Name: owner.Name, Namespace: pvc.Namespace}, rag); err != nil { + logger.Error(err, "failed to get rag", "RAG", owner.Name) + return + } + // the pvc was removed and the evaluation process needs to be re-executed + dp := rag.DeepCopy() + dp.Status.Conditions = nil + dp.Status.Phase = "" + if err := r.Client.Status().Patch(ctx, dp, client.MergeFrom(rag)); err != nil { + logger.Error(err, "update the status of the rag to initial status failed.", "RAG", owner.Name) + } + } + } +} + +func (r *RAGReconciler) WhenJobChanged(job *batchv1.Job) { + ctx := context.TODO() + logger := log.FromContext(ctx, "JOB", job.Name, "Namespace", job.Namespace) + if len(job.Status.Conditions) == 0 { + logger.Info("job currently has no status changes and does not do anything about it") + return + } + + for _, owner := range job.OwnerReferences { + if owner.APIVersion == evaluationarcadiav1alpha1.GroupVersion.String() && owner.Kind == "RAG" { + rag := &evaluationarcadiav1alpha1.RAG{} + if err := r.Client.Get(ctx, types.NamespacedName{Name: owner.Name, Namespace: job.Namespace}, rag); err != nil { + logger.Error(err, "failed to get rag", "RAG", owner.Name) + return + } + dp := rag.DeepCopy() + cur := job.Status.Conditions[0] + for i := 1; i < len(job.Status.Conditions); i++ { + if job.Status.Conditions[i].LastTransitionTime.After(cur.LastTransitionTime.Time) { + cur = job.Status.Conditions[i] + } + } + dp.Status.Conditions = []batchv1.JobCondition{cur} + if err := r.Client.Status().Patch(ctx, dp, client.MergeFrom(rag)); err != nil { + logger.Error(err, "set the status of a job to rag failure.", "RAG", owner.Name, "Condition", dp.Status.Conditions[0]) + } + } + } +} diff --git a/deploy/charts/arcadia/Chart.yaml b/deploy/charts/arcadia/Chart.yaml index 9516d36bf..0d1569657 100644 --- a/deploy/charts/arcadia/Chart.yaml +++ b/deploy/charts/arcadia/Chart.yaml @@ -2,7 +2,7 @@ apiVersion: v2 name: arcadia description: A Helm chart(KubeBB Component) for KubeAGI Arcadia type: application -version: 0.2.24 +version: 0.2.25 appVersion: "0.1.0" keywords: diff --git a/deploy/charts/arcadia/templates/rbac.yaml b/deploy/charts/arcadia/templates/rbac.yaml index 8df7d3834..daa8e435e 100644 --- a/deploy/charts/arcadia/templates/rbac.yaml +++ b/deploy/charts/arcadia/templates/rbac.yaml @@ -53,6 +53,12 @@ rules: - get - list - watch +- apiGroups: + - "" + resources: + - persistentvolumeclaims + verbs: + - '*' - apiGroups: - "" resources: @@ -425,6 +431,12 @@ rules: - subjectaccessreviews verbs: - create +- apiGroups: + - batch + resources: + - jobs + verbs: + - '*' - apiGroups: - chain.arcadia.kubeagi.k8s.com.cn resources: diff --git a/main.go b/main.go index a09bcec03..ac12a7798 100644 --- a/main.go +++ b/main.go @@ -25,6 +25,7 @@ import ( "path/filepath" "strconv" + batchv1 "k8s.io/api/batch/v1" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -68,6 +69,7 @@ func init() { utilruntime.Must(apiprompt.AddToScheme(scheme)) utilruntime.Must(apiretriever.AddToScheme(scheme)) utilruntime.Must(evaluationarcadiav1alpha1.AddToScheme(scheme)) + utilruntime.Must(batchv1.AddToScheme(scheme)) //+kubebuilder:scaffold:scheme } diff --git a/pkg/evaluation/jobs.go b/pkg/evaluation/jobs.go new file mode 100644 index 000000000..32c31a4c9 --- /dev/null +++ b/pkg/evaluation/jobs.go @@ -0,0 +1,344 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package evaluation + +import ( + "context" + "fmt" + "path/filepath" + + batchv1 "k8s.io/api/batch/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/utils/env" + "k8s.io/utils/pointer" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/kubeagi/arcadia/api/base/v1alpha1" + evav1alpha1 "github.com/kubeagi/arcadia/api/evaluation/v1alpha1" + "github.com/kubeagi/arcadia/pkg/config" + "github.com/kubeagi/arcadia/pkg/llms" + "github.com/kubeagi/arcadia/pkg/utils" +) + +const ( + defaultPVCMountPath = "/data/evaluations" + defaultTestRagFile = "ragas.csv" + defaultMCImage = "kubeagi/minio-mc:RELEASE.2023-01-28T20-29-38Z" +) + +func PhaseJobName(instance *evav1alpha1.RAG, phase evav1alpha1.RAGPhase) string { + return fmt.Sprintf("%s-phase-%s", instance.Name, phase) +} + +func DownloadJob(instance *evav1alpha1.RAG) (*batchv1.Job, error) { + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: instance.Namespace, + Name: PhaseJobName(instance, evav1alpha1.DownloadFilesPhase), + Labels: map[string]string{ + evav1alpha1.EvaluationJobLabels: instance.Name, + }, + }, + Spec: batchv1.JobSpec{ + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + RestartPolicy: v1.RestartPolicyNever, + ServiceAccountName: instance.Spec.ServiceAccountName, + Containers: []v1.Container{ + { + Name: "download-dataset-files", + Image: "kubeagi/arcadia-eval", + Command: []string{ + "arctl", + }, + Args: []string{ + fmt.Sprintf("-n=%s", instance.Namespace), + "eval", "download", + fmt.Sprintf("--rag=%s", instance.Name), + fmt.Sprintf("--application=%s", instance.Spec.Application.Name), + fmt.Sprintf("--dir=%s", defaultPVCMountPath), + fmt.Sprintf("--system-conf-namespace=%s", utils.GetCurrentNamespace()), + fmt.Sprintf("--system-conf-name=%s", env.GetString(config.EnvConfigKey, config.EnvConfigDefaultValue)), + }, + VolumeMounts: []v1.VolumeMount{ + { + Name: "data", + MountPath: defaultPVCMountPath, + }, + }, + }, + }, + Volumes: []v1.Volume{ + { + Name: "data", + VolumeSource: v1.VolumeSource{ + PersistentVolumeClaim: &v1.PersistentVolumeClaimVolumeSource{ + ClaimName: instance.Name, + ReadOnly: false, + }, + }, + }, + }, + }, + }, + BackoffLimit: pointer.Int32(1), + Completions: pointer.Int32(1), + Parallelism: pointer.Int32(1), + Suspend: &instance.Spec.Suspend, + }, + } + return job, nil +} + +func GenTestDataJob(instance *evav1alpha1.RAG) (*batchv1.Job, error) { + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: instance.Namespace, + Name: PhaseJobName(instance, evav1alpha1.GenerateTestFilesPhase), + Labels: map[string]string{ + evav1alpha1.EvaluationJobLabels: instance.Name, + }, + }, + Spec: batchv1.JobSpec{ + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + RestartPolicy: v1.RestartPolicyNever, + ServiceAccountName: instance.Spec.ServiceAccountName, + Containers: []v1.Container{ + { + Name: "gen-test-files", + Image: "kubeagi/arcadia-eval", + Command: []string{ + "arctl", + }, + Args: []string{ + fmt.Sprintf("-n=%s", instance.Namespace), + "eval", "gen_test_dataset", + fmt.Sprintf("--application=%s", instance.Spec.Application.Name), + fmt.Sprintf("--input-dir=%s", defaultPVCMountPath), + "--output=csv", + "--merge=true", + fmt.Sprintf("--merge-file=%s", filepath.Join(defaultPVCMountPath, defaultTestRagFile)), + }, + VolumeMounts: []v1.VolumeMount{ + { + Name: "data", + MountPath: defaultPVCMountPath, + }, + }, + }, + }, + Volumes: []v1.Volume{ + { + Name: "data", + VolumeSource: v1.VolumeSource{ + PersistentVolumeClaim: &v1.PersistentVolumeClaimVolumeSource{ + ClaimName: instance.Name, + ReadOnly: false, + }, + }, + }, + }, + }, + }, + BackoffLimit: pointer.Int32(1), + Completions: pointer.Int32(1), + Parallelism: pointer.Int32(1), + Suspend: &instance.Spec.Suspend, + }, + } + return job, nil +} + +func JudgeJobGenerator(ctx context.Context, c client.Client) func(*evav1alpha1.RAG) (*batchv1.Job, error) { + return func(instance *evav1alpha1.RAG) (*batchv1.Job, error) { + var ( + apiBase, model, apiKey string + err error + ) + llm := v1alpha1.LLM{} + ns := instance.Namespace + if instance.Spec.JudgeLLM.Namespace != nil { + ns = *instance.Spec.JudgeLLM.Namespace + } + if err = c.Get(context.TODO(), types.NamespacedName{Namespace: ns, Name: instance.Spec.JudgeLLM.Name}, &llm); err != nil { + return nil, err + } + + apiBase = llm.Get3rdPartyLLMBaseURL() + apiKey, err = llm.AuthAPIKey(ctx, c, nil) + if err != nil { + return nil, err + } + + switch llm.Spec.Type { + case llms.OpenAI: + model = "gtp4" + case llms.ZhiPuAI: + model = "glm-4" + default: + return nil, fmt.Errorf("not support type %s", llm.Spec.Type) + } + if r := llm.Get3rdPartyModels(); len(r) > 0 { + model = r[0] + } + + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: instance.Namespace, + Name: PhaseJobName(instance, evav1alpha1.JudgeLLMPhase), + Labels: map[string]string{ + evav1alpha1.EvaluationJobLabels: instance.Name, + }, + }, + + Spec: batchv1.JobSpec{ + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + RestartPolicy: v1.RestartPolicyNever, + ServiceAccountName: instance.Spec.ServiceAccountName, + Containers: []v1.Container{ + { + Name: "judge-llm", + Image: "kubeagi/arcadia-eval:v0.1.0", + WorkingDir: defaultPVCMountPath, + Command: []string{ + "python3", + }, + Args: []string{ + "-m", + "ragas_once.cli", + fmt.Sprintf("--apibase=%s", apiBase), + fmt.Sprintf("--model=%s", model), + fmt.Sprintf("--apikey=%s", apiKey), + fmt.Sprintf("--dataset=%s", filepath.Join(defaultPVCMountPath, defaultTestRagFile)), + }, + VolumeMounts: []v1.VolumeMount{ + { + Name: "data", + MountPath: defaultPVCMountPath, + }, + }, + }, + }, + Volumes: []v1.Volume{ + { + Name: "data", + VolumeSource: v1.VolumeSource{ + PersistentVolumeClaim: &v1.PersistentVolumeClaimVolumeSource{ + ClaimName: instance.Name, + ReadOnly: false, + }, + }, + }, + }, + }, + }, + BackoffLimit: pointer.Int32(1), + Completions: pointer.Int32(1), + Parallelism: pointer.Int32(1), + Suspend: &instance.Spec.Suspend, + }, + } + return job, nil + } +} + +func UploadJobGenerator(ctx context.Context, client client.Client) func(*evav1alpha1.RAG) (*batchv1.Job, error) { + return func(instance *evav1alpha1.RAG) (*batchv1.Job, error) { + datasource, err := config.GetSystemDatasource(ctx, client, nil) + if err != nil { + return nil, err + } + url := datasource.Spec.Endpoint.URL + if datasource.Spec.Endpoint.Insecure { + url = "http://" + url + } else { + url = "https://" + url + } + ns := datasource.Namespace + if datasource.Spec.Endpoint.AuthSecret.Namespace != nil { + ns = *datasource.Spec.Endpoint.AuthSecret.Namespace + } + data, err := datasource.Spec.Endpoint.AuthData(ctx, ns, client, nil) + if err != nil { + return nil, err + } + + accessKeyID := string(data["rootUser"]) + secretAccessKey := string(data["rootPassword"]) + + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: instance.Namespace, + Name: PhaseJobName(instance, evav1alpha1.UploadFilesPhase), + Labels: map[string]string{ + evav1alpha1.EvaluationJobLabels: instance.Name, + }, + }, + Spec: batchv1.JobSpec{ + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + RestartPolicy: v1.RestartPolicyNever, + ServiceAccountName: instance.Spec.ServiceAccountName, + Containers: []v1.Container{ + { + Name: "upload-result", + Image: defaultMCImage, + Command: []string{ + "/bin/bash", + "-c", + fmt.Sprintf(`echo "upload result" +mc alias set oss $MINIO_ENDPOINT $MINIO_ACCESS_KEY $MINIO_SECRET_KEY --insecure +mc --insecure cp -r %s/ oss/%s/evals/%s/%s`, defaultPVCMountPath, instance.Namespace, instance.Spec.Application.Name, instance.Name), + }, + VolumeMounts: []v1.VolumeMount{ + { + Name: "data", + MountPath: defaultPVCMountPath, + }, + }, + Env: []v1.EnvVar{ + {Name: "MINIO_ENDPOINT", Value: url}, + {Name: "MINIO_ACCESS_KEY", Value: accessKeyID}, + {Name: "MINIO_SECRET_KEY", Value: secretAccessKey}, + }, + }, + }, + Volumes: []v1.Volume{ + { + Name: "data", + VolumeSource: v1.VolumeSource{ + PersistentVolumeClaim: &v1.PersistentVolumeClaimVolumeSource{ + ClaimName: instance.Name, + ReadOnly: false, + }, + }, + }, + }, + }, + }, + BackoffLimit: pointer.Int32(1), + Completions: pointer.Int32(1), + Parallelism: pointer.Int32(1), + Suspend: &instance.Spec.Suspend, + }, + } + return job, nil + } +}