Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: revision based jobs #970

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions annotation/annotation/annotations/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,3 +1019,15 @@ def construct_document_links(
)
)
return links


def get_annotations_by_revision(
revisions: Set[str], db: Session
) -> List[AnnotatedDoc]:
if not revisions:
return []
return (
db.query(AnnotatedDoc)
.filter(AnnotatedDoc.revision.in_(revisions))
.all()
)
35 changes: 35 additions & 0 deletions annotation/annotation/annotations/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, List, Optional, Set
from uuid import UUID

import filter_lib
from fastapi import (
APIRouter,
Depends,
Expand All @@ -22,6 +23,7 @@
)
from annotation.database import get_db
from annotation.errors import NoSuchRevisionsError
from annotation.filters import AnnotationRequestFilter
from annotation.jobs.services import update_jobs_categories
from annotation.microservice_communication import jobs_communication
from annotation.microservice_communication.assets_communication import (
Expand Down Expand Up @@ -70,6 +72,39 @@
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))


@router.post(
"",
status_code=status.HTTP_200_OK,
response_model=filter_lib.Page[AnnotatedDocSchema],
summary="Get annotations by filters",
tags=[ANNOTATION_TAG],
)
async def get_annotations(
request: AnnotationRequestFilter,
x_current_tenant: str = X_CURRENT_TENANT_HEADER,
token: TenantData = Depends(TOKEN),
db: Session = Depends(get_db),
) -> filter_lib.Page[AnnotatedDocSchema]:
filter_args = filter_lib.map_request_to_filter(
request.dict(), AnnotatedDoc.__name__
)
# Distinct on revision, filter_lib doesn't work right with
# distinct and sorting
subquery = (
db.query(AnnotatedDoc)
.filter(AnnotatedDoc.tenant == x_current_tenant)
.distinct(AnnotatedDoc.revision)
.subquery()
)
query = db.query(AnnotatedDoc).join(
subquery, AnnotatedDoc.revision == subquery.c.revision
)
query, pagination = filter_lib.form_query(filter_args, query)
return filter_lib.paginate(
[AnnotatedDocSchema.from_orm(el) for el in query], pagination
)


@router.post(
"/{task_id}",
status_code=status.HTTP_201_CREATED,
Expand Down
2 changes: 2 additions & 0 deletions annotation/annotation/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@
],
include=ADDITIONAL_TASK_FIELDS,
)

AnnotationRequestFilter = create_filter_model(AnnotatedDoc)
54 changes: 52 additions & 2 deletions annotation/annotation/jobs/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,18 @@

import annotation.categories.services
from annotation import logger as app_logger
from annotation.annotations import construct_annotated_doc
from annotation.annotations.main import (
construct_particular_rev_response,
get_annotations_by_revision,
)
from annotation.categories import fetch_bunch_categories_db
from annotation.database import get_db
from annotation.distribution import distribute, redistribute
from annotation.errors import NoSuchRevisionsError
from annotation.filters import CategoryFilter
from annotation.microservice_communication.assets_communication import (
get_file_path_and_bucket,
get_files_info,
)
from annotation.microservice_communication.jobs_communication import (
Expand All @@ -29,6 +36,7 @@
BadRequestErrorSchema,
CategoryResponseSchema,
ConnectionErrorSchema,
DocForSaveSchema,
FileStatusEnumSchema,
JobFilesInfoSchema,
JobInfoSchema,
Expand Down Expand Up @@ -140,6 +148,13 @@ def post_job(
)
)

# Revision based job if any revisions are provided
annotated_doc_revisions = get_annotations_by_revision(
job_info.revisions, db
)
if len(annotated_doc_revisions) != len(job_info.revisions):
raise NoSuchRevisionsError

files = []

if job_info.previous_jobs:
Expand All @@ -161,8 +176,9 @@ def post_job(
]
files += tmp_files
else:
revision_file_ids = [rev.file_id for rev in annotated_doc_revisions]
files += get_files_info(
job_info.files,
revision_file_ids or job_info.files,
job_info.datasets,
x_current_tenant,
token.token,
Expand Down Expand Up @@ -219,7 +235,41 @@ def post_job(
deadline=job_info.deadline,
extensive_coverage=job_info.extensive_coverage,
)

if job_info.revisions: # revision based job
# for each existing revision from other jobs we copy them over
for annotated_doc in annotated_doc_revisions:
rev_annotation = construct_particular_rev_response(annotated_doc)
doc = DocForSaveSchema(
base_revision=annotated_doc.revision,
user=annotated_doc.user,
pipeline=annotated_doc.pipeline,
pages=rev_annotation.pages,
validated=rev_annotation.validated,
failed_validation_pages=rev_annotation.failed_validation_pages,
similar_revisions=rev_annotation.similar_revisions,
categories=rev_annotation.categories,
links_json=rev_annotation.links_json,
)
s3_file_path, s3_file_bucket = get_file_path_and_bucket(
file_id=annotated_doc.file_id,
tenant=x_current_tenant,
token=token.token,
)
construct_annotated_doc(
db=db,
# assume user_id is always the same as original revision
user_id=annotated_doc.user,
pipeline_id=None,
job_id=job_id,
file_id=annotated_doc.file_id,
doc=doc,
tenant=x_current_tenant,
s3_file_path=s3_file_path,
s3_file_bucket=s3_file_bucket,
latest_doc=None,
task_id=None,
is_latest=True,
)
db.commit()


Expand Down
14 changes: 11 additions & 3 deletions annotation/annotation/schemas/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ class JobInfoSchema(BaseModel):
)
files: Set[int] = Field(..., example={1, 2, 3})
datasets: Set[int] = Field(..., example={1, 2, 3})
revisions: Set[str] = Field(
set(),
example={
"35b7b50a056d00048b0977b195f7f8e9f9f7400f",
"4dc503a9ade7d7cb55d6be671748a312d663bb0a",
},
)
previous_jobs: List[PreviousJobInfoSchema] = Field(...)
is_auto_distribution: bool = Field(default=False, example=False)
categories: Optional[Set[str]] = Field(None, example={"1", "2"})
Expand All @@ -83,20 +90,21 @@ class JobInfoSchema(BaseModel):
@root_validator
def check_files_datasets_previous_jobs(cls, values):
"""
Files and datasets should not be empty at the same time.
Files and datasets and revisions should not be empty at the same time.
"""
files, datasets = values.get("files"), values.get("datasets")
revisions = values.get("revisions")
previous_jobs = values.get("previous_jobs")

job_type = values.get("job_type")

if (
not (bool(previous_jobs) ^ bool(files or datasets))
not (bool(previous_jobs) ^ bool(files or datasets or revisions))
and job_type != JobTypeEnumSchema.ImportJob
):
raise ValueError(
"Only one field must be specified: "
"either previous_jobs or files/datasets"
"either previous_jobs or files/datasets/revisions"
)
return values

Expand Down
32 changes: 32 additions & 0 deletions jobs/alembic/versions/a4b7b64a472c_add_revisions_field_to_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""add revisions field to job

Revision ID: a4b7b64a472c
Revises: 2dd22b64e1a9
Create Date: 2024-09-11 17:40:41.052573

"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "a4b7b64a472c"
down_revision = "2dd22b64e1a9"
branch_labels = None
depends_on = None


def upgrade():
op.add_column(
"job",
sa.Column(
"revisions",
sa.ARRAY(sa.String(50)),
nullable=False,
server_default="{}",
),
)


def downgrade():
op.drop_column("job", "revisions")
7 changes: 6 additions & 1 deletion jobs/jobs/airflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ async def run(
files: List[pipeline.PipelineFile],
current_tenant: str,
datasets: List[pipeline.Dataset],
revisions: List[str],
) -> None:
configuration = get_configuration()
with client.ApiClient(configuration) as api_client:
Expand All @@ -79,6 +80,7 @@ async def run(
tenant=current_tenant,
files_data=files,
datasets=datasets,
revisions=revisions,
)
),
)
Expand All @@ -97,5 +99,8 @@ async def run(
files: List[pipeline.PipelineFile],
current_tenant: str,
datasets: List[pipeline.Dataset],
revisions: List[str],
) -> None:
return await run(pipeline_id, job_id, files, current_tenant, datasets)
return await run(
pipeline_id, job_id, files, current_tenant, datasets, revisions
)
59 changes: 10 additions & 49 deletions jobs/jobs/create_job_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,39 +41,17 @@ async def get_all_datasets_and_files_data(
return files_data, valid_dataset_tags, valid_separate_files_uuids


# noinspection PyUnreachableCode
async def create_extraction_job(
extraction_job_input: ExtractionJobParams,
extraction_job_input: ExtractionJobParams, # todo: should be JobParams
current_tenant: str,
jw_token: str,
db: Session = Depends(db_service.get_session),
) -> dbm.CombinedJob:
"""Creates new ExtractionJob and saves it in the database"""

if False:
# old pipelines service
pipeline_instance = await utils.get_pipeline_instance_by_its_name(
pipeline_name=extraction_job_input.pipeline_name,
current_tenant=current_tenant,
jw_token=jw_token,
pipeline_version=extraction_job_input.pipeline_version,
)

pipeline_id = (
extraction_job_input.pipeline_name
if extraction_job_input.pipeline_name.endswith(":airflow")
else pipeline_instance.get("id")
)

pipeline_categories = pipeline_instance.get("meta", {}).get(
"categories", []
)

else:
pipeline_id = extraction_job_input.pipeline_id
pipeline_engine = extraction_job_input.pipeline_engine
# check if categories passed and then append all categories to job
pipeline_categories = []
pipeline_id = extraction_job_input.pipeline_id
pipeline_engine = extraction_job_input.pipeline_engine
# check if categories passed and then append all categories to job
pipeline_categories = []

(
files_data,
Expand Down Expand Up @@ -119,6 +97,7 @@ async def create_extraction_job(
extraction_job_input.previous_jobs,
initial_status,
pipeline_categories,
list(extraction_job_input.revisions),
)

return job_in_db
Expand Down Expand Up @@ -150,28 +129,10 @@ async def create_extraction_annotation_job(
db: Session = Depends(db_service.get_session),
) -> dbm.CombinedJob:
"""Creates new ExtractionWithAnnotationJob and saves it in the database"""
if False:
pipeline_instance = await utils.get_pipeline_instance_by_its_name(
pipeline_name=extraction_annotation_job_input.pipeline_name,
current_tenant=current_tenant,
jw_token=jw_token,
pipeline_version=extraction_annotation_job_input.pipeline_version,
)
pipeline_id = (
extraction_annotation_job_input.pipeline_name
if extraction_annotation_job_input.pipeline_name.endswith(
":airflow"
)
else pipeline_instance.get("id")
)
pipeline_categories = pipeline_instance.get("meta", {}).get(
"categories", []
)
else:
pipeline_id = extraction_annotation_job_input.pipeline_id
pipeline_engine = extraction_annotation_job_input.pipeline_engine
# check if categories passed and then append all categories to job
pipeline_categories = []
pipeline_id = extraction_annotation_job_input.pipeline_id
pipeline_engine = extraction_annotation_job_input.pipeline_engine
# check if categories passed and then append all categories to job
pipeline_categories = []

(
files_data,
Expand Down
10 changes: 9 additions & 1 deletion jobs/jobs/databricks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def run(
files: List[pipeline.PipelineFile],
current_tenant: str,
datasets: List[pipeline.Dataset],
revisions: List[str],
) -> None:
logger.info(
"Running pipeline %s, job_id %s, current_tenant: %s with arguments %s",
Expand All @@ -65,6 +66,7 @@ async def run(
tenant=current_tenant,
files_data=files,
datasets=datasets,
revisions=revisions,
)
)
)
Expand All @@ -84,7 +86,13 @@ async def run(
files: List[pipeline.PipelineFile],
current_tenant: str,
datasets: List[pipeline.Dataset],
revisions: List[str],
) -> None:
await run(
pipeline_id, int(job_id), files, current_tenant, datasets=datasets
pipeline_id,
int(job_id),
files,
current_tenant,
datasets=datasets,
revisions=revisions,
)
Loading
Loading