1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""
19This module contains integration with Azure CosmosDB.
20
21AzureCosmosDBHook communicates via the Azure Cosmos library. Make sure that a
22Airflow connection of type `azure_cosmos` exists. Authorization can be done by supplying a
23login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify
24the default database and collection to use (see connection `azure_cosmos_default` for an example).
25"""
26
27from __future__ import annotations
28
29import uuid
30from typing import TYPE_CHECKING, Any, List, Union
31from urllib.parse import urlparse
32
33from azure.cosmos import PartitionKey
34from azure.cosmos.cosmos_client import CosmosClient
35from azure.cosmos.exceptions import CosmosHttpResponseError
36from azure.mgmt.cosmosdb import CosmosDBManagementClient
37
38from airflow.exceptions import AirflowBadRequest, AirflowException
39from airflow.hooks.base import BaseHook
40from airflow.providers.microsoft.azure.utils import (
41 add_managed_identity_connection_widgets,
42 get_field,
43 get_sync_default_azure_credential,
44)
45
46if TYPE_CHECKING:
47 PartitionKeyType = Union[str, List[str]]
48
49
50class AzureCosmosDBHook(BaseHook):
51 """
52 Interact with Azure CosmosDB.
53
54 login should be the endpoint uri, password should be the master key
55 optionally, you can use the following extras to default these values
56 {"database_name": "<DATABASE_NAME>", "collection_name": "COLLECTION_NAME"}.
57
58 :param azure_cosmos_conn_id: Reference to the
59 :ref:`Azure CosmosDB connection<howto/connection:azure_cosmos>`.
60 """
61
62 conn_name_attr = "azure_cosmos_conn_id"
63 default_conn_name = "azure_cosmos_default"
64 conn_type = "azure_cosmos"
65 hook_name = "Azure CosmosDB"
66
67 @classmethod
68 @add_managed_identity_connection_widgets
69 def get_connection_form_widgets(cls) -> dict[str, Any]:
70 """Return connection widgets to add to connection form."""
71 from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
72 from flask_babel import lazy_gettext
73 from wtforms import StringField
74
75 return {
76 "database_name": StringField(
77 lazy_gettext("Cosmos Database Name (optional)"), widget=BS3TextFieldWidget()
78 ),
79 "collection_name": StringField(
80 lazy_gettext("Cosmos Collection Name (optional)"), widget=BS3TextFieldWidget()
81 ),
82 "subscription_id": StringField(
83 lazy_gettext("Subscription ID (optional)"),
84 widget=BS3TextFieldWidget(),
85 ),
86 "resource_group_name": StringField(
87 lazy_gettext("Resource Group Name (optional)"),
88 widget=BS3TextFieldWidget(),
89 ),
90 }
91
92 @classmethod
93 def get_ui_field_behaviour(cls) -> dict[str, Any]:
94 """Return custom field behaviour."""
95 return {
96 "hidden_fields": ["schema", "port", "host", "extra"],
97 "relabeling": {
98 "login": "Cosmos Endpoint URI",
99 "password": "Cosmos Master Key Token",
100 },
101 "placeholders": {
102 "login": "endpoint uri",
103 "password": "master key (not needed for Azure AD authentication)",
104 "database_name": "database name",
105 "collection_name": "collection name",
106 "subscription_id": "Subscription ID (required for Azure AD authentication)",
107 "resource_group_name": "Resource Group Name (required for Azure AD authentication)",
108 },
109 }
110
111 def __init__(self, azure_cosmos_conn_id: str = default_conn_name) -> None:
112 super().__init__()
113 self.conn_id = azure_cosmos_conn_id
114 self._conn: CosmosClient | None = None
115
116 self.default_database_name = None
117 self.default_collection_name = None
118 self.default_partition_key = None
119
120 def _get_field(self, extras, name):
121 return get_field(
122 conn_id=self.conn_id,
123 conn_type=self.conn_type,
124 extras=extras,
125 field_name=name,
126 )
127
128 def get_conn(self) -> CosmosClient:
129 """Return a cosmos db client."""
130 if not self._conn:
131 conn = self.get_connection(self.conn_id)
132 extras = conn.extra_dejson
133 endpoint_uri = conn.login
134 resource_group_name = self._get_field(extras, "resource_group_name")
135
136 if conn.password:
137 master_key = conn.password
138 elif resource_group_name:
139 managed_identity_client_id = self._get_field(extras, "managed_identity_client_id")
140 workload_identity_tenant_id = self._get_field(extras, "workload_identity_tenant_id")
141 subscritption_id = self._get_field(extras, "subscription_id")
142 credential = get_sync_default_azure_credential(
143 managed_identity_client_id=managed_identity_client_id,
144 workload_identity_tenant_id=workload_identity_tenant_id,
145 )
146 management_client = CosmosDBManagementClient(
147 credential=credential,
148 subscription_id=subscritption_id,
149 )
150
151 database_account = urlparse(conn.login).netloc.split(".")[0]
152 database_account_keys = management_client.database_accounts.list_keys(
153 resource_group_name, database_account
154 )
155 master_key = database_account_keys.primary_master_key
156 else:
157 raise AirflowException("Either password or resource_group_name is required")
158
159 self.default_database_name = self._get_field(extras, "database_name")
160 self.default_collection_name = self._get_field(extras, "collection_name")
161 self.default_partition_key = self._get_field(extras, "partition_key")
162
163 # Initialize the Python Azure Cosmos DB client
164 self._conn = CosmosClient(endpoint_uri, {"masterKey": master_key})
165 return self._conn
166
167 def __get_database_name(self, database_name: str | None = None) -> str:
168 self.get_conn()
169 db_name = database_name
170 if db_name is None:
171 db_name = self.default_database_name
172
173 if db_name is None:
174 raise AirflowBadRequest("Database name must be specified")
175
176 return db_name
177
178 def __get_collection_name(self, collection_name: str | None = None) -> str:
179 self.get_conn()
180 coll_name = collection_name
181 if coll_name is None:
182 coll_name = self.default_collection_name
183
184 if coll_name is None:
185 raise AirflowBadRequest("Collection name must be specified")
186
187 return coll_name
188
189 def __get_partition_key(self, partition_key: PartitionKeyType | None = None) -> PartitionKeyType:
190 self.get_conn()
191 if partition_key is None:
192 part_key = self.default_partition_key
193 else:
194 part_key = partition_key
195
196 if part_key is None:
197 raise AirflowBadRequest("Partition key must be specified")
198
199 return part_key
200
201 def does_collection_exist(self, collection_name: str, database_name: str) -> bool:
202 """Check if a collection exists in CosmosDB."""
203 if collection_name is None:
204 raise AirflowBadRequest("Collection name cannot be None.")
205
206 # The ignores below is due to typing bug in azure-cosmos 9.2.0
207 # https://github.com/Azure/azure-sdk-for-python/issues/31811
208 existing_container = list(
209 self.get_conn()
210 .get_database_client(self.__get_database_name(database_name))
211 .query_containers(
212 "SELECT * FROM r WHERE r.id=@id",
213 parameters=[{"name": "@id", "value": collection_name}], # type: ignore[list-item]
214 )
215 )
216 if not existing_container:
217 return False
218
219 return True
220
221 def create_collection(
222 self,
223 collection_name: str,
224 database_name: str | None = None,
225 partition_key: PartitionKeyType | None = None,
226 ) -> None:
227 """Create a new collection in the CosmosDB database."""
228 if collection_name is None:
229 raise AirflowBadRequest("Collection name cannot be None.")
230
231 # We need to check to see if this container already exists so we don't try
232 # to create it twice
233 # The ignores below is due to typing bug in azure-cosmos 9.2.0
234 # https://github.com/Azure/azure-sdk-for-python/issues/31811
235 existing_container = list(
236 self.get_conn()
237 .get_database_client(self.__get_database_name(database_name))
238 .query_containers(
239 "SELECT * FROM r WHERE r.id=@id",
240 parameters=[{"name": "@id", "value": collection_name}], # type: ignore[list-item]
241 )
242 )
243
244 # Only create if we did not find it already existing
245 if not existing_container:
246 self.get_conn().get_database_client(self.__get_database_name(database_name)).create_container(
247 collection_name,
248 partition_key=PartitionKey(path=self.__get_partition_key(partition_key)),
249 )
250
251 def does_database_exist(self, database_name: str) -> bool:
252 """Check if a database exists in CosmosDB."""
253 if database_name is None:
254 raise AirflowBadRequest("Database name cannot be None.")
255
256 # The ignores below is due to typing bug in azure-cosmos 9.2.0
257 # https://github.com/Azure/azure-sdk-for-python/issues/31811
258 existing_database = list(
259 self.get_conn().query_databases(
260 "SELECT * FROM r WHERE r.id=@id",
261 parameters=[{"name": "@id", "value": database_name}], # type: ignore[list-item]
262 )
263 )
264 if not existing_database:
265 return False
266
267 return True
268
269 def create_database(self, database_name: str) -> None:
270 """Create a new database in CosmosDB."""
271 if database_name is None:
272 raise AirflowBadRequest("Database name cannot be None.")
273
274 # We need to check to see if this database already exists so we don't try
275 # to create it twice
276 # The ignores below is due to typing bug in azure-cosmos 9.2.0
277 # https://github.com/Azure/azure-sdk-for-python/issues/31811
278 existing_database = list(
279 self.get_conn().query_databases(
280 "SELECT * FROM r WHERE r.id=@id",
281 parameters=[{"name": "@id", "value": database_name}], # type: ignore[list-item]
282 )
283 )
284
285 # Only create if we did not find it already existing
286 if not existing_database:
287 self.get_conn().create_database(database_name)
288
289 def delete_database(self, database_name: str) -> None:
290 """Delete an existing database in CosmosDB."""
291 if database_name is None:
292 raise AirflowBadRequest("Database name cannot be None.")
293
294 self.get_conn().delete_database(database_name)
295
296 def delete_collection(self, collection_name: str, database_name: str | None = None) -> None:
297 """Delete an existing collection in the CosmosDB database."""
298 if collection_name is None:
299 raise AirflowBadRequest("Collection name cannot be None.")
300
301 self.get_conn().get_database_client(self.__get_database_name(database_name)).delete_container(
302 collection_name
303 )
304
305 def upsert_document(self, document, database_name=None, collection_name=None, document_id=None):
306 """Insert or update a document into an existing collection in the CosmosDB database."""
307 # Assign unique ID if one isn't provided
308 if document_id is None:
309 document_id = str(uuid.uuid4())
310
311 if document is None:
312 raise AirflowBadRequest("You cannot insert a None document")
313
314 # Add document id if isn't found
315 if document.get("id") is None:
316 document["id"] = document_id
317
318 created_document = (
319 self.get_conn()
320 .get_database_client(self.__get_database_name(database_name))
321 .get_container_client(self.__get_collection_name(collection_name))
322 .upsert_item(document)
323 )
324
325 return created_document
326
327 def insert_documents(
328 self, documents, database_name: str | None = None, collection_name: str | None = None
329 ) -> list:
330 """Insert a list of new documents into an existing collection in the CosmosDB database."""
331 if documents is None:
332 raise AirflowBadRequest("You cannot insert empty documents")
333
334 created_documents = []
335 for single_document in documents:
336 created_documents.append(
337 self.get_conn()
338 .get_database_client(self.__get_database_name(database_name))
339 .get_container_client(self.__get_collection_name(collection_name))
340 .create_item(single_document)
341 )
342
343 return created_documents
344
345 def delete_document(
346 self,
347 document_id: str,
348 database_name: str | None = None,
349 collection_name: str | None = None,
350 partition_key: PartitionKeyType | None = None,
351 ) -> None:
352 """Delete an existing document out of a collection in the CosmosDB database."""
353 if document_id is None:
354 raise AirflowBadRequest("Cannot delete a document without an id")
355 (
356 self.get_conn()
357 .get_database_client(self.__get_database_name(database_name))
358 .get_container_client(self.__get_collection_name(collection_name))
359 .delete_item(document_id, partition_key=self.__get_partition_key(partition_key))
360 )
361
362 def get_document(
363 self,
364 document_id: str,
365 database_name: str | None = None,
366 collection_name: str | None = None,
367 partition_key: PartitionKeyType | None = None,
368 ):
369 """Get a document from an existing collection in the CosmosDB database."""
370 if document_id is None:
371 raise AirflowBadRequest("Cannot get a document without an id")
372
373 try:
374 return (
375 self.get_conn()
376 .get_database_client(self.__get_database_name(database_name))
377 .get_container_client(self.__get_collection_name(collection_name))
378 .read_item(document_id, partition_key=self.__get_partition_key(partition_key))
379 )
380 except CosmosHttpResponseError:
381 return None
382
383 def get_documents(
384 self,
385 sql_string: str,
386 database_name: str | None = None,
387 collection_name: str | None = None,
388 partition_key: PartitionKeyType | None = None,
389 ) -> list | None:
390 """Get a list of documents from an existing collection in the CosmosDB database via SQL query."""
391 if sql_string is None:
392 raise AirflowBadRequest("SQL query string cannot be None")
393
394 try:
395 result_iterable = (
396 self.get_conn()
397 .get_database_client(self.__get_database_name(database_name))
398 .get_container_client(self.__get_collection_name(collection_name))
399 .query_items(sql_string, partition_key=self.__get_partition_key(partition_key))
400 )
401 return list(result_iterable)
402 except CosmosHttpResponseError:
403 return None
404
405 def test_connection(self):
406 """Test a configured Azure Cosmos connection."""
407 try:
408 # Attempt to list existing databases under the configured subscription and retrieve the first in
409 # the returned iterator. The Azure Cosmos API does allow for creation of a
410 # CosmosClient with incorrect values but then will fail properly once items are
411 # retrieved using the client. We need to _actually_ try to retrieve an object to properly test the
412 # connection.
413 next(iter(self.get_conn().list_databases()), None)
414 except Exception as e:
415 return False, str(e)
416 return True, "Successfully connected to Azure Cosmos."
417
418
419def get_database_link(database_id: str) -> str:
420 """Get Azure CosmosDB database link."""
421 return "dbs/" + database_id
422
423
424def get_collection_link(database_id: str, collection_id: str) -> str:
425 """Get Azure CosmosDB collection link."""
426 return get_database_link(database_id) + "/colls/" + collection_id
427
428
429def get_document_link(database_id: str, collection_id: str, document_id: str) -> str:
430 """Get Azure CosmosDB document link."""
431 return get_collection_link(database_id, collection_id) + "/docs/" + document_id