Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/providers/databricks/operators/databricks_repos.py: 0%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

122 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""This module contains Databricks operators.""" 

19 

20from __future__ import annotations 

21 

22import re 

23from functools import cached_property 

24from typing import TYPE_CHECKING, Sequence 

25from urllib.parse import urlsplit 

26 

27from airflow.exceptions import AirflowException 

28from airflow.models import BaseOperator 

29from airflow.providers.databricks.hooks.databricks import DatabricksHook 

30 

31if TYPE_CHECKING: 

32 from airflow.utils.context import Context 

33 

34 

35class DatabricksReposCreateOperator(BaseOperator): 

36 """ 

37 Creates, and optionally checks out, a Databricks Repo using the POST api/2.0/repos API endpoint. 

38 

39 See: https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/create-repo 

40 

41 :param git_url: Required HTTPS URL of a Git repository 

42 :param git_provider: Optional name of Git provider. Must be provided if we can't guess its name from URL. 

43 :param repo_path: optional path for a repository. Must be in the format ``/Repos/{folder}/{repo-name}``. 

44 If not specified, it will be created in the user's directory. 

45 :param branch: optional name of branch to check out. 

46 :param tag: optional name of tag to checkout. 

47 :param ignore_existing_repo: don't throw exception if repository with given path already exists. 

48 :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`. 

49 By default and in the common case this will be ``databricks_default``. To use 

50 token based authentication, provide the key ``token`` in the extra field for the 

51 connection and create the key ``host`` and leave the ``host`` field empty. (templated) 

52 :param databricks_retry_limit: Amount of times retry if the Databricks backend is 

53 unreachable. Its value must be greater than or equal to 1. 

54 :param databricks_retry_delay: Number of seconds to wait between retries (it 

55 might be a floating point number). 

56 """ 

57 

58 # Used in airflow.models.BaseOperator 

59 template_fields: Sequence[str] = ("repo_path", "tag", "branch", "databricks_conn_id") 

60 

61 __git_providers__ = { 

62 "github.com": "gitHub", 

63 "dev.azure.com": "azureDevOpsServices", 

64 "gitlab.com": "gitLab", 

65 "bitbucket.org": "bitbucketCloud", 

66 } 

67 __aws_code_commit_regexp__ = re.compile(r"^git-codecommit\.[^.]+\.amazonaws.com$") 

68 __repos_path_regexp__ = re.compile(r"/Repos/[^/]+/[^/]+/?$") 

69 

70 def __init__( 

71 self, 

72 *, 

73 git_url: str, 

74 git_provider: str | None = None, 

75 branch: str | None = None, 

76 tag: str | None = None, 

77 repo_path: str | None = None, 

78 ignore_existing_repo: bool = False, 

79 databricks_conn_id: str = "databricks_default", 

80 databricks_retry_limit: int = 3, 

81 databricks_retry_delay: int = 1, 

82 **kwargs, 

83 ) -> None: 

84 """Create a new ``DatabricksReposCreateOperator``.""" 

85 super().__init__(**kwargs) 

86 self.databricks_conn_id = databricks_conn_id 

87 self.databricks_retry_limit = databricks_retry_limit 

88 self.databricks_retry_delay = databricks_retry_delay 

89 self.git_url = git_url 

90 self.ignore_existing_repo = ignore_existing_repo 

91 if git_provider is None: 

92 self.git_provider = self.__detect_repo_provider__(git_url) 

93 if self.git_provider is None: 

94 raise AirflowException( 

95 f"git_provider isn't specified and couldn't be guessed for URL {git_url}" 

96 ) 

97 else: 

98 self.git_provider = git_provider 

99 self.repo_path = repo_path 

100 if branch is not None and tag is not None: 

101 raise AirflowException("Only one of branch or tag should be provided, but not both") 

102 self.branch = branch 

103 self.tag = tag 

104 

105 @staticmethod 

106 def __detect_repo_provider__(url): 

107 provider = None 

108 try: 

109 netloc = urlsplit(url).netloc.lower() 

110 _, _, netloc = netloc.rpartition("@") 

111 provider = DatabricksReposCreateOperator.__git_providers__.get(netloc) 

112 if provider is None and DatabricksReposCreateOperator.__aws_code_commit_regexp__.match(netloc): 

113 provider = "awsCodeCommit" 

114 except ValueError: 

115 pass 

116 return provider 

117 

118 @cached_property 

119 def _hook(self) -> DatabricksHook: 

120 return DatabricksHook( 

121 self.databricks_conn_id, 

122 retry_limit=self.databricks_retry_limit, 

123 retry_delay=self.databricks_retry_delay, 

124 caller="DatabricksReposCreateOperator", 

125 ) 

126 

127 def execute(self, context: Context): 

128 """ 

129 Create a Databricks Repo. 

130 

131 :param context: context 

132 :return: Repo ID 

133 """ 

134 payload = { 

135 "url": self.git_url, 

136 "provider": self.git_provider, 

137 } 

138 if self.repo_path is not None: 

139 if not self.__repos_path_regexp__.match(self.repo_path): 

140 raise AirflowException( 

141 f"repo_path should have form of /Repos/{{folder}}/{{repo-name}}, got '{self.repo_path}'" 

142 ) 

143 payload["path"] = self.repo_path 

144 existing_repo_id = None 

145 if self.repo_path is not None: 

146 existing_repo_id = self._hook.get_repo_by_path(self.repo_path) 

147 if existing_repo_id is not None and not self.ignore_existing_repo: 

148 raise AirflowException(f"Repo with path '{self.repo_path}' already exists") 

149 if existing_repo_id is None: 

150 result = self._hook.create_repo(payload) 

151 repo_id = result["id"] 

152 else: 

153 repo_id = existing_repo_id 

154 # update repo if necessary 

155 if self.branch is not None: 

156 self._hook.update_repo(str(repo_id), {"branch": str(self.branch)}) 

157 elif self.tag is not None: 

158 self._hook.update_repo(str(repo_id), {"tag": str(self.tag)}) 

159 

160 return repo_id 

161 

162 

163class DatabricksReposUpdateOperator(BaseOperator): 

164 """ 

165 Updates specified repository to a given branch or tag using the PATCH api/2.0/repos API endpoint. 

166 

167 See: https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/update-repo 

168 

169 :param branch: optional name of branch to update to. Should be specified if ``tag`` is omitted 

170 :param tag: optional name of tag to update to. Should be specified if ``branch`` is omitted 

171 :param repo_id: optional ID of existing repository. Should be specified if ``repo_path`` is omitted 

172 :param repo_path: optional path of existing repository. Should be specified if ``repo_id`` is omitted 

173 :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`. 

174 By default and in the common case this will be ``databricks_default``. To use 

175 token based authentication, provide the key ``token`` in the extra field for the 

176 connection and create the key ``host`` and leave the ``host`` field empty. (templated) 

177 :param databricks_retry_limit: Amount of times retry if the Databricks backend is 

178 unreachable. Its value must be greater than or equal to 1. 

179 :param databricks_retry_delay: Number of seconds to wait between retries (it 

180 might be a floating point number). 

181 """ 

182 

183 # Used in airflow.models.BaseOperator 

184 template_fields: Sequence[str] = ("repo_path", "tag", "branch", "databricks_conn_id") 

185 

186 def __init__( 

187 self, 

188 *, 

189 branch: str | None = None, 

190 tag: str | None = None, 

191 repo_id: str | None = None, 

192 repo_path: str | None = None, 

193 databricks_conn_id: str = "databricks_default", 

194 databricks_retry_limit: int = 3, 

195 databricks_retry_delay: int = 1, 

196 **kwargs, 

197 ) -> None: 

198 """Create a new ``DatabricksReposUpdateOperator``.""" 

199 super().__init__(**kwargs) 

200 self.databricks_conn_id = databricks_conn_id 

201 self.databricks_retry_limit = databricks_retry_limit 

202 self.databricks_retry_delay = databricks_retry_delay 

203 if branch is not None and tag is not None: 

204 raise AirflowException("Only one of branch or tag should be provided, but not both") 

205 if branch is None and tag is None: 

206 raise AirflowException("One of branch or tag should be provided") 

207 if repo_id is not None and repo_path is not None: 

208 raise AirflowException("Only one of repo_id or repo_path should be provided, but not both") 

209 if repo_id is None and repo_path is None: 

210 raise AirflowException("One of repo_id or repo_path should be provided") 

211 self.repo_path = repo_path 

212 self.repo_id = repo_id 

213 self.branch = branch 

214 self.tag = tag 

215 

216 @cached_property 

217 def _hook(self) -> DatabricksHook: 

218 return DatabricksHook( 

219 self.databricks_conn_id, 

220 retry_limit=self.databricks_retry_limit, 

221 retry_delay=self.databricks_retry_delay, 

222 caller="DatabricksReposUpdateOperator", 

223 ) 

224 

225 def execute(self, context: Context): 

226 if self.repo_path is not None: 

227 self.repo_id = self._hook.get_repo_by_path(self.repo_path) 

228 if self.repo_id is None: 

229 raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'") 

230 if self.branch is not None: 

231 payload = {"branch": str(self.branch)} 

232 else: 

233 payload = {"tag": str(self.tag)} 

234 

235 result = self._hook.update_repo(str(self.repo_id), payload) 

236 return result["head_commit_id"] 

237 

238 

239class DatabricksReposDeleteOperator(BaseOperator): 

240 """ 

241 Deletes specified repository using the DELETE api/2.0/repos API endpoint. 

242 

243 See: https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/delete-repo 

244 

245 :param repo_id: optional ID of existing repository. Should be specified if ``repo_path`` is omitted 

246 :param repo_path: optional path of existing repository. Should be specified if ``repo_id`` is omitted 

247 :param databricks_conn_id: Reference to the :ref:`Databricks connection <howto/connection:databricks>`. 

248 By default and in the common case this will be ``databricks_default``. To use 

249 token based authentication, provide the key ``token`` in the extra field for the 

250 connection and create the key ``host`` and leave the ``host`` field empty. (templated) 

251 :param databricks_retry_limit: Amount of times retry if the Databricks backend is 

252 unreachable. Its value must be greater than or equal to 1. 

253 :param databricks_retry_delay: Number of seconds to wait between retries (it 

254 might be a floating point number). 

255 """ 

256 

257 # Used in airflow.models.BaseOperator 

258 template_fields: Sequence[str] = ("repo_path", "databricks_conn_id") 

259 

260 def __init__( 

261 self, 

262 *, 

263 repo_id: str | None = None, 

264 repo_path: str | None = None, 

265 databricks_conn_id: str = "databricks_default", 

266 databricks_retry_limit: int = 3, 

267 databricks_retry_delay: int = 1, 

268 **kwargs, 

269 ) -> None: 

270 """Create a new ``DatabricksReposDeleteOperator``.""" 

271 super().__init__(**kwargs) 

272 self.databricks_conn_id = databricks_conn_id 

273 self.databricks_retry_limit = databricks_retry_limit 

274 self.databricks_retry_delay = databricks_retry_delay 

275 if repo_id is not None and repo_path is not None: 

276 raise AirflowException("Only one of repo_id or repo_path should be provided, but not both") 

277 if repo_id is None and repo_path is None: 

278 raise AirflowException("One of repo_id repo_path tag should be provided") 

279 self.repo_path = repo_path 

280 self.repo_id = repo_id 

281 

282 @cached_property 

283 def _hook(self) -> DatabricksHook: 

284 return DatabricksHook( 

285 self.databricks_conn_id, 

286 retry_limit=self.databricks_retry_limit, 

287 retry_delay=self.databricks_retry_delay, 

288 caller="DatabricksReposDeleteOperator", 

289 ) 

290 

291 def execute(self, context: Context): 

292 if self.repo_path is not None: 

293 self.repo_id = self._hook.get_repo_by_path(self.repo_path) 

294 if self.repo_id is None: 

295 raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'") 

296 

297 self._hook.delete_repo(str(self.repo_id))