Skip to content

Commit

Permalink
feat: revision based jobs (#970)
Browse files Browse the repository at this point in the history
  • Loading branch information
cakeinsauce authored Sep 23, 2024
1 parent b4974db commit 9ea8c63
Show file tree
Hide file tree
Showing 16 changed files with 197 additions and 62 deletions.
12 changes: 12 additions & 0 deletions annotation/annotation/annotations/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,3 +1019,15 @@ def construct_document_links(
)
)
return links


def get_annotations_by_revision(
revisions: Set[str], db: Session
) -> List[AnnotatedDoc]:
if not revisions:
return []
return (
db.query(AnnotatedDoc)
.filter(AnnotatedDoc.revision.in_(revisions))
.all()
)
35 changes: 35 additions & 0 deletions annotation/annotation/annotations/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, List, Optional, Set
from uuid import UUID

import filter_lib
from fastapi import (
APIRouter,
Depends,
Expand All @@ -22,6 +23,7 @@
)
from annotation.database import get_db
from annotation.errors import NoSuchRevisionsError
from annotation.filters import AnnotationRequestFilter
from annotation.jobs.services import update_jobs_categories
from annotation.microservice_communication import jobs_communication
from annotation.microservice_communication.assets_communication import (
Expand Down Expand Up @@ -70,6 +72,39 @@
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))


@router.post(
"",
status_code=status.HTTP_200_OK,
response_model=filter_lib.Page[AnnotatedDocSchema],
summary="Get annotations by filters",
tags=[ANNOTATION_TAG],
)
async def get_annotations(
request: AnnotationRequestFilter,
x_current_tenant: str = X_CURRENT_TENANT_HEADER,
token: TenantData = Depends(TOKEN),
db: Session = Depends(get_db),
) -> filter_lib.Page[AnnotatedDocSchema]:
filter_args = filter_lib.map_request_to_filter(
request.dict(), AnnotatedDoc.__name__
)
# Distinct on revision, filter_lib doesn't work right with
# distinct and sorting
subquery = (
db.query(AnnotatedDoc)
.filter(AnnotatedDoc.tenant == x_current_tenant)
.distinct(AnnotatedDoc.revision)
.subquery()
)
query = db.query(AnnotatedDoc).join(
subquery, AnnotatedDoc.revision == subquery.c.revision
)
query, pagination = filter_lib.form_query(filter_args, query)
return filter_lib.paginate(
[AnnotatedDocSchema.from_orm(el) for el in query], pagination
)


@router.post(
"/{task_id}",
status_code=status.HTTP_201_CREATED,
Expand Down
2 changes: 2 additions & 0 deletions annotation/annotation/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@
],
include=ADDITIONAL_TASK_FIELDS,
)

AnnotationRequestFilter = create_filter_model(AnnotatedDoc)
54 changes: 52 additions & 2 deletions annotation/annotation/jobs/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,18 @@

import annotation.categories.services
from annotation import logger as app_logger
from annotation.annotations import construct_annotated_doc
from annotation.annotations.main import (
construct_particular_rev_response,
get_annotations_by_revision,
)
from annotation.categories import fetch_bunch_categories_db
from annotation.database import get_db
from annotation.distribution import distribute, redistribute
from annotation.errors import NoSuchRevisionsError
from annotation.filters import CategoryFilter
from annotation.microservice_communication.assets_communication import (
get_file_path_and_bucket,
get_files_info,
)
from annotation.microservice_communication.jobs_communication import (
Expand All @@ -29,6 +36,7 @@
BadRequestErrorSchema,
CategoryResponseSchema,
ConnectionErrorSchema,
DocForSaveSchema,
FileStatusEnumSchema,
JobFilesInfoSchema,
JobInfoSchema,
Expand Down Expand Up @@ -140,6 +148,13 @@ def post_job(
)
)

# Revision based job if any revisions are provided
annotated_doc_revisions = get_annotations_by_revision(
job_info.revisions, db
)
if len(annotated_doc_revisions) != len(job_info.revisions):
raise NoSuchRevisionsError

files = []

if job_info.previous_jobs:
Expand All @@ -161,8 +176,9 @@ def post_job(
]
files += tmp_files
else:
revision_file_ids = [rev.file_id for rev in annotated_doc_revisions]
files += get_files_info(
job_info.files,
revision_file_ids or job_info.files,
job_info.datasets,
x_current_tenant,
token.token,
Expand Down Expand Up @@ -219,7 +235,41 @@ def post_job(
deadline=job_info.deadline,
extensive_coverage=job_info.extensive_coverage,
)

if job_info.revisions: # revision based job
# for each existing revision from other jobs we copy them over
for annotated_doc in annotated_doc_revisions:
rev_annotation = construct_particular_rev_response(annotated_doc)
doc = DocForSaveSchema(
base_revision=annotated_doc.revision,
user=annotated_doc.user,
pipeline=annotated_doc.pipeline,
pages=rev_annotation.pages,
validated=rev_annotation.validated,
failed_validation_pages=rev_annotation.failed_validation_pages,
similar_revisions=rev_annotation.similar_revisions,
categories=rev_annotation.categories,
links_json=rev_annotation.links_json,
)
s3_file_path, s3_file_bucket = get_file_path_and_bucket(
file_id=annotated_doc.file_id,
tenant=x_current_tenant,
token=token.token,
)
construct_annotated_doc(
db=db,
# assume user_id is always the same as original revision
user_id=annotated_doc.user,
pipeline_id=None,
job_id=job_id,
file_id=annotated_doc.file_id,
doc=doc,
tenant=x_current_tenant,
s3_file_path=s3_file_path,
s3_file_bucket=s3_file_bucket,
latest_doc=None,
task_id=None,
is_latest=True,
)
db.commit()


Expand Down
14 changes: 11 additions & 3 deletions annotation/annotation/schemas/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ class JobInfoSchema(BaseModel):
)
files: Set[int] = Field(..., example={1, 2, 3})
datasets: Set[int] = Field(..., example={1, 2, 3})
revisions: Set[str] = Field(
set(),
example={
"35b7b50a056d00048b0977b195f7f8e9f9f7400f",
"4dc503a9ade7d7cb55d6be671748a312d663bb0a",
},
)
previous_jobs: List[PreviousJobInfoSchema] = Field(...)
is_auto_distribution: bool = Field(default=False, example=False)
categories: Optional[Set[str]] = Field(None, example={"1", "2"})
Expand All @@ -83,20 +90,21 @@ class JobInfoSchema(BaseModel):
@root_validator
def check_files_datasets_previous_jobs(cls, values):
"""
Files and datasets should not be empty at the same time.
Files and datasets and revisions should not be empty at the same time.
"""
files, datasets = values.get("files"), values.get("datasets")
revisions = values.get("revisions")
previous_jobs = values.get("previous_jobs")

job_type = values.get("job_type")

if (
not (bool(previous_jobs) ^ bool(files or datasets))
not (bool(previous_jobs) ^ bool(files or datasets or revisions))
and job_type != JobTypeEnumSchema.ImportJob
):
raise ValueError(
"Only one field must be specified: "
"either previous_jobs or files/datasets"
"either previous_jobs or files/datasets/revisions"
)
return values

Expand Down
32 changes: 32 additions & 0 deletions jobs/alembic/versions/a4b7b64a472c_add_revisions_field_to_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""add revisions field to job
Revision ID: a4b7b64a472c
Revises: 2dd22b64e1a9
Create Date: 2024-09-11 17:40:41.052573
"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "a4b7b64a472c"
down_revision = "2dd22b64e1a9"
branch_labels = None
depends_on = None


def upgrade():
op.add_column(
"job",
sa.Column(
"revisions",
sa.ARRAY(sa.String(50)),
nullable=False,
server_default="{}",
),
)


def downgrade():
op.drop_column("job", "revisions")
7 changes: 6 additions & 1 deletion jobs/jobs/airflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ async def run(
files: List[pipeline.PipelineFile],
current_tenant: str,
datasets: List[pipeline.Dataset],
revisions: List[str],
) -> None:
configuration = get_configuration()
with client.ApiClient(configuration) as api_client:
Expand All @@ -79,6 +80,7 @@ async def run(
tenant=current_tenant,
files_data=files,
datasets=datasets,
revisions=revisions,
)
),
)
Expand All @@ -97,5 +99,8 @@ async def run(
files: List[pipeline.PipelineFile],
current_tenant: str,
datasets: List[pipeline.Dataset],
revisions: List[str],
) -> None:
return await run(pipeline_id, job_id, files, current_tenant, datasets)
return await run(
pipeline_id, job_id, files, current_tenant, datasets, revisions
)
59 changes: 10 additions & 49 deletions jobs/jobs/create_job_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,39 +41,17 @@ async def get_all_datasets_and_files_data(
return files_data, valid_dataset_tags, valid_separate_files_uuids


# noinspection PyUnreachableCode
async def create_extraction_job(
extraction_job_input: ExtractionJobParams,
extraction_job_input: ExtractionJobParams, # todo: should be JobParams
current_tenant: str,
jw_token: str,
db: Session = Depends(db_service.get_session),
) -> dbm.CombinedJob:
"""Creates new ExtractionJob and saves it in the database"""

if False:
# old pipelines service
pipeline_instance = await utils.get_pipeline_instance_by_its_name(
pipeline_name=extraction_job_input.pipeline_name,
current_tenant=current_tenant,
jw_token=jw_token,
pipeline_version=extraction_job_input.pipeline_version,
)

pipeline_id = (
extraction_job_input.pipeline_name
if extraction_job_input.pipeline_name.endswith(":airflow")
else pipeline_instance.get("id")
)

pipeline_categories = pipeline_instance.get("meta", {}).get(
"categories", []
)

else:
pipeline_id = extraction_job_input.pipeline_id
pipeline_engine = extraction_job_input.pipeline_engine
# check if categories passed and then append all categories to job
pipeline_categories = []
pipeline_id = extraction_job_input.pipeline_id
pipeline_engine = extraction_job_input.pipeline_engine
# check if categories passed and then append all categories to job
pipeline_categories = []

(
files_data,
Expand Down Expand Up @@ -119,6 +97,7 @@ async def create_extraction_job(
extraction_job_input.previous_jobs,
initial_status,
pipeline_categories,
list(extraction_job_input.revisions),
)

return job_in_db
Expand Down Expand Up @@ -150,28 +129,10 @@ async def create_extraction_annotation_job(
db: Session = Depends(db_service.get_session),
) -> dbm.CombinedJob:
"""Creates new ExtractionWithAnnotationJob and saves it in the database"""
if False:
pipeline_instance = await utils.get_pipeline_instance_by_its_name(
pipeline_name=extraction_annotation_job_input.pipeline_name,
current_tenant=current_tenant,
jw_token=jw_token,
pipeline_version=extraction_annotation_job_input.pipeline_version,
)
pipeline_id = (
extraction_annotation_job_input.pipeline_name
if extraction_annotation_job_input.pipeline_name.endswith(
":airflow"
)
else pipeline_instance.get("id")
)
pipeline_categories = pipeline_instance.get("meta", {}).get(
"categories", []
)
else:
pipeline_id = extraction_annotation_job_input.pipeline_id
pipeline_engine = extraction_annotation_job_input.pipeline_engine
# check if categories passed and then append all categories to job
pipeline_categories = []
pipeline_id = extraction_annotation_job_input.pipeline_id
pipeline_engine = extraction_annotation_job_input.pipeline_engine
# check if categories passed and then append all categories to job
pipeline_categories = []

(
files_data,
Expand Down
10 changes: 9 additions & 1 deletion jobs/jobs/databricks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def run(
files: List[pipeline.PipelineFile],
current_tenant: str,
datasets: List[pipeline.Dataset],
revisions: List[str],
) -> None:
logger.info(
"Running pipeline %s, job_id %s, current_tenant: %s with arguments %s",
Expand All @@ -65,6 +66,7 @@ async def run(
tenant=current_tenant,
files_data=files,
datasets=datasets,
revisions=revisions,
)
)
)
Expand All @@ -84,7 +86,13 @@ async def run(
files: List[pipeline.PipelineFile],
current_tenant: str,
datasets: List[pipeline.Dataset],
revisions: List[str],
) -> None:
await run(
pipeline_id, int(job_id), files, current_tenant, datasets=datasets
pipeline_id,
int(job_id),
files,
current_tenant,
datasets=datasets,
revisions=revisions,
)
Loading

0 comments on commit 9ea8c63

Please sign in to comment.