diff --git a/src/python/dart/engine/redshift/command/ddl.py b/src/python/dart/engine/redshift/command/ddl.py index 082c629..5ef463f 100644 --- a/src/python/dart/engine/redshift/command/ddl.py +++ b/src/python/dart/engine/redshift/command/ddl.py @@ -10,11 +10,11 @@ def get_target_schema_and_table_name(action, dataset): def get_stage_schema_and_table_name(action, dataset): schema_name, table_name = get_target_schema_and_table_name(action, dataset) - return 'dart_stage', schema_name + '_' + table_name + '_' + action.id + return 'dart_stage', schema_name + '_' + table_name + '_' + action.data.workflow_action_id def get_tracking_schema_and_table_name(action): - table_name = 's3_files_for_action_%s' % action.id + table_name = 's3_files_for_action_%s' % action.data.workflow_action_id if action.data.action_type_name == RedshiftActionTypes.consume_subscription.name: table_name = 's3_files_for_subscription_%s' % action.data.args['subscription_id'] return 'dart_tracking', table_name diff --git a/src/python/dart/model/action.py b/src/python/dart/model/action.py index 4386352..4763dba 100644 --- a/src/python/dart/model/action.py +++ b/src/python/dart/model/action.py @@ -66,7 +66,8 @@ def __init__(self, name, action_type_name, args=None, state=ActionState.HAS_NEVE on_failure_email=None, on_success_email=None, engine_name=None, datastore_id=None, workflow_id=None, workflow_instance_id=None, workflow_action_id=None, first_in_workflow=False, last_in_workflow=False, ecs_task_arn=None, batch_job_id=None, extra_data=None, tags=None, user_id='anonymous', - avg_runtime=None, completed_runs=0, parallelization_parents=None, parallelization_idx=None): + avg_runtime=None, completed_runs=0, parallelization_parents=None, parallelization_idx=None, + vcpus=None, memory_mb=None, job_definition=None, job_queue=None, job_name=None): """ :type name: str :type action_type_name: str @@ -96,6 +97,12 @@ def __init__(self, name, action_type_name, args=None, state=ActionState.HAS_NEVE :type completed_runs: int :type parallelization_parents: list[int] :type parallelization_idx: int + :type vcpus: int + :type memory_mb: int + :type job_definition: str + :type job_queue: str + :type job_name: str + """ self.name = name self.action_type_name = action_type_name @@ -126,3 +133,8 @@ def __init__(self, name, action_type_name, args=None, state=ActionState.HAS_NEVE self.user_id = user_id self.avg_runtime = avg_runtime self.completed_runs = completed_runs + self.vcpus = vcpus + self.memory_mb = memory_mb + self.job_definition = job_definition + self.job_queue = job_queue + self.job_name = job_name diff --git a/src/python/dart/schema/action.py b/src/python/dart/schema/action.py index 7c8d8b4..d4de6ff 100644 --- a/src/python/dart/schema/action.py +++ b/src/python/dart/schema/action.py @@ -19,6 +19,11 @@ def action_schema(supported_action_type_params_schema): 'order_idx': {'type': ['number', 'null'], 'minimum': 0.0}, 'parallelization_parents': parallelization_parents(), 'parallelization_idx': {'type': ['number', 'null'], 'minimum': 0.0}, + 'vcpus': {'type': ['number', 'null'], 'minimum': 1.0, 'maximum': 1024}, + 'memory_mb': {'type': ['number', 'null'], 'minimum': 4.0, 'maximum': 61440}, + 'job_definition': {'type': ['string', 'null']}, + 'job_queue': {'type': ['string', 'null']}, + 'job_name': {'type': ['string', 'null']}, 'error_message': {'type': ['string', 'null'], 'readonly': True, "x-schema-form": {"type": "textarea"}}, 'on_failure': { 'type': 'string', diff --git a/src/python/dart/util/aws_batch.py b/src/python/dart/util/aws_batch.py index 2a5185e..e2f4c56 100644 --- a/src/python/dart/util/aws_batch.py +++ b/src/python/dart/util/aws_batch.py @@ -117,6 +117,34 @@ def generate_dag(self, single_ordered_wf_instance_actions, retries_on_failures, _logger.info("AWS_Batch: Done building workflow {0} with jobs: {1}". format(wf_attribs['workflow_id'], all_previous_jobs)) + + def add_container_overrides(self, oaction, submit_job_input, job_name): + ''' action overrides job_defintion or dart-rpt.yaml configs ''' + # special batch overrides + if hasattr(oaction.data, 'vcpus') and oaction.data.vcpus: + submit_job_input['containerOverrides']['vcpus'] = oaction.data.vcpus + _logger.info("AWS_Batch: job={0} vcpus overrides={1}".format(job_name, oaction.data.vcpus)) + + if hasattr(oaction.data, 'memory_mb') and oaction.data.memory_mb: + submit_job_input['containerOverrides']['memory'] = oaction.data.memory_mb + _logger.info("AWS_Batch: job={0} memory_mb overrides={1}".format(job_name, oaction.data.memory_mb)) + + if hasattr(oaction.data, 'job_definition') and oaction.data.job_definition: + submit_job_input['jobDefinition'] = oaction.data.job_definition + _logger.info("AWS_Batch: job={0} jobDefinition overrides={1}".format(job_name, oaction.data.job_definition)) + + if hasattr(oaction.data, 'job_name') and oaction.data.job_name: + submit_job_input['jobName'] = oaction.data.job_name + _logger.info("AWS_Batch: job={0} jobName overrides={1}".format(job_name, oaction.data.job_name)) + + if hasattr(oaction.data, 'job_queue') and oaction.data.job_queue: + submit_job_input['jobQueue'] = oaction.data.job_queue + _logger.info("AWS_Batch: job={0} job_queue overrides={1}".format(job_name, oaction.data.job_queue)) + + + return submit_job_input + + def submit_job(self, wf_attribs, idx, oaction, last_action_index, dependency, action_env): job_name = self.generate_job_name(wf_attribs['workflow_id'], oaction.data.order_idx, oaction.data.name, self.job_definition_suffix) _logger.info("AWS_Batch: job-name={0}, dependsOn={1}".format(job_name, dependency)) @@ -129,22 +157,25 @@ def submit_job(self, wf_attribs, idx, oaction, last_action_index, dependency, ac # submit_job is sensitive to None value in env variables so we wrap them with str(..) input_env = json.dumps(self.generate_env_vars(wf_attribs, action_env, idx == 0, idx == last_action_index)) - response = self.client.submit_job(jobName=job_name, - # SNS to notify workflow completion and action completion - jobDefinition=self.get_latest_active_job_definition(oaction.data.engine_name, self.job_definition_suffix, self.client.describe_job_definitions), - jobQueue=queue_name, - dependsOn=dependency, - containerOverrides={ - 'environment': [ - {'name': 'input_env', 'value': input_env}, # passing execution info to job - {'name': 'DART_ACTION_ID', 'value': str(oaction.id)}, - {'name': 'DART_ACTION_USER_ID', 'value': str(oaction.data.user_id)}, - {'name': 'DART_CONFIG', 'value': str(self.dart_config)}, - {'name': 'DART_ROLE', 'value': "worker:{0}".format(oaction.data.engine_name)}, # An implicit convention - {'name': 'DART_URL', 'value': str(self.dart_url)}, # Used by abacus to access data lineage - {'name': 'AWS_DEFAULT_REGION', 'value': str(self.aws_default_region)} - ] - }) + submit_job_input = { + 'jobName': job_name, + 'jobDefinition': self.get_latest_active_job_definition(oaction.data.engine_name, self.job_definition_suffix, self.client.describe_job_definitions), + 'jobQueue': queue_name, + 'dependsOn': dependency, + 'containerOverrides': { + 'environment': [ + {'name': 'input_env', 'value': input_env}, # passing execution info to job + {'name': 'DART_ACTION_ID', 'value': str(oaction.id)}, + {'name': 'DART_ACTION_USER_ID', 'value': str(oaction.data.user_id)}, + {'name': 'DART_CONFIG', 'value': str(self.dart_config)}, + {'name': 'DART_ROLE', 'value': "worker:{0}".format(oaction.data.engine_name)}, # An implicit convention + {'name': 'DART_URL', 'value': str(self.dart_url)}, # Used by abacus to access data lineage + {'name': 'AWS_DEFAULT_REGION', 'value': str(self.aws_default_region)} + ] + } + } + job_input = self.add_container_overrides(oaction=oaction, submit_job_input=submit_job_input, job_name=job_name) + response = self.client.submit_job(**job_input) _logger.info("AWS_Batch: response={0}".format(response)) return response['jobId'] diff --git a/src/python/dart/web/api/action.py b/src/python/dart/web/api/action.py index faf46b8..04f8eef 100644 --- a/src/python/dart/web/api/action.py +++ b/src/python/dart/web/api/action.py @@ -221,6 +221,12 @@ def update_action(action, updated_action): sanitized_action.data.on_success_email = updated_action.data.on_success_email sanitized_action.data.extra_data = updated_action.data.extra_data + sanitized_action.data.vcpus = updated_action.data.vcpus + sanitized_action.data.memory_mb = updated_action.data.memory_mb + sanitized_action.data.job_definition = updated_action.data.job_definition + sanitized_action.data.job_queue = updated_action.data.job_queue + sanitized_action.data.job_name = updated_action.data.job_name + # revalidate sanitized_action = action_service().default_and_validate_action(sanitized_action)