1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""This module contains Databricks operators."""
19
20from __future__ import annotations
21
22import re
23from functools import cached_property
24from typing import TYPE_CHECKING, Sequence
25from urllib.parse import urlsplit
26
27from airflow.exceptions import AirflowException
28from airflow.models import BaseOperator
29from airflow.providers.databricks.hooks.databricks import DatabricksHook
30
31if TYPE_CHECKING:
32 from airflow.utils.context import Context
33
34
35class DatabricksReposCreateOperator(BaseOperator):
36 """
37 Creates, and optionally checks out, a Databricks Repo using the POST api/2.0/repos API endpoint.
38
39 See: https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/create-repo
40
41 :param git_url: Required HTTPS URL of a Git repository
42 :param git_provider: Optional name of Git provider. Must be provided if we can't guess its name from URL.
43 :param repo_path: optional path for a repository. Must be in the format ``/Repos/{folder}/{repo-name}``.
44 If not specified, it will be created in the user's directory.
45 :param branch: optional name of branch to check out.
46 :param tag: optional name of tag to checkout.
47 :param ignore_existing_repo: don't throw exception if repository with given path already exists.
48 :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`.
49 By default and in the common case this will be ``databricks_default``. To use
50 token based authentication, provide the key ``token`` in the extra field for the
51 connection and create the key ``host`` and leave the ``host`` field empty. (templated)
52 :param databricks_retry_limit: Amount of times retry if the Databricks backend is
53 unreachable. Its value must be greater than or equal to 1.
54 :param databricks_retry_delay: Number of seconds to wait between retries (it
55 might be a floating point number).
56 """
57
58 # Used in airflow.models.BaseOperator
59 template_fields: Sequence[str] = ("repo_path", "tag", "branch", "databricks_conn_id")
60
61 __git_providers__ = {
62 "github.com": "gitHub",
63 "dev.azure.com": "azureDevOpsServices",
64 "gitlab.com": "gitLab",
65 "bitbucket.org": "bitbucketCloud",
66 }
67 __aws_code_commit_regexp__ = re.compile(r"^git-codecommit\.[^.]+\.amazonaws.com$")
68 __repos_path_regexp__ = re.compile(r"/Repos/[^/]+/[^/]+/?$")
69
70 def __init__(
71 self,
72 *,
73 git_url: str,
74 git_provider: str | None = None,
75 branch: str | None = None,
76 tag: str | None = None,
77 repo_path: str | None = None,
78 ignore_existing_repo: bool = False,
79 databricks_conn_id: str = "databricks_default",
80 databricks_retry_limit: int = 3,
81 databricks_retry_delay: int = 1,
82 **kwargs,
83 ) -> None:
84 """Create a new ``DatabricksReposCreateOperator``."""
85 super().__init__(**kwargs)
86 self.databricks_conn_id = databricks_conn_id
87 self.databricks_retry_limit = databricks_retry_limit
88 self.databricks_retry_delay = databricks_retry_delay
89 self.git_url = git_url
90 self.ignore_existing_repo = ignore_existing_repo
91 if git_provider is None:
92 self.git_provider = self.__detect_repo_provider__(git_url)
93 if self.git_provider is None:
94 raise AirflowException(
95 f"git_provider isn't specified and couldn't be guessed for URL {git_url}"
96 )
97 else:
98 self.git_provider = git_provider
99 self.repo_path = repo_path
100 if branch is not None and tag is not None:
101 raise AirflowException("Only one of branch or tag should be provided, but not both")
102 self.branch = branch
103 self.tag = tag
104
105 @staticmethod
106 def __detect_repo_provider__(url):
107 provider = None
108 try:
109 netloc = urlsplit(url).netloc.lower()
110 _, _, netloc = netloc.rpartition("@")
111 provider = DatabricksReposCreateOperator.__git_providers__.get(netloc)
112 if provider is None and DatabricksReposCreateOperator.__aws_code_commit_regexp__.match(netloc):
113 provider = "awsCodeCommit"
114 except ValueError:
115 pass
116 return provider
117
118 @cached_property
119 def _hook(self) -> DatabricksHook:
120 return DatabricksHook(
121 self.databricks_conn_id,
122 retry_limit=self.databricks_retry_limit,
123 retry_delay=self.databricks_retry_delay,
124 caller="DatabricksReposCreateOperator",
125 )
126
127 def execute(self, context: Context):
128 """
129 Create a Databricks Repo.
130
131 :param context: context
132 :return: Repo ID
133 """
134 payload = {
135 "url": self.git_url,
136 "provider": self.git_provider,
137 }
138 if self.repo_path is not None:
139 if not self.__repos_path_regexp__.match(self.repo_path):
140 raise AirflowException(
141 f"repo_path should have form of /Repos/{{folder}}/{{repo-name}}, got '{self.repo_path}'"
142 )
143 payload["path"] = self.repo_path
144 existing_repo_id = None
145 if self.repo_path is not None:
146 existing_repo_id = self._hook.get_repo_by_path(self.repo_path)
147 if existing_repo_id is not None and not self.ignore_existing_repo:
148 raise AirflowException(f"Repo with path '{self.repo_path}' already exists")
149 if existing_repo_id is None:
150 result = self._hook.create_repo(payload)
151 repo_id = result["id"]
152 else:
153 repo_id = existing_repo_id
154 # update repo if necessary
155 if self.branch is not None:
156 self._hook.update_repo(str(repo_id), {"branch": str(self.branch)})
157 elif self.tag is not None:
158 self._hook.update_repo(str(repo_id), {"tag": str(self.tag)})
159
160 return repo_id
161
162
163class DatabricksReposUpdateOperator(BaseOperator):
164 """
165 Updates specified repository to a given branch or tag using the PATCH api/2.0/repos API endpoint.
166
167 See: https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/update-repo
168
169 :param branch: optional name of branch to update to. Should be specified if ``tag`` is omitted
170 :param tag: optional name of tag to update to. Should be specified if ``branch`` is omitted
171 :param repo_id: optional ID of existing repository. Should be specified if ``repo_path`` is omitted
172 :param repo_path: optional path of existing repository. Should be specified if ``repo_id`` is omitted
173 :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`.
174 By default and in the common case this will be ``databricks_default``. To use
175 token based authentication, provide the key ``token`` in the extra field for the
176 connection and create the key ``host`` and leave the ``host`` field empty. (templated)
177 :param databricks_retry_limit: Amount of times retry if the Databricks backend is
178 unreachable. Its value must be greater than or equal to 1.
179 :param databricks_retry_delay: Number of seconds to wait between retries (it
180 might be a floating point number).
181 """
182
183 # Used in airflow.models.BaseOperator
184 template_fields: Sequence[str] = ("repo_path", "tag", "branch", "databricks_conn_id")
185
186 def __init__(
187 self,
188 *,
189 branch: str | None = None,
190 tag: str | None = None,
191 repo_id: str | None = None,
192 repo_path: str | None = None,
193 databricks_conn_id: str = "databricks_default",
194 databricks_retry_limit: int = 3,
195 databricks_retry_delay: int = 1,
196 **kwargs,
197 ) -> None:
198 """Create a new ``DatabricksReposUpdateOperator``."""
199 super().__init__(**kwargs)
200 self.databricks_conn_id = databricks_conn_id
201 self.databricks_retry_limit = databricks_retry_limit
202 self.databricks_retry_delay = databricks_retry_delay
203 if branch is not None and tag is not None:
204 raise AirflowException("Only one of branch or tag should be provided, but not both")
205 if branch is None and tag is None:
206 raise AirflowException("One of branch or tag should be provided")
207 if repo_id is not None and repo_path is not None:
208 raise AirflowException("Only one of repo_id or repo_path should be provided, but not both")
209 if repo_id is None and repo_path is None:
210 raise AirflowException("One of repo_id or repo_path should be provided")
211 self.repo_path = repo_path
212 self.repo_id = repo_id
213 self.branch = branch
214 self.tag = tag
215
216 @cached_property
217 def _hook(self) -> DatabricksHook:
218 return DatabricksHook(
219 self.databricks_conn_id,
220 retry_limit=self.databricks_retry_limit,
221 retry_delay=self.databricks_retry_delay,
222 caller="DatabricksReposUpdateOperator",
223 )
224
225 def execute(self, context: Context):
226 if self.repo_path is not None:
227 self.repo_id = self._hook.get_repo_by_path(self.repo_path)
228 if self.repo_id is None:
229 raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'")
230 if self.branch is not None:
231 payload = {"branch": str(self.branch)}
232 else:
233 payload = {"tag": str(self.tag)}
234
235 result = self._hook.update_repo(str(self.repo_id), payload)
236 return result["head_commit_id"]
237
238
239class DatabricksReposDeleteOperator(BaseOperator):
240 """
241 Deletes specified repository using the DELETE api/2.0/repos API endpoint.
242
243 See: https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/delete-repo
244
245 :param repo_id: optional ID of existing repository. Should be specified if ``repo_path`` is omitted
246 :param repo_path: optional path of existing repository. Should be specified if ``repo_id`` is omitted
247 :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`.
248 By default and in the common case this will be ``databricks_default``. To use
249 token based authentication, provide the key ``token`` in the extra field for the
250 connection and create the key ``host`` and leave the ``host`` field empty. (templated)
251 :param databricks_retry_limit: Amount of times retry if the Databricks backend is
252 unreachable. Its value must be greater than or equal to 1.
253 :param databricks_retry_delay: Number of seconds to wait between retries (it
254 might be a floating point number).
255 """
256
257 # Used in airflow.models.BaseOperator
258 template_fields: Sequence[str] = ("repo_path", "databricks_conn_id")
259
260 def __init__(
261 self,
262 *,
263 repo_id: str | None = None,
264 repo_path: str | None = None,
265 databricks_conn_id: str = "databricks_default",
266 databricks_retry_limit: int = 3,
267 databricks_retry_delay: int = 1,
268 **kwargs,
269 ) -> None:
270 """Create a new ``DatabricksReposDeleteOperator``."""
271 super().__init__(**kwargs)
272 self.databricks_conn_id = databricks_conn_id
273 self.databricks_retry_limit = databricks_retry_limit
274 self.databricks_retry_delay = databricks_retry_delay
275 if repo_id is not None and repo_path is not None:
276 raise AirflowException("Only one of repo_id or repo_path should be provided, but not both")
277 if repo_id is None and repo_path is None:
278 raise AirflowException("One of repo_id repo_path tag should be provided")
279 self.repo_path = repo_path
280 self.repo_id = repo_id
281
282 @cached_property
283 def _hook(self) -> DatabricksHook:
284 return DatabricksHook(
285 self.databricks_conn_id,
286 retry_limit=self.databricks_retry_limit,
287 retry_delay=self.databricks_retry_delay,
288 caller="DatabricksReposDeleteOperator",
289 )
290
291 def execute(self, context: Context):
292 if self.repo_path is not None:
293 self.repo_id = self._hook.get_repo_by_path(self.repo_path)
294 if self.repo_id is None:
295 raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'")
296
297 self._hook.delete_repo(str(self.repo_id))