Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/providers/microsoft/azure/hooks/cosmos.py: 0%

142 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18""" 

19This module contains integration with Azure CosmosDB. 

20 

21AzureCosmosDBHook communicates via the Azure Cosmos library. Make sure that a 

22Airflow connection of type `azure_cosmos` exists. Authorization can be done by supplying a 

23login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify 

24the default database and collection to use (see connection `azure_cosmos_default` for an example). 

25""" 

26from __future__ import annotations 

27 

28import json 

29import uuid 

30from typing import Any 

31 

32from azure.cosmos.cosmos_client import CosmosClient 

33from azure.cosmos.exceptions import CosmosHttpResponseError 

34 

35from airflow.exceptions import AirflowBadRequest 

36from airflow.hooks.base import BaseHook 

37from airflow.providers.microsoft.azure.utils import _ensure_prefixes, get_field 

38 

39 

40class AzureCosmosDBHook(BaseHook): 

41 """ 

42 Interacts with Azure CosmosDB. 

43 

44 login should be the endpoint uri, password should be the master key 

45 optionally, you can use the following extras to default these values 

46 {"database_name": "<DATABASE_NAME>", "collection_name": "COLLECTION_NAME"}. 

47 

48 :param azure_cosmos_conn_id: Reference to the 

49 :ref:`Azure CosmosDB connection<howto/connection:azure_cosmos>`. 

50 """ 

51 

52 conn_name_attr = "azure_cosmos_conn_id" 

53 default_conn_name = "azure_cosmos_default" 

54 conn_type = "azure_cosmos" 

55 hook_name = "Azure CosmosDB" 

56 

57 @staticmethod 

58 def get_connection_form_widgets() -> dict[str, Any]: 

59 """Returns connection widgets to add to connection form""" 

60 from flask_appbuilder.fieldwidgets import BS3TextFieldWidget 

61 from flask_babel import lazy_gettext 

62 from wtforms import StringField 

63 

64 return { 

65 "database_name": StringField( 

66 lazy_gettext("Cosmos Database Name (optional)"), widget=BS3TextFieldWidget() 

67 ), 

68 "collection_name": StringField( 

69 lazy_gettext("Cosmos Collection Name (optional)"), widget=BS3TextFieldWidget() 

70 ), 

71 } 

72 

73 @staticmethod 

74 @_ensure_prefixes(conn_type="azure_cosmos") # todo: remove when min airflow version >= 2.5 

75 def get_ui_field_behaviour() -> dict[str, Any]: 

76 """Returns custom field behaviour""" 

77 return { 

78 "hidden_fields": ["schema", "port", "host", "extra"], 

79 "relabeling": { 

80 "login": "Cosmos Endpoint URI", 

81 "password": "Cosmos Master Key Token", 

82 }, 

83 "placeholders": { 

84 "login": "endpoint uri", 

85 "password": "master key", 

86 "database_name": "database name", 

87 "collection_name": "collection name", 

88 }, 

89 } 

90 

91 def __init__(self, azure_cosmos_conn_id: str = default_conn_name) -> None: 

92 super().__init__() 

93 self.conn_id = azure_cosmos_conn_id 

94 self._conn: CosmosClient | None = None 

95 

96 self.default_database_name = None 

97 self.default_collection_name = None 

98 

99 def _get_field(self, extras, name): 

100 return get_field( 

101 conn_id=self.conn_id, 

102 conn_type=self.conn_type, 

103 extras=extras, 

104 field_name=name, 

105 ) 

106 

107 def get_conn(self) -> CosmosClient: 

108 """Return a cosmos db client.""" 

109 if not self._conn: 

110 conn = self.get_connection(self.conn_id) 

111 extras = conn.extra_dejson 

112 endpoint_uri = conn.login 

113 master_key = conn.password 

114 

115 self.default_database_name = self._get_field(extras, "database_name") 

116 self.default_collection_name = self._get_field(extras, "collection_name") 

117 

118 # Initialize the Python Azure Cosmos DB client 

119 self._conn = CosmosClient(endpoint_uri, {"masterKey": master_key}) 

120 return self._conn 

121 

122 def __get_database_name(self, database_name: str | None = None) -> str: 

123 self.get_conn() 

124 db_name = database_name 

125 if db_name is None: 

126 db_name = self.default_database_name 

127 

128 if db_name is None: 

129 raise AirflowBadRequest("Database name must be specified") 

130 

131 return db_name 

132 

133 def __get_collection_name(self, collection_name: str | None = None) -> str: 

134 self.get_conn() 

135 coll_name = collection_name 

136 if coll_name is None: 

137 coll_name = self.default_collection_name 

138 

139 if coll_name is None: 

140 raise AirflowBadRequest("Collection name must be specified") 

141 

142 return coll_name 

143 

144 def does_collection_exist(self, collection_name: str, database_name: str) -> bool: 

145 """Checks if a collection exists in CosmosDB.""" 

146 if collection_name is None: 

147 raise AirflowBadRequest("Collection name cannot be None.") 

148 

149 existing_container = list( 

150 self.get_conn() 

151 .get_database_client(self.__get_database_name(database_name)) 

152 .query_containers( 

153 "SELECT * FROM r WHERE r.id=@id", 

154 parameters=[json.dumps({"name": "@id", "value": collection_name})], 

155 ) 

156 ) 

157 if len(existing_container) == 0: 

158 return False 

159 

160 return True 

161 

162 def create_collection( 

163 self, 

164 collection_name: str, 

165 database_name: str | None = None, 

166 partition_key: str | None = None, 

167 ) -> None: 

168 """Creates a new collection in the CosmosDB database.""" 

169 if collection_name is None: 

170 raise AirflowBadRequest("Collection name cannot be None.") 

171 

172 # We need to check to see if this container already exists so we don't try 

173 # to create it twice 

174 existing_container = list( 

175 self.get_conn() 

176 .get_database_client(self.__get_database_name(database_name)) 

177 .query_containers( 

178 "SELECT * FROM r WHERE r.id=@id", 

179 parameters=[json.dumps({"name": "@id", "value": collection_name})], 

180 ) 

181 ) 

182 

183 # Only create if we did not find it already existing 

184 if len(existing_container) == 0: 

185 self.get_conn().get_database_client(self.__get_database_name(database_name)).create_container( 

186 collection_name, partition_key=partition_key 

187 ) 

188 

189 def does_database_exist(self, database_name: str) -> bool: 

190 """Checks if a database exists in CosmosDB.""" 

191 if database_name is None: 

192 raise AirflowBadRequest("Database name cannot be None.") 

193 

194 existing_database = list( 

195 self.get_conn().query_databases( 

196 "SELECT * FROM r WHERE r.id=@id", 

197 parameters=[json.dumps({"name": "@id", "value": database_name})], 

198 ) 

199 ) 

200 if len(existing_database) == 0: 

201 return False 

202 

203 return True 

204 

205 def create_database(self, database_name: str) -> None: 

206 """Creates a new database in CosmosDB.""" 

207 if database_name is None: 

208 raise AirflowBadRequest("Database name cannot be None.") 

209 

210 # We need to check to see if this database already exists so we don't try 

211 # to create it twice 

212 existing_database = list( 

213 self.get_conn().query_databases( 

214 "SELECT * FROM r WHERE r.id=@id", 

215 parameters=[json.dumps({"name": "@id", "value": database_name})], 

216 ) 

217 ) 

218 

219 # Only create if we did not find it already existing 

220 if len(existing_database) == 0: 

221 self.get_conn().create_database(database_name) 

222 

223 def delete_database(self, database_name: str) -> None: 

224 """Deletes an existing database in CosmosDB.""" 

225 if database_name is None: 

226 raise AirflowBadRequest("Database name cannot be None.") 

227 

228 self.get_conn().delete_database(database_name) 

229 

230 def delete_collection(self, collection_name: str, database_name: str | None = None) -> None: 

231 """Deletes an existing collection in the CosmosDB database.""" 

232 if collection_name is None: 

233 raise AirflowBadRequest("Collection name cannot be None.") 

234 

235 self.get_conn().get_database_client(self.__get_database_name(database_name)).delete_container( 

236 collection_name 

237 ) 

238 

239 def upsert_document(self, document, database_name=None, collection_name=None, document_id=None): 

240 """ 

241 Inserts a new document (or updates an existing one) into an existing 

242 collection in the CosmosDB database. 

243 """ 

244 # Assign unique ID if one isn't provided 

245 if document_id is None: 

246 document_id = str(uuid.uuid4()) 

247 

248 if document is None: 

249 raise AirflowBadRequest("You cannot insert a None document") 

250 

251 # Add document id if isn't found 

252 if "id" in document: 

253 if document["id"] is None: 

254 document["id"] = document_id 

255 else: 

256 document["id"] = document_id 

257 

258 created_document = ( 

259 self.get_conn() 

260 .get_database_client(self.__get_database_name(database_name)) 

261 .get_container_client(self.__get_collection_name(collection_name)) 

262 .upsert_item(document) 

263 ) 

264 

265 return created_document 

266 

267 def insert_documents( 

268 self, documents, database_name: str | None = None, collection_name: str | None = None 

269 ) -> list: 

270 """Insert a list of new documents into an existing collection in the CosmosDB database.""" 

271 if documents is None: 

272 raise AirflowBadRequest("You cannot insert empty documents") 

273 

274 created_documents = [] 

275 for single_document in documents: 

276 created_documents.append( 

277 self.get_conn() 

278 .get_database_client(self.__get_database_name(database_name)) 

279 .get_container_client(self.__get_collection_name(collection_name)) 

280 .create_item(single_document) 

281 ) 

282 

283 return created_documents 

284 

285 def delete_document( 

286 self, 

287 document_id: str, 

288 database_name: str | None = None, 

289 collection_name: str | None = None, 

290 partition_key: str | None = None, 

291 ) -> None: 

292 """Delete an existing document out of a collection in the CosmosDB database.""" 

293 if document_id is None: 

294 raise AirflowBadRequest("Cannot delete a document without an id") 

295 ( 

296 self.get_conn() 

297 .get_database_client(self.__get_database_name(database_name)) 

298 .get_container_client(self.__get_collection_name(collection_name)) 

299 .delete_item(document_id, partition_key=partition_key) 

300 ) 

301 

302 def get_document( 

303 self, 

304 document_id: str, 

305 database_name: str | None = None, 

306 collection_name: str | None = None, 

307 partition_key: str | None = None, 

308 ): 

309 """Get a document from an existing collection in the CosmosDB database.""" 

310 if document_id is None: 

311 raise AirflowBadRequest("Cannot get a document without an id") 

312 

313 try: 

314 return ( 

315 self.get_conn() 

316 .get_database_client(self.__get_database_name(database_name)) 

317 .get_container_client(self.__get_collection_name(collection_name)) 

318 .read_item(document_id, partition_key=partition_key) 

319 ) 

320 except CosmosHttpResponseError: 

321 return None 

322 

323 def get_documents( 

324 self, 

325 sql_string: str, 

326 database_name: str | None = None, 

327 collection_name: str | None = None, 

328 partition_key: str | None = None, 

329 ) -> list | None: 

330 """Get a list of documents from an existing collection in the CosmosDB database via SQL query.""" 

331 if sql_string is None: 

332 raise AirflowBadRequest("SQL query string cannot be None") 

333 

334 try: 

335 result_iterable = ( 

336 self.get_conn() 

337 .get_database_client(self.__get_database_name(database_name)) 

338 .get_container_client(self.__get_collection_name(collection_name)) 

339 .query_items(sql_string, partition_key=partition_key) 

340 ) 

341 return list(result_iterable) 

342 except CosmosHttpResponseError: 

343 return None 

344 

345 def test_connection(self): 

346 """Test a configured Azure Cosmos connection.""" 

347 try: 

348 # Attempt to list existing databases under the configured subscription and retrieve the first in 

349 # the returned iterator. The Azure Cosmos API does allow for creation of a 

350 # CosmosClient with incorrect values but then will fail properly once items are 

351 # retrieved using the client. We need to _actually_ try to retrieve an object to properly test the 

352 # connection. 

353 next(iter(self.get_conn().list_databases()), None) 

354 except Exception as e: 

355 return False, str(e) 

356 return True, "Successfully connected to Azure Cosmos." 

357 

358 

359def get_database_link(database_id: str) -> str: 

360 """Get Azure CosmosDB database link""" 

361 return "dbs/" + database_id 

362 

363 

364def get_collection_link(database_id: str, collection_id: str) -> str: 

365 """Get Azure CosmosDB collection link""" 

366 return get_database_link(database_id) + "/colls/" + collection_id 

367 

368 

369def get_document_link(database_id: str, collection_id: str, document_id: str) -> str: 

370 """Get Azure CosmosDB document link""" 

371 return get_collection_link(database_id, collection_id) + "/docs/" + document_id