Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/providers/databricks/operators/databricks_repos.py: 0%

125 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""This module contains Databricks operators.""" 

19from __future__ import annotations 

20 

21import re 

22from functools import cached_property 

23from typing import TYPE_CHECKING, Sequence 

24from urllib.parse import urlsplit 

25 

26from airflow.exceptions import AirflowException 

27from airflow.models import BaseOperator 

28from airflow.providers.databricks.hooks.databricks import DatabricksHook 

29 

30if TYPE_CHECKING: 

31 from airflow.utils.context import Context 

32 

33 

34class DatabricksReposCreateOperator(BaseOperator): 

35 """ 

36 Creates a Databricks Repo 

37 using 

38 `POST api/2.0/repos <https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/create-repo>`_ 

39 API endpoint and optionally checking it out to a specific branch or tag. 

40 

41 :param git_url: Required HTTPS URL of a Git repository 

42 :param git_provider: Optional name of Git provider. Must be provided if we can't guess its name from URL. 

43 :param repo_path: optional path for a repository. Must be in the format ``/Repos/{folder}/{repo-name}``. 

44 If not specified, it will be created in the user's directory. 

45 :param branch: optional name of branch to check out. 

46 :param tag: optional name of tag to checkout. 

47 :param ignore_existing_repo: don't throw exception if repository with given path already exists. 

48 :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`. 

49 By default and in the common case this will be ``databricks_default``. To use 

50 token based authentication, provide the key ``token`` in the extra field for the 

51 connection and create the key ``host`` and leave the ``host`` field empty. (templated) 

52 :param databricks_retry_limit: Amount of times retry if the Databricks backend is 

53 unreachable. Its value must be greater than or equal to 1. 

54 :param databricks_retry_delay: Number of seconds to wait between retries (it 

55 might be a floating point number). 

56 """ 

57 

58 # Used in airflow.models.BaseOperator 

59 template_fields: Sequence[str] = ("repo_path", "tag", "branch", "databricks_conn_id") 

60 

61 __git_providers__ = { 

62 "github.com": "gitHub", 

63 "dev.azure.com": "azureDevOpsServices", 

64 "gitlab.com": "gitLab", 

65 "bitbucket.org": "bitbucketCloud", 

66 } 

67 __aws_code_commit_regexp__ = re.compile(r"^git-codecommit\.[^.]+\.amazonaws.com$") 

68 __repos_path_regexp__ = re.compile(r"/Repos/[^/]+/[^/]+/?$") 

69 

70 def __init__( 

71 self, 

72 *, 

73 git_url: str, 

74 git_provider: str | None = None, 

75 branch: str | None = None, 

76 tag: str | None = None, 

77 repo_path: str | None = None, 

78 ignore_existing_repo: bool = False, 

79 databricks_conn_id: str = "databricks_default", 

80 databricks_retry_limit: int = 3, 

81 databricks_retry_delay: int = 1, 

82 **kwargs, 

83 ) -> None: 

84 """Creates a new ``DatabricksReposCreateOperator``.""" 

85 super().__init__(**kwargs) 

86 self.databricks_conn_id = databricks_conn_id 

87 self.databricks_retry_limit = databricks_retry_limit 

88 self.databricks_retry_delay = databricks_retry_delay 

89 self.git_url = git_url 

90 self.ignore_existing_repo = ignore_existing_repo 

91 if git_provider is None: 

92 self.git_provider = self.__detect_repo_provider__(git_url) 

93 if self.git_provider is None: 

94 raise AirflowException( 

95 f"git_provider isn't specified and couldn't be guessed for URL {git_url}" 

96 ) 

97 else: 

98 self.git_provider = git_provider 

99 self.repo_path = repo_path 

100 if branch is not None and tag is not None: 

101 raise AirflowException("Only one of branch or tag should be provided, but not both") 

102 self.branch = branch 

103 self.tag = tag 

104 

105 @staticmethod 

106 def __detect_repo_provider__(url): 

107 provider = None 

108 try: 

109 netloc = urlsplit(url).netloc 

110 idx = netloc.rfind("@") 

111 if idx != -1: 

112 netloc = netloc[(idx + 1) :] 

113 netloc = netloc.lower() 

114 provider = DatabricksReposCreateOperator.__git_providers__.get(netloc) 

115 if provider is None and DatabricksReposCreateOperator.__aws_code_commit_regexp__.match(netloc): 

116 provider = "awsCodeCommit" 

117 except ValueError: 

118 pass 

119 return provider 

120 

121 @cached_property 

122 def _hook(self) -> DatabricksHook: 

123 return DatabricksHook( 

124 self.databricks_conn_id, 

125 retry_limit=self.databricks_retry_limit, 

126 retry_delay=self.databricks_retry_delay, 

127 caller="DatabricksReposCreateOperator", 

128 ) 

129 

130 def execute(self, context: Context): 

131 """ 

132 Creates a Databricks Repo. 

133 

134 :param context: context 

135 :return: Repo ID 

136 """ 

137 payload = { 

138 "url": self.git_url, 

139 "provider": self.git_provider, 

140 } 

141 if self.repo_path is not None: 

142 if not self.__repos_path_regexp__.match(self.repo_path): 

143 raise AirflowException( 

144 f"repo_path should have form of /Repos/{{folder}}/{{repo-name}}, got '{self.repo_path}'" 

145 ) 

146 payload["path"] = self.repo_path 

147 existing_repo_id = None 

148 if self.repo_path is not None: 

149 existing_repo_id = self._hook.get_repo_by_path(self.repo_path) 

150 if existing_repo_id is not None and not self.ignore_existing_repo: 

151 raise AirflowException(f"Repo with path '{self.repo_path}' already exists") 

152 if existing_repo_id is None: 

153 result = self._hook.create_repo(payload) 

154 repo_id = result["id"] 

155 else: 

156 repo_id = existing_repo_id 

157 # update repo if necessary 

158 if self.branch is not None: 

159 self._hook.update_repo(str(repo_id), {"branch": str(self.branch)}) 

160 elif self.tag is not None: 

161 self._hook.update_repo(str(repo_id), {"tag": str(self.tag)}) 

162 

163 return repo_id 

164 

165 

166class DatabricksReposUpdateOperator(BaseOperator): 

167 """ 

168 Updates specified repository to a given branch or tag 

169 using `PATCH api/2.0/repos 

170 <https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/update-repo>`_ API endpoint. 

171 

172 :param branch: optional name of branch to update to. Should be specified if ``tag`` is omitted 

173 :param tag: optional name of tag to update to. Should be specified if ``branch`` is omitted 

174 :param repo_id: optional ID of existing repository. Should be specified if ``repo_path`` is omitted 

175 :param repo_path: optional path of existing repository. Should be specified if ``repo_id`` is omitted 

176 :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`. 

177 By default and in the common case this will be ``databricks_default``. To use 

178 token based authentication, provide the key ``token`` in the extra field for the 

179 connection and create the key ``host`` and leave the ``host`` field empty. (templated) 

180 :param databricks_retry_limit: Amount of times retry if the Databricks backend is 

181 unreachable. Its value must be greater than or equal to 1. 

182 :param databricks_retry_delay: Number of seconds to wait between retries (it 

183 might be a floating point number). 

184 """ 

185 

186 # Used in airflow.models.BaseOperator 

187 template_fields: Sequence[str] = ("repo_path", "tag", "branch", "databricks_conn_id") 

188 

189 def __init__( 

190 self, 

191 *, 

192 branch: str | None = None, 

193 tag: str | None = None, 

194 repo_id: str | None = None, 

195 repo_path: str | None = None, 

196 databricks_conn_id: str = "databricks_default", 

197 databricks_retry_limit: int = 3, 

198 databricks_retry_delay: int = 1, 

199 **kwargs, 

200 ) -> None: 

201 """Creates a new ``DatabricksReposUpdateOperator``.""" 

202 super().__init__(**kwargs) 

203 self.databricks_conn_id = databricks_conn_id 

204 self.databricks_retry_limit = databricks_retry_limit 

205 self.databricks_retry_delay = databricks_retry_delay 

206 if branch is not None and tag is not None: 

207 raise AirflowException("Only one of branch or tag should be provided, but not both") 

208 if branch is None and tag is None: 

209 raise AirflowException("One of branch or tag should be provided") 

210 if repo_id is not None and repo_path is not None: 

211 raise AirflowException("Only one of repo_id or repo_path should be provided, but not both") 

212 if repo_id is None and repo_path is None: 

213 raise AirflowException("One of repo_id or repo_path should be provided") 

214 self.repo_path = repo_path 

215 self.repo_id = repo_id 

216 self.branch = branch 

217 self.tag = tag 

218 

219 @cached_property 

220 def _hook(self) -> DatabricksHook: 

221 return DatabricksHook( 

222 self.databricks_conn_id, 

223 retry_limit=self.databricks_retry_limit, 

224 retry_delay=self.databricks_retry_delay, 

225 caller="DatabricksReposUpdateOperator", 

226 ) 

227 

228 def execute(self, context: Context): 

229 if self.repo_path is not None: 

230 self.repo_id = self._hook.get_repo_by_path(self.repo_path) 

231 if self.repo_id is None: 

232 raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'") 

233 if self.branch is not None: 

234 payload = {"branch": str(self.branch)} 

235 else: 

236 payload = {"tag": str(self.tag)} 

237 

238 result = self._hook.update_repo(str(self.repo_id), payload) 

239 return result["head_commit_id"] 

240 

241 

242class DatabricksReposDeleteOperator(BaseOperator): 

243 """ 

244 Deletes specified repository 

245 using `DELETE api/2.0/repos 

246 <https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/delete-repo>`_ API endpoint. 

247 

248 :param repo_id: optional ID of existing repository. Should be specified if ``repo_path`` is omitted 

249 :param repo_path: optional path of existing repository. Should be specified if ``repo_id`` is omitted 

250 :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`. 

251 By default and in the common case this will be ``databricks_default``. To use 

252 token based authentication, provide the key ``token`` in the extra field for the 

253 connection and create the key ``host`` and leave the ``host`` field empty. (templated) 

254 :param databricks_retry_limit: Amount of times retry if the Databricks backend is 

255 unreachable. Its value must be greater than or equal to 1. 

256 :param databricks_retry_delay: Number of seconds to wait between retries (it 

257 might be a floating point number). 

258 """ 

259 

260 # Used in airflow.models.BaseOperator 

261 template_fields: Sequence[str] = ("repo_path", "databricks_conn_id") 

262 

263 def __init__( 

264 self, 

265 *, 

266 repo_id: str | None = None, 

267 repo_path: str | None = None, 

268 databricks_conn_id: str = "databricks_default", 

269 databricks_retry_limit: int = 3, 

270 databricks_retry_delay: int = 1, 

271 **kwargs, 

272 ) -> None: 

273 """Creates a new ``DatabricksReposDeleteOperator``.""" 

274 super().__init__(**kwargs) 

275 self.databricks_conn_id = databricks_conn_id 

276 self.databricks_retry_limit = databricks_retry_limit 

277 self.databricks_retry_delay = databricks_retry_delay 

278 if repo_id is not None and repo_path is not None: 

279 raise AirflowException("Only one of repo_id or repo_path should be provided, but not both") 

280 if repo_id is None and repo_path is None: 

281 raise AirflowException("One of repo_id repo_path tag should be provided") 

282 self.repo_path = repo_path 

283 self.repo_id = repo_id 

284 

285 @cached_property 

286 def _hook(self) -> DatabricksHook: 

287 return DatabricksHook( 

288 self.databricks_conn_id, 

289 retry_limit=self.databricks_retry_limit, 

290 retry_delay=self.databricks_retry_delay, 

291 caller="DatabricksReposDeleteOperator", 

292 ) 

293 

294 def execute(self, context: Context): 

295 if self.repo_path is not None: 

296 self.repo_id = self._hook.get_repo_by_path(self.repo_path) 

297 if self.repo_id is None: 

298 raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'") 

299 

300 self._hook.delete_repo(str(self.repo_id))