Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/providers/databricks/operators/databricks_repos.py: 0%
125 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""This module contains Databricks operators."""
19from __future__ import annotations
21import re
22from functools import cached_property
23from typing import TYPE_CHECKING, Sequence
24from urllib.parse import urlsplit
26from airflow.exceptions import AirflowException
27from airflow.models import BaseOperator
28from airflow.providers.databricks.hooks.databricks import DatabricksHook
30if TYPE_CHECKING:
31 from airflow.utils.context import Context
34class DatabricksReposCreateOperator(BaseOperator):
35 """
36 Creates a Databricks Repo
37 using
38 `POST api/2.0/repos <https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/create-repo>`_
39 API endpoint and optionally checking it out to a specific branch or tag.
41 :param git_url: Required HTTPS URL of a Git repository
42 :param git_provider: Optional name of Git provider. Must be provided if we can't guess its name from URL.
43 :param repo_path: optional path for a repository. Must be in the format ``/Repos/{folder}/{repo-name}``.
44 If not specified, it will be created in the user's directory.
45 :param branch: optional name of branch to check out.
46 :param tag: optional name of tag to checkout.
47 :param ignore_existing_repo: don't throw exception if repository with given path already exists.
48 :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`.
49 By default and in the common case this will be ``databricks_default``. To use
50 token based authentication, provide the key ``token`` in the extra field for the
51 connection and create the key ``host`` and leave the ``host`` field empty. (templated)
52 :param databricks_retry_limit: Amount of times retry if the Databricks backend is
53 unreachable. Its value must be greater than or equal to 1.
54 :param databricks_retry_delay: Number of seconds to wait between retries (it
55 might be a floating point number).
56 """
58 # Used in airflow.models.BaseOperator
59 template_fields: Sequence[str] = ("repo_path", "tag", "branch", "databricks_conn_id")
61 __git_providers__ = {
62 "github.com": "gitHub",
63 "dev.azure.com": "azureDevOpsServices",
64 "gitlab.com": "gitLab",
65 "bitbucket.org": "bitbucketCloud",
66 }
67 __aws_code_commit_regexp__ = re.compile(r"^git-codecommit\.[^.]+\.amazonaws.com$")
68 __repos_path_regexp__ = re.compile(r"/Repos/[^/]+/[^/]+/?$")
70 def __init__(
71 self,
72 *,
73 git_url: str,
74 git_provider: str | None = None,
75 branch: str | None = None,
76 tag: str | None = None,
77 repo_path: str | None = None,
78 ignore_existing_repo: bool = False,
79 databricks_conn_id: str = "databricks_default",
80 databricks_retry_limit: int = 3,
81 databricks_retry_delay: int = 1,
82 **kwargs,
83 ) -> None:
84 """Creates a new ``DatabricksReposCreateOperator``."""
85 super().__init__(**kwargs)
86 self.databricks_conn_id = databricks_conn_id
87 self.databricks_retry_limit = databricks_retry_limit
88 self.databricks_retry_delay = databricks_retry_delay
89 self.git_url = git_url
90 self.ignore_existing_repo = ignore_existing_repo
91 if git_provider is None:
92 self.git_provider = self.__detect_repo_provider__(git_url)
93 if self.git_provider is None:
94 raise AirflowException(
95 f"git_provider isn't specified and couldn't be guessed for URL {git_url}"
96 )
97 else:
98 self.git_provider = git_provider
99 self.repo_path = repo_path
100 if branch is not None and tag is not None:
101 raise AirflowException("Only one of branch or tag should be provided, but not both")
102 self.branch = branch
103 self.tag = tag
105 @staticmethod
106 def __detect_repo_provider__(url):
107 provider = None
108 try:
109 netloc = urlsplit(url).netloc
110 idx = netloc.rfind("@")
111 if idx != -1:
112 netloc = netloc[(idx + 1) :]
113 netloc = netloc.lower()
114 provider = DatabricksReposCreateOperator.__git_providers__.get(netloc)
115 if provider is None and DatabricksReposCreateOperator.__aws_code_commit_regexp__.match(netloc):
116 provider = "awsCodeCommit"
117 except ValueError:
118 pass
119 return provider
121 @cached_property
122 def _hook(self) -> DatabricksHook:
123 return DatabricksHook(
124 self.databricks_conn_id,
125 retry_limit=self.databricks_retry_limit,
126 retry_delay=self.databricks_retry_delay,
127 caller="DatabricksReposCreateOperator",
128 )
130 def execute(self, context: Context):
131 """
132 Creates a Databricks Repo.
134 :param context: context
135 :return: Repo ID
136 """
137 payload = {
138 "url": self.git_url,
139 "provider": self.git_provider,
140 }
141 if self.repo_path is not None:
142 if not self.__repos_path_regexp__.match(self.repo_path):
143 raise AirflowException(
144 f"repo_path should have form of /Repos/{{folder}}/{{repo-name}}, got '{self.repo_path}'"
145 )
146 payload["path"] = self.repo_path
147 existing_repo_id = None
148 if self.repo_path is not None:
149 existing_repo_id = self._hook.get_repo_by_path(self.repo_path)
150 if existing_repo_id is not None and not self.ignore_existing_repo:
151 raise AirflowException(f"Repo with path '{self.repo_path}' already exists")
152 if existing_repo_id is None:
153 result = self._hook.create_repo(payload)
154 repo_id = result["id"]
155 else:
156 repo_id = existing_repo_id
157 # update repo if necessary
158 if self.branch is not None:
159 self._hook.update_repo(str(repo_id), {"branch": str(self.branch)})
160 elif self.tag is not None:
161 self._hook.update_repo(str(repo_id), {"tag": str(self.tag)})
163 return repo_id
166class DatabricksReposUpdateOperator(BaseOperator):
167 """
168 Updates specified repository to a given branch or tag
169 using `PATCH api/2.0/repos
170 <https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/update-repo>`_ API endpoint.
172 :param branch: optional name of branch to update to. Should be specified if ``tag`` is omitted
173 :param tag: optional name of tag to update to. Should be specified if ``branch`` is omitted
174 :param repo_id: optional ID of existing repository. Should be specified if ``repo_path`` is omitted
175 :param repo_path: optional path of existing repository. Should be specified if ``repo_id`` is omitted
176 :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`.
177 By default and in the common case this will be ``databricks_default``. To use
178 token based authentication, provide the key ``token`` in the extra field for the
179 connection and create the key ``host`` and leave the ``host`` field empty. (templated)
180 :param databricks_retry_limit: Amount of times retry if the Databricks backend is
181 unreachable. Its value must be greater than or equal to 1.
182 :param databricks_retry_delay: Number of seconds to wait between retries (it
183 might be a floating point number).
184 """
186 # Used in airflow.models.BaseOperator
187 template_fields: Sequence[str] = ("repo_path", "tag", "branch", "databricks_conn_id")
189 def __init__(
190 self,
191 *,
192 branch: str | None = None,
193 tag: str | None = None,
194 repo_id: str | None = None,
195 repo_path: str | None = None,
196 databricks_conn_id: str = "databricks_default",
197 databricks_retry_limit: int = 3,
198 databricks_retry_delay: int = 1,
199 **kwargs,
200 ) -> None:
201 """Creates a new ``DatabricksReposUpdateOperator``."""
202 super().__init__(**kwargs)
203 self.databricks_conn_id = databricks_conn_id
204 self.databricks_retry_limit = databricks_retry_limit
205 self.databricks_retry_delay = databricks_retry_delay
206 if branch is not None and tag is not None:
207 raise AirflowException("Only one of branch or tag should be provided, but not both")
208 if branch is None and tag is None:
209 raise AirflowException("One of branch or tag should be provided")
210 if repo_id is not None and repo_path is not None:
211 raise AirflowException("Only one of repo_id or repo_path should be provided, but not both")
212 if repo_id is None and repo_path is None:
213 raise AirflowException("One of repo_id or repo_path should be provided")
214 self.repo_path = repo_path
215 self.repo_id = repo_id
216 self.branch = branch
217 self.tag = tag
219 @cached_property
220 def _hook(self) -> DatabricksHook:
221 return DatabricksHook(
222 self.databricks_conn_id,
223 retry_limit=self.databricks_retry_limit,
224 retry_delay=self.databricks_retry_delay,
225 caller="DatabricksReposUpdateOperator",
226 )
228 def execute(self, context: Context):
229 if self.repo_path is not None:
230 self.repo_id = self._hook.get_repo_by_path(self.repo_path)
231 if self.repo_id is None:
232 raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'")
233 if self.branch is not None:
234 payload = {"branch": str(self.branch)}
235 else:
236 payload = {"tag": str(self.tag)}
238 result = self._hook.update_repo(str(self.repo_id), payload)
239 return result["head_commit_id"]
242class DatabricksReposDeleteOperator(BaseOperator):
243 """
244 Deletes specified repository
245 using `DELETE api/2.0/repos
246 <https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/delete-repo>`_ API endpoint.
248 :param repo_id: optional ID of existing repository. Should be specified if ``repo_path`` is omitted
249 :param repo_path: optional path of existing repository. Should be specified if ``repo_id`` is omitted
250 :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`.
251 By default and in the common case this will be ``databricks_default``. To use
252 token based authentication, provide the key ``token`` in the extra field for the
253 connection and create the key ``host`` and leave the ``host`` field empty. (templated)
254 :param databricks_retry_limit: Amount of times retry if the Databricks backend is
255 unreachable. Its value must be greater than or equal to 1.
256 :param databricks_retry_delay: Number of seconds to wait between retries (it
257 might be a floating point number).
258 """
260 # Used in airflow.models.BaseOperator
261 template_fields: Sequence[str] = ("repo_path", "databricks_conn_id")
263 def __init__(
264 self,
265 *,
266 repo_id: str | None = None,
267 repo_path: str | None = None,
268 databricks_conn_id: str = "databricks_default",
269 databricks_retry_limit: int = 3,
270 databricks_retry_delay: int = 1,
271 **kwargs,
272 ) -> None:
273 """Creates a new ``DatabricksReposDeleteOperator``."""
274 super().__init__(**kwargs)
275 self.databricks_conn_id = databricks_conn_id
276 self.databricks_retry_limit = databricks_retry_limit
277 self.databricks_retry_delay = databricks_retry_delay
278 if repo_id is not None and repo_path is not None:
279 raise AirflowException("Only one of repo_id or repo_path should be provided, but not both")
280 if repo_id is None and repo_path is None:
281 raise AirflowException("One of repo_id repo_path tag should be provided")
282 self.repo_path = repo_path
283 self.repo_id = repo_id
285 @cached_property
286 def _hook(self) -> DatabricksHook:
287 return DatabricksHook(
288 self.databricks_conn_id,
289 retry_limit=self.databricks_retry_limit,
290 retry_delay=self.databricks_retry_delay,
291 caller="DatabricksReposDeleteOperator",
292 )
294 def execute(self, context: Context):
295 if self.repo_path is not None:
296 self.repo_id = self._hook.get_repo_by_path(self.repo_path)
297 if self.repo_id is None:
298 raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'")
300 self._hook.delete_repo(str(self.repo_id))