Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/providers/microsoft/azure/hooks/cosmos.py: 0%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

166 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18""" 

19This module contains integration with Azure CosmosDB. 

20 

21AzureCosmosDBHook communicates via the Azure Cosmos library. Make sure that a 

22Airflow connection of type `azure_cosmos` exists. Authorization can be done by supplying a 

23login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify 

24the default database and collection to use (see connection `azure_cosmos_default` for an example). 

25""" 

26 

27from __future__ import annotations 

28 

29import uuid 

30from typing import TYPE_CHECKING, Any, List, Union 

31from urllib.parse import urlparse 

32 

33from azure.cosmos import PartitionKey 

34from azure.cosmos.cosmos_client import CosmosClient 

35from azure.cosmos.exceptions import CosmosHttpResponseError 

36from azure.mgmt.cosmosdb import CosmosDBManagementClient 

37 

38from airflow.exceptions import AirflowBadRequest, AirflowException 

39from airflow.hooks.base import BaseHook 

40from airflow.providers.microsoft.azure.utils import ( 

41 add_managed_identity_connection_widgets, 

42 get_field, 

43 get_sync_default_azure_credential, 

44) 

45 

46if TYPE_CHECKING: 

47 PartitionKeyType = Union[str, List[str]] 

48 

49 

50class AzureCosmosDBHook(BaseHook): 

51 """ 

52 Interact with Azure CosmosDB. 

53 

54 login should be the endpoint uri, password should be the master key 

55 optionally, you can use the following extras to default these values 

56 {"database_name": "<DATABASE_NAME>", "collection_name": "COLLECTION_NAME"}. 

57 

58 :param azure_cosmos_conn_id: Reference to the 

59 :ref:`Azure CosmosDB connection<howto/connection:azure_cosmos>`. 

60 """ 

61 

62 conn_name_attr = "azure_cosmos_conn_id" 

63 default_conn_name = "azure_cosmos_default" 

64 conn_type = "azure_cosmos" 

65 hook_name = "Azure CosmosDB" 

66 

67 @classmethod 

68 @add_managed_identity_connection_widgets 

69 def get_connection_form_widgets(cls) -> dict[str, Any]: 

70 """Return connection widgets to add to connection form.""" 

71 from flask_appbuilder.fieldwidgets import BS3TextFieldWidget 

72 from flask_babel import lazy_gettext 

73 from wtforms import StringField 

74 

75 return { 

76 "database_name": StringField( 

77 lazy_gettext("Cosmos Database Name (optional)"), widget=BS3TextFieldWidget() 

78 ), 

79 "collection_name": StringField( 

80 lazy_gettext("Cosmos Collection Name (optional)"), widget=BS3TextFieldWidget() 

81 ), 

82 "subscription_id": StringField( 

83 lazy_gettext("Subscription ID (optional)"), 

84 widget=BS3TextFieldWidget(), 

85 ), 

86 "resource_group_name": StringField( 

87 lazy_gettext("Resource Group Name (optional)"), 

88 widget=BS3TextFieldWidget(), 

89 ), 

90 } 

91 

92 @classmethod 

93 def get_ui_field_behaviour(cls) -> dict[str, Any]: 

94 """Return custom field behaviour.""" 

95 return { 

96 "hidden_fields": ["schema", "port", "host", "extra"], 

97 "relabeling": { 

98 "login": "Cosmos Endpoint URI", 

99 "password": "Cosmos Master Key Token", 

100 }, 

101 "placeholders": { 

102 "login": "endpoint uri", 

103 "password": "master key (not needed for Azure AD authentication)", 

104 "database_name": "database name", 

105 "collection_name": "collection name", 

106 "subscription_id": "Subscription ID (required for Azure AD authentication)", 

107 "resource_group_name": "Resource Group Name (required for Azure AD authentication)", 

108 }, 

109 } 

110 

111 def __init__(self, azure_cosmos_conn_id: str = default_conn_name) -> None: 

112 super().__init__() 

113 self.conn_id = azure_cosmos_conn_id 

114 self._conn: CosmosClient | None = None 

115 

116 self.default_database_name = None 

117 self.default_collection_name = None 

118 self.default_partition_key = None 

119 

120 def _get_field(self, extras, name): 

121 return get_field( 

122 conn_id=self.conn_id, 

123 conn_type=self.conn_type, 

124 extras=extras, 

125 field_name=name, 

126 ) 

127 

128 def get_conn(self) -> CosmosClient: 

129 """Return a cosmos db client.""" 

130 if not self._conn: 

131 conn = self.get_connection(self.conn_id) 

132 extras = conn.extra_dejson 

133 endpoint_uri = conn.login 

134 resource_group_name = self._get_field(extras, "resource_group_name") 

135 

136 if conn.password: 

137 master_key = conn.password 

138 elif resource_group_name: 

139 managed_identity_client_id = self._get_field(extras, "managed_identity_client_id") 

140 workload_identity_tenant_id = self._get_field(extras, "workload_identity_tenant_id") 

141 subscritption_id = self._get_field(extras, "subscription_id") 

142 credential = get_sync_default_azure_credential( 

143 managed_identity_client_id=managed_identity_client_id, 

144 workload_identity_tenant_id=workload_identity_tenant_id, 

145 ) 

146 management_client = CosmosDBManagementClient( 

147 credential=credential, 

148 subscription_id=subscritption_id, 

149 ) 

150 

151 database_account = urlparse(conn.login).netloc.split(".")[0] 

152 database_account_keys = management_client.database_accounts.list_keys( 

153 resource_group_name, database_account 

154 ) 

155 master_key = database_account_keys.primary_master_key 

156 else: 

157 raise AirflowException("Either password or resource_group_name is required") 

158 

159 self.default_database_name = self._get_field(extras, "database_name") 

160 self.default_collection_name = self._get_field(extras, "collection_name") 

161 self.default_partition_key = self._get_field(extras, "partition_key") 

162 

163 # Initialize the Python Azure Cosmos DB client 

164 self._conn = CosmosClient(endpoint_uri, {"masterKey": master_key}) 

165 return self._conn 

166 

167 def __get_database_name(self, database_name: str | None = None) -> str: 

168 self.get_conn() 

169 db_name = database_name 

170 if db_name is None: 

171 db_name = self.default_database_name 

172 

173 if db_name is None: 

174 raise AirflowBadRequest("Database name must be specified") 

175 

176 return db_name 

177 

178 def __get_collection_name(self, collection_name: str | None = None) -> str: 

179 self.get_conn() 

180 coll_name = collection_name 

181 if coll_name is None: 

182 coll_name = self.default_collection_name 

183 

184 if coll_name is None: 

185 raise AirflowBadRequest("Collection name must be specified") 

186 

187 return coll_name 

188 

189 def __get_partition_key(self, partition_key: PartitionKeyType | None = None) -> PartitionKeyType: 

190 self.get_conn() 

191 if partition_key is None: 

192 part_key = self.default_partition_key 

193 else: 

194 part_key = partition_key 

195 

196 if part_key is None: 

197 raise AirflowBadRequest("Partition key must be specified") 

198 

199 return part_key 

200 

201 def does_collection_exist(self, collection_name: str, database_name: str) -> bool: 

202 """Check if a collection exists in CosmosDB.""" 

203 if collection_name is None: 

204 raise AirflowBadRequest("Collection name cannot be None.") 

205 

206 # The ignores below is due to typing bug in azure-cosmos 9.2.0 

207 # https://github.com/Azure/azure-sdk-for-python/issues/31811 

208 existing_container = list( 

209 self.get_conn() 

210 .get_database_client(self.__get_database_name(database_name)) 

211 .query_containers( 

212 "SELECT * FROM r WHERE r.id=@id", 

213 parameters=[{"name": "@id", "value": collection_name}], # type: ignore[list-item] 

214 ) 

215 ) 

216 if not existing_container: 

217 return False 

218 

219 return True 

220 

221 def create_collection( 

222 self, 

223 collection_name: str, 

224 database_name: str | None = None, 

225 partition_key: PartitionKeyType | None = None, 

226 ) -> None: 

227 """Create a new collection in the CosmosDB database.""" 

228 if collection_name is None: 

229 raise AirflowBadRequest("Collection name cannot be None.") 

230 

231 # We need to check to see if this container already exists so we don't try 

232 # to create it twice 

233 # The ignores below is due to typing bug in azure-cosmos 9.2.0 

234 # https://github.com/Azure/azure-sdk-for-python/issues/31811 

235 existing_container = list( 

236 self.get_conn() 

237 .get_database_client(self.__get_database_name(database_name)) 

238 .query_containers( 

239 "SELECT * FROM r WHERE r.id=@id", 

240 parameters=[{"name": "@id", "value": collection_name}], # type: ignore[list-item] 

241 ) 

242 ) 

243 

244 # Only create if we did not find it already existing 

245 if not existing_container: 

246 self.get_conn().get_database_client(self.__get_database_name(database_name)).create_container( 

247 collection_name, 

248 partition_key=PartitionKey(path=self.__get_partition_key(partition_key)), 

249 ) 

250 

251 def does_database_exist(self, database_name: str) -> bool: 

252 """Check if a database exists in CosmosDB.""" 

253 if database_name is None: 

254 raise AirflowBadRequest("Database name cannot be None.") 

255 

256 # The ignores below is due to typing bug in azure-cosmos 9.2.0 

257 # https://github.com/Azure/azure-sdk-for-python/issues/31811 

258 existing_database = list( 

259 self.get_conn().query_databases( 

260 "SELECT * FROM r WHERE r.id=@id", 

261 parameters=[{"name": "@id", "value": database_name}], # type: ignore[list-item] 

262 ) 

263 ) 

264 if not existing_database: 

265 return False 

266 

267 return True 

268 

269 def create_database(self, database_name: str) -> None: 

270 """Create a new database in CosmosDB.""" 

271 if database_name is None: 

272 raise AirflowBadRequest("Database name cannot be None.") 

273 

274 # We need to check to see if this database already exists so we don't try 

275 # to create it twice 

276 # The ignores below is due to typing bug in azure-cosmos 9.2.0 

277 # https://github.com/Azure/azure-sdk-for-python/issues/31811 

278 existing_database = list( 

279 self.get_conn().query_databases( 

280 "SELECT * FROM r WHERE r.id=@id", 

281 parameters=[{"name": "@id", "value": database_name}], # type: ignore[list-item] 

282 ) 

283 ) 

284 

285 # Only create if we did not find it already existing 

286 if not existing_database: 

287 self.get_conn().create_database(database_name) 

288 

289 def delete_database(self, database_name: str) -> None: 

290 """Delete an existing database in CosmosDB.""" 

291 if database_name is None: 

292 raise AirflowBadRequest("Database name cannot be None.") 

293 

294 self.get_conn().delete_database(database_name) 

295 

296 def delete_collection(self, collection_name: str, database_name: str | None = None) -> None: 

297 """Delete an existing collection in the CosmosDB database.""" 

298 if collection_name is None: 

299 raise AirflowBadRequest("Collection name cannot be None.") 

300 

301 self.get_conn().get_database_client(self.__get_database_name(database_name)).delete_container( 

302 collection_name 

303 ) 

304 

305 def upsert_document(self, document, database_name=None, collection_name=None, document_id=None): 

306 """Insert or update a document into an existing collection in the CosmosDB database.""" 

307 # Assign unique ID if one isn't provided 

308 if document_id is None: 

309 document_id = str(uuid.uuid4()) 

310 

311 if document is None: 

312 raise AirflowBadRequest("You cannot insert a None document") 

313 

314 # Add document id if isn't found 

315 if document.get("id") is None: 

316 document["id"] = document_id 

317 

318 created_document = ( 

319 self.get_conn() 

320 .get_database_client(self.__get_database_name(database_name)) 

321 .get_container_client(self.__get_collection_name(collection_name)) 

322 .upsert_item(document) 

323 ) 

324 

325 return created_document 

326 

327 def insert_documents( 

328 self, documents, database_name: str | None = None, collection_name: str | None = None 

329 ) -> list: 

330 """Insert a list of new documents into an existing collection in the CosmosDB database.""" 

331 if documents is None: 

332 raise AirflowBadRequest("You cannot insert empty documents") 

333 

334 created_documents = [] 

335 for single_document in documents: 

336 created_documents.append( 

337 self.get_conn() 

338 .get_database_client(self.__get_database_name(database_name)) 

339 .get_container_client(self.__get_collection_name(collection_name)) 

340 .create_item(single_document) 

341 ) 

342 

343 return created_documents 

344 

345 def delete_document( 

346 self, 

347 document_id: str, 

348 database_name: str | None = None, 

349 collection_name: str | None = None, 

350 partition_key: PartitionKeyType | None = None, 

351 ) -> None: 

352 """Delete an existing document out of a collection in the CosmosDB database.""" 

353 if document_id is None: 

354 raise AirflowBadRequest("Cannot delete a document without an id") 

355 ( 

356 self.get_conn() 

357 .get_database_client(self.__get_database_name(database_name)) 

358 .get_container_client(self.__get_collection_name(collection_name)) 

359 .delete_item(document_id, partition_key=self.__get_partition_key(partition_key)) 

360 ) 

361 

362 def get_document( 

363 self, 

364 document_id: str, 

365 database_name: str | None = None, 

366 collection_name: str | None = None, 

367 partition_key: PartitionKeyType | None = None, 

368 ): 

369 """Get a document from an existing collection in the CosmosDB database.""" 

370 if document_id is None: 

371 raise AirflowBadRequest("Cannot get a document without an id") 

372 

373 try: 

374 return ( 

375 self.get_conn() 

376 .get_database_client(self.__get_database_name(database_name)) 

377 .get_container_client(self.__get_collection_name(collection_name)) 

378 .read_item(document_id, partition_key=self.__get_partition_key(partition_key)) 

379 ) 

380 except CosmosHttpResponseError: 

381 return None 

382 

383 def get_documents( 

384 self, 

385 sql_string: str, 

386 database_name: str | None = None, 

387 collection_name: str | None = None, 

388 partition_key: PartitionKeyType | None = None, 

389 ) -> list | None: 

390 """Get a list of documents from an existing collection in the CosmosDB database via SQL query.""" 

391 if sql_string is None: 

392 raise AirflowBadRequest("SQL query string cannot be None") 

393 

394 try: 

395 result_iterable = ( 

396 self.get_conn() 

397 .get_database_client(self.__get_database_name(database_name)) 

398 .get_container_client(self.__get_collection_name(collection_name)) 

399 .query_items(sql_string, partition_key=self.__get_partition_key(partition_key)) 

400 ) 

401 return list(result_iterable) 

402 except CosmosHttpResponseError: 

403 return None 

404 

405 def test_connection(self): 

406 """Test a configured Azure Cosmos connection.""" 

407 try: 

408 # Attempt to list existing databases under the configured subscription and retrieve the first in 

409 # the returned iterator. The Azure Cosmos API does allow for creation of a 

410 # CosmosClient with incorrect values but then will fail properly once items are 

411 # retrieved using the client. We need to _actually_ try to retrieve an object to properly test the 

412 # connection. 

413 next(iter(self.get_conn().list_databases()), None) 

414 except Exception as e: 

415 return False, str(e) 

416 return True, "Successfully connected to Azure Cosmos." 

417 

418 

419def get_database_link(database_id: str) -> str: 

420 """Get Azure CosmosDB database link.""" 

421 return "dbs/" + database_id 

422 

423 

424def get_collection_link(database_id: str, collection_id: str) -> str: 

425 """Get Azure CosmosDB collection link.""" 

426 return get_database_link(database_id) + "/colls/" + collection_id 

427 

428 

429def get_document_link(database_id: str, collection_id: str, document_id: str) -> str: 

430 """Get Azure CosmosDB document link.""" 

431 return get_collection_link(database_id, collection_id) + "/docs/" + document_id