diff --git a/dbt/adapters/databricks/api_client.py b/dbt/adapters/databricks/api_client.py index 7928880e..83c79e32 100644 --- a/dbt/adapters/databricks/api_client.py +++ b/dbt/adapters/databricks/api_client.py @@ -6,6 +6,7 @@ from typing import Any from typing import Callable from typing import Dict +from typing import List from typing import Optional from typing import Set @@ -41,6 +42,11 @@ def post( ) -> Response: return self.session.post(f"{self.prefix}{suffix}", json=json, params=params) + def put( + self, suffix: str = "", json: Optional[Any] = None, params: Optional[Dict[str, Any]] = None + ) -> Response: + return self.session.put(f"{self.prefix}{suffix}", json=json, params=params) + class DatabricksApi(ABC): def __init__(self, session: Session, host: str, api: str): @@ -357,6 +363,86 @@ def cancel(self, run_id: str) -> None: raise DbtRuntimeError(f"Cancel run {run_id} failed.\n {response.content!r}") +class JobPermissionsApi(DatabricksApi): + def __init__(self, session: Session, host: str): + super().__init__(session, host, "/api/2.0/permissions/jobs") + + def put(self, job_id: str, access_control_list: List[Dict[str, Any]]) -> None: + request_body = {"access_control_list": access_control_list} + + response = self.session.put(f"/{job_id}", json=request_body) + logger.info(f"Workflow permissions update response={response.json()}") + + if response.status_code != 200: + raise DbtRuntimeError(f"Error updating Databricks workflow.\n {response.content!r}") + + def get(self, job_id: str) -> Dict[str, Any]: + response = self.session.get(f"/{job_id}") + + if response.status_code != 200: + raise DbtRuntimeError( + f"Error fetching Databricks workflow permissions.\n {response.content!r}" + ) + + return response.json() + + +class WorkflowJobApi(DatabricksApi): + + def __init__(self, session: Session, host: str): + super().__init__(session, host, "/api/2.1/jobs") + + def search_by_name(self, job_name: str) -> List[Dict[str, Any]]: + response = self.session.get("/list", json={"name": job_name}) + + if response.status_code != 200: + raise DbtRuntimeError(f"Error fetching job by name.\n {response.content!r}") + + logger.info(f"Job list response={response.json()}") + return response.json().get("jobs", []) + + def create(self, job_spec: Dict[str, Any]) -> str: + """ + :return: the job_id + """ + response = self.session.post("/create", json=job_spec) + + if response.status_code != 200: + raise DbtRuntimeError(f"Error creating Workflow.\n {response.content!r}") + + logger.info(f"Workflow creation response={response.json()}") + return response.json()["job_id"] + + def update_job_settings(self, job_id: str, job_spec: Dict[str, Any]) -> None: + request_body = { + "job_id": job_id, + "new_settings": job_spec, + } + response = self.session.post("/reset", json=request_body) + + if response.status_code != 200: + raise DbtRuntimeError(f"Error updating Workflow.\n {response.content!r}") + + logger.info(f"Workflow update response={response.json()}") + + def run(self, job_id: str, enable_queueing=True) -> str: + request_body = { + "job_id": job_id, + "queue": { + "enabled": enable_queueing, + } + } + response = self.session.post("/run-now", json=request_body) + + if response.status_code != 200: + raise DbtRuntimeError(f"Error triggering run for workflow.\n {response.content!r}") + + response_json = response.json() + logger.info(f"Workflow trigger response={response_json}") + + return response_json["run_id"] + + class DatabricksApiClient: def __init__( self, @@ -375,6 +461,8 @@ def __init__( self.workspace = WorkspaceApi(session, host, self.folders) self.commands = CommandApi(session, host, polling_interval, timeout) self.job_runs = JobRunsApi(session, host, polling_interval, timeout) + self.workflows = WorkflowJobApi(session, host) + self.workflow_permissions = JobPermissionsApi(session, host) @staticmethod def create( diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index b6aca192..92ee29c6 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -50,6 +50,9 @@ from dbt.adapters.databricks.python_models.python_submissions import ( ServerlessClusterPythonJobHelper, ) +from dbt.adapters.databricks.python_models.python_submissions import ( + WorkflowPythonJobHelper, +) from dbt.adapters.databricks.relation import DatabricksRelation from dbt.adapters.databricks.relation import DatabricksRelationType from dbt.adapters.databricks.relation import KEY_TABLE_PROVIDER @@ -623,6 +626,7 @@ def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]: "job_cluster": JobClusterPythonJobHelper, "all_purpose_cluster": AllPurposeClusterPythonJobHelper, "serverless_cluster": ServerlessClusterPythonJobHelper, + "workflow_job": WorkflowPythonJobHelper, } @available diff --git a/dbt/adapters/databricks/python_models/python_submissions.py b/dbt/adapters/databricks/python_models/python_submissions.py index eb017fc2..00ab2a1e 100644 --- a/dbt/adapters/databricks/python_models/python_submissions.py +++ b/dbt/adapters/databricks/python_models/python_submissions.py @@ -1,13 +1,16 @@ import uuid from typing import Any from typing import Dict +from typing import List from typing import Optional +from typing import Tuple from dbt.adapters.base import PythonJobHelper from dbt.adapters.databricks.api_client import CommandExecution from dbt.adapters.databricks.api_client import DatabricksApiClient from dbt.adapters.databricks.credentials import DatabricksCredentials from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker +from dbt_common.exceptions import DbtRuntimeError DEFAULT_TIMEOUT = 60 * 60 * 24 @@ -162,3 +165,178 @@ def submit(self, compiled_code: str) -> None: class ServerlessClusterPythonJobHelper(BaseDatabricksHelper): def submit(self, compiled_code: str) -> None: self._submit_through_notebook(compiled_code, {}) + + +class WorkflowPythonJobHelper(BaseDatabricksHelper): + + @property + def default_job_name(self) -> str: + return f"dbt__{self.database}-{self.schema}-{self.identifier}" + + @property + def notebook_path(self) -> str: + return f"{self.notebook_dir}/{self.identifier}" + + @property + def notebook_dir(self) -> str: + return self.api_client.workspace.user_api.get_folder(self.catalog, self.schema) + + @property + def catalog(self) -> str: + return self.database or "hive_metastore" + + def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None: + super().__init__(parsed_model, credentials) + + def check_credentials(self) -> None: + workflow_config = self.parsed_model["config"].get("workflow_job_config", None) + if not workflow_config: + raise ValueError( + "workflow_job_config is required for the `workflow_job_config` submission method." + ) + + def submit(self, compiled_code: str) -> None: + workflow_spec = self.parsed_model["config"]["workflow_job_config"] + cluster_spec = self.parsed_model["config"].get("job_cluster_config", None) + + # This dict gets modified throughout. Settings added through dbt are popped off + # before the spec is sent to the Databricks API + workflow_spec = self._build_job_spec(workflow_spec, cluster_spec) + + self._submit_through_workflow(compiled_code, workflow_spec) + + def _build_job_spec( + self, workflow_spec: Dict[str, Any], cluster_spec: Dict[str, Any] + ) -> Dict[str, Any]: + workflow_spec["name"] = workflow_spec.get("name", self.default_job_name) + + cluster_settings = ( + {} + ) # Undefined cluster settings defaults to serverless in the Databricks API + if cluster_spec is not None: + cluster_settings["new_cluster"] = cluster_spec + elif "existing_cluster_id" in workflow_spec: + cluster_settings["existing_cluster_id"] = workflow_spec["existing_cluster_id"] + + notebook_task = { + "task_key": "task_a", + "notebook_task": { + "notebook_path": self.notebook_path, + "source": "WORKSPACE", + }, + } + notebook_task.update(cluster_settings) + notebook_task.update(workflow_spec.pop("additional_task_settings", {})) + + post_hook_tasks = workflow_spec.pop("post_hook_tasks", []) + for task in post_hook_tasks: + if "existing_cluster_id" not in task and "new_cluster" not in task: + task.update(cluster_settings) + + workflow_spec["tasks"] = [notebook_task] + post_hook_tasks + return workflow_spec + + def _submit_through_workflow(self, compiled_code: str, workflow_spec: Dict[str, Any]) -> None: + self.api_client.workspace.create_python_model_dir(self.catalog, self.schema) + self.api_client.workspace.upload_notebook(self.notebook_path, compiled_code) + + job_id, is_new = self._get_or_create_job(workflow_spec) + + if not is_new: + self.api_client.workflows.update_job_settings(job_id, workflow_spec) + + grants = workflow_spec.pop("grants", {}) + access_control_list = self._build_job_permissions(job_id, grants) + self.api_client.workflow_permissions.put(job_id, access_control_list) + + run_id = self.api_client.workflows.run(job_id, enable_queueing=True) + self.tracker.insert_run_id(run_id) + + try: + self.api_client.job_runs.poll_for_completion(run_id) + finally: + self.tracker.remove_run_id(run_id) + + def _get_or_create_job(self, workflow_spec: Dict[str, Any]) -> Tuple[str, bool]: + """ + :return: tuple of job_id and whether the job is new + """ + existing_job_id = workflow_spec.pop("existing_job_id", "") + if existing_job_id: + return existing_job_id, False + + response_jobs = self.api_client.workflows.search_by_name(workflow_spec["name"]) + + if len(response_jobs) > 1: + raise DbtRuntimeError( + f"""Multiple jobs found with name {workflow_spec['name']}. Use a unique job + name or specify the `existing_job_id` in the workflow_job_config.""" + ) + + if len(response_jobs) == 1: + return response_jobs[0]["job_id"], False + else: + return self.api_client.workflows.create(workflow_spec), True + + def _build_job_permissions( + self, job_id: str, job_grants: Dict[str, List[Dict[str, Any]]] + ) -> List[Dict[str, Any]]: + access_control_list = [] + current_owner, permissions_attribute = self._get_current_job_owner(job_id) + access_control_list.append( + { + permissions_attribute: current_owner, + "permission_level": "IS_OWNER", + } + ) + + for grant in job_grants.get("view", []): + acl_grant = grant.copy() + acl_grant.update( + { + "permission_level": "CAN_VIEW", + } + ) + access_control_list.append(acl_grant) + for grant in job_grants.get("run", []): + acl_grant = grant.copy() + acl_grant.update( + { + "permission_level": "CAN_MANAGE_RUN", + } + ) + access_control_list.append(acl_grant) + for grant in job_grants.get("manage", []): + acl_grant = grant.copy() + acl_grant.update( + { + "permission_level": "CAN_MANAGE", + } + ) + access_control_list.append(acl_grant) + + return access_control_list + + def _get_current_job_owner(self, job_id: str) -> Tuple[str, str]: + """ + :return: a tuple of the user id and the ACL attribute it came from ie: + [user_name|group_name|service_principal_name] + For example: `("mateizaharia@databricks.com", "user_name")` + """ + permissions = self.api_client.workflow_permissions.get(job_id) + for principal in permissions.get("access_control_list", []): + for permission in principal["all_permissions"]: + if ( + permission["permission_level"] == "IS_OWNER" + and permission["inherited"] is False + ): + if principal.get("user_name"): + return principal["user_name"], "user_name" + elif principal.get("group_name"): + return principal["group_name"], "group_name" + else: + return principal["service_principal_name"], "service_principal_name" + + raise DbtRuntimeError( + f"Error getting current owner for Databricks workflow.\n {permissions!r}" + ) \ No newline at end of file