Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/providers/microsoft/azure/hooks/cosmos.py: 0%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""
19This module contains integration with Azure CosmosDB.
21AzureCosmosDBHook communicates via the Azure Cosmos library. Make sure that a
22Airflow connection of type `azure_cosmos` exists. Authorization can be done by supplying a
23login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify
24the default database and collection to use (see connection `azure_cosmos_default` for an example).
25"""
27from __future__ import annotations
29import uuid
30from typing import TYPE_CHECKING, Any, List, Union
31from urllib.parse import urlparse
33from azure.cosmos import PartitionKey
34from azure.cosmos.cosmos_client import CosmosClient
35from azure.cosmos.exceptions import CosmosHttpResponseError
36from azure.mgmt.cosmosdb import CosmosDBManagementClient
38from airflow.exceptions import AirflowBadRequest, AirflowException
39from airflow.hooks.base import BaseHook
40from airflow.providers.microsoft.azure.utils import (
41 add_managed_identity_connection_widgets,
42 get_field,
43 get_sync_default_azure_credential,
44)
46if TYPE_CHECKING:
47 PartitionKeyType = Union[str, List[str]]
50class AzureCosmosDBHook(BaseHook):
51 """
52 Interact with Azure CosmosDB.
54 login should be the endpoint uri, password should be the master key
55 optionally, you can use the following extras to default these values
56 {"database_name": "<DATABASE_NAME>", "collection_name": "COLLECTION_NAME"}.
58 :param azure_cosmos_conn_id: Reference to the
59 :ref:`Azure CosmosDB connection<howto/connection:azure_cosmos>`.
60 """
62 conn_name_attr = "azure_cosmos_conn_id"
63 default_conn_name = "azure_cosmos_default"
64 conn_type = "azure_cosmos"
65 hook_name = "Azure CosmosDB"
67 @classmethod
68 @add_managed_identity_connection_widgets
69 def get_connection_form_widgets(cls) -> dict[str, Any]:
70 """Return connection widgets to add to connection form."""
71 from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
72 from flask_babel import lazy_gettext
73 from wtforms import StringField
75 return {
76 "database_name": StringField(
77 lazy_gettext("Cosmos Database Name (optional)"), widget=BS3TextFieldWidget()
78 ),
79 "collection_name": StringField(
80 lazy_gettext("Cosmos Collection Name (optional)"), widget=BS3TextFieldWidget()
81 ),
82 "subscription_id": StringField(
83 lazy_gettext("Subscription ID (optional)"),
84 widget=BS3TextFieldWidget(),
85 ),
86 "resource_group_name": StringField(
87 lazy_gettext("Resource Group Name (optional)"),
88 widget=BS3TextFieldWidget(),
89 ),
90 }
92 @classmethod
93 def get_ui_field_behaviour(cls) -> dict[str, Any]:
94 """Return custom field behaviour."""
95 return {
96 "hidden_fields": ["schema", "port", "host", "extra"],
97 "relabeling": {
98 "login": "Cosmos Endpoint URI",
99 "password": "Cosmos Master Key Token",
100 },
101 "placeholders": {
102 "login": "endpoint uri",
103 "password": "master key (not needed for Azure AD authentication)",
104 "database_name": "database name",
105 "collection_name": "collection name",
106 "subscription_id": "Subscription ID (required for Azure AD authentication)",
107 "resource_group_name": "Resource Group Name (required for Azure AD authentication)",
108 },
109 }
111 def __init__(self, azure_cosmos_conn_id: str = default_conn_name) -> None:
112 super().__init__()
113 self.conn_id = azure_cosmos_conn_id
114 self._conn: CosmosClient | None = None
116 self.default_database_name = None
117 self.default_collection_name = None
118 self.default_partition_key = None
120 def _get_field(self, extras, name):
121 return get_field(
122 conn_id=self.conn_id,
123 conn_type=self.conn_type,
124 extras=extras,
125 field_name=name,
126 )
128 def get_conn(self) -> CosmosClient:
129 """Return a cosmos db client."""
130 if not self._conn:
131 conn = self.get_connection(self.conn_id)
132 extras = conn.extra_dejson
133 endpoint_uri = conn.login
134 resource_group_name = self._get_field(extras, "resource_group_name")
136 if conn.password:
137 master_key = conn.password
138 elif resource_group_name:
139 managed_identity_client_id = self._get_field(extras, "managed_identity_client_id")
140 workload_identity_tenant_id = self._get_field(extras, "workload_identity_tenant_id")
141 subscritption_id = self._get_field(extras, "subscription_id")
142 credential = get_sync_default_azure_credential(
143 managed_identity_client_id=managed_identity_client_id,
144 workload_identity_tenant_id=workload_identity_tenant_id,
145 )
146 management_client = CosmosDBManagementClient(
147 credential=credential,
148 subscription_id=subscritption_id,
149 )
151 database_account = urlparse(conn.login).netloc.split(".")[0]
152 database_account_keys = management_client.database_accounts.list_keys(
153 resource_group_name, database_account
154 )
155 master_key = database_account_keys.primary_master_key
156 else:
157 raise AirflowException("Either password or resource_group_name is required")
159 self.default_database_name = self._get_field(extras, "database_name")
160 self.default_collection_name = self._get_field(extras, "collection_name")
161 self.default_partition_key = self._get_field(extras, "partition_key")
163 # Initialize the Python Azure Cosmos DB client
164 self._conn = CosmosClient(endpoint_uri, {"masterKey": master_key})
165 return self._conn
167 def __get_database_name(self, database_name: str | None = None) -> str:
168 self.get_conn()
169 db_name = database_name
170 if db_name is None:
171 db_name = self.default_database_name
173 if db_name is None:
174 raise AirflowBadRequest("Database name must be specified")
176 return db_name
178 def __get_collection_name(self, collection_name: str | None = None) -> str:
179 self.get_conn()
180 coll_name = collection_name
181 if coll_name is None:
182 coll_name = self.default_collection_name
184 if coll_name is None:
185 raise AirflowBadRequest("Collection name must be specified")
187 return coll_name
189 def __get_partition_key(self, partition_key: PartitionKeyType | None = None) -> PartitionKeyType:
190 self.get_conn()
191 if partition_key is None:
192 part_key = self.default_partition_key
193 else:
194 part_key = partition_key
196 if part_key is None:
197 raise AirflowBadRequest("Partition key must be specified")
199 return part_key
201 def does_collection_exist(self, collection_name: str, database_name: str) -> bool:
202 """Check if a collection exists in CosmosDB."""
203 if collection_name is None:
204 raise AirflowBadRequest("Collection name cannot be None.")
206 # The ignores below is due to typing bug in azure-cosmos 9.2.0
207 # https://github.com/Azure/azure-sdk-for-python/issues/31811
208 existing_container = list(
209 self.get_conn()
210 .get_database_client(self.__get_database_name(database_name))
211 .query_containers(
212 "SELECT * FROM r WHERE r.id=@id",
213 parameters=[{"name": "@id", "value": collection_name}], # type: ignore[list-item]
214 )
215 )
216 if not existing_container:
217 return False
219 return True
221 def create_collection(
222 self,
223 collection_name: str,
224 database_name: str | None = None,
225 partition_key: PartitionKeyType | None = None,
226 ) -> None:
227 """Create a new collection in the CosmosDB database."""
228 if collection_name is None:
229 raise AirflowBadRequest("Collection name cannot be None.")
231 # We need to check to see if this container already exists so we don't try
232 # to create it twice
233 # The ignores below is due to typing bug in azure-cosmos 9.2.0
234 # https://github.com/Azure/azure-sdk-for-python/issues/31811
235 existing_container = list(
236 self.get_conn()
237 .get_database_client(self.__get_database_name(database_name))
238 .query_containers(
239 "SELECT * FROM r WHERE r.id=@id",
240 parameters=[{"name": "@id", "value": collection_name}], # type: ignore[list-item]
241 )
242 )
244 # Only create if we did not find it already existing
245 if not existing_container:
246 self.get_conn().get_database_client(self.__get_database_name(database_name)).create_container(
247 collection_name,
248 partition_key=PartitionKey(path=self.__get_partition_key(partition_key)),
249 )
251 def does_database_exist(self, database_name: str) -> bool:
252 """Check if a database exists in CosmosDB."""
253 if database_name is None:
254 raise AirflowBadRequest("Database name cannot be None.")
256 # The ignores below is due to typing bug in azure-cosmos 9.2.0
257 # https://github.com/Azure/azure-sdk-for-python/issues/31811
258 existing_database = list(
259 self.get_conn().query_databases(
260 "SELECT * FROM r WHERE r.id=@id",
261 parameters=[{"name": "@id", "value": database_name}], # type: ignore[list-item]
262 )
263 )
264 if not existing_database:
265 return False
267 return True
269 def create_database(self, database_name: str) -> None:
270 """Create a new database in CosmosDB."""
271 if database_name is None:
272 raise AirflowBadRequest("Database name cannot be None.")
274 # We need to check to see if this database already exists so we don't try
275 # to create it twice
276 # The ignores below is due to typing bug in azure-cosmos 9.2.0
277 # https://github.com/Azure/azure-sdk-for-python/issues/31811
278 existing_database = list(
279 self.get_conn().query_databases(
280 "SELECT * FROM r WHERE r.id=@id",
281 parameters=[{"name": "@id", "value": database_name}], # type: ignore[list-item]
282 )
283 )
285 # Only create if we did not find it already existing
286 if not existing_database:
287 self.get_conn().create_database(database_name)
289 def delete_database(self, database_name: str) -> None:
290 """Delete an existing database in CosmosDB."""
291 if database_name is None:
292 raise AirflowBadRequest("Database name cannot be None.")
294 self.get_conn().delete_database(database_name)
296 def delete_collection(self, collection_name: str, database_name: str | None = None) -> None:
297 """Delete an existing collection in the CosmosDB database."""
298 if collection_name is None:
299 raise AirflowBadRequest("Collection name cannot be None.")
301 self.get_conn().get_database_client(self.__get_database_name(database_name)).delete_container(
302 collection_name
303 )
305 def upsert_document(self, document, database_name=None, collection_name=None, document_id=None):
306 """Insert or update a document into an existing collection in the CosmosDB database."""
307 # Assign unique ID if one isn't provided
308 if document_id is None:
309 document_id = str(uuid.uuid4())
311 if document is None:
312 raise AirflowBadRequest("You cannot insert a None document")
314 # Add document id if isn't found
315 if document.get("id") is None:
316 document["id"] = document_id
318 created_document = (
319 self.get_conn()
320 .get_database_client(self.__get_database_name(database_name))
321 .get_container_client(self.__get_collection_name(collection_name))
322 .upsert_item(document)
323 )
325 return created_document
327 def insert_documents(
328 self, documents, database_name: str | None = None, collection_name: str | None = None
329 ) -> list:
330 """Insert a list of new documents into an existing collection in the CosmosDB database."""
331 if documents is None:
332 raise AirflowBadRequest("You cannot insert empty documents")
334 created_documents = []
335 for single_document in documents:
336 created_documents.append(
337 self.get_conn()
338 .get_database_client(self.__get_database_name(database_name))
339 .get_container_client(self.__get_collection_name(collection_name))
340 .create_item(single_document)
341 )
343 return created_documents
345 def delete_document(
346 self,
347 document_id: str,
348 database_name: str | None = None,
349 collection_name: str | None = None,
350 partition_key: PartitionKeyType | None = None,
351 ) -> None:
352 """Delete an existing document out of a collection in the CosmosDB database."""
353 if document_id is None:
354 raise AirflowBadRequest("Cannot delete a document without an id")
355 (
356 self.get_conn()
357 .get_database_client(self.__get_database_name(database_name))
358 .get_container_client(self.__get_collection_name(collection_name))
359 .delete_item(document_id, partition_key=self.__get_partition_key(partition_key))
360 )
362 def get_document(
363 self,
364 document_id: str,
365 database_name: str | None = None,
366 collection_name: str | None = None,
367 partition_key: PartitionKeyType | None = None,
368 ):
369 """Get a document from an existing collection in the CosmosDB database."""
370 if document_id is None:
371 raise AirflowBadRequest("Cannot get a document without an id")
373 try:
374 return (
375 self.get_conn()
376 .get_database_client(self.__get_database_name(database_name))
377 .get_container_client(self.__get_collection_name(collection_name))
378 .read_item(document_id, partition_key=self.__get_partition_key(partition_key))
379 )
380 except CosmosHttpResponseError:
381 return None
383 def get_documents(
384 self,
385 sql_string: str,
386 database_name: str | None = None,
387 collection_name: str | None = None,
388 partition_key: PartitionKeyType | None = None,
389 ) -> list | None:
390 """Get a list of documents from an existing collection in the CosmosDB database via SQL query."""
391 if sql_string is None:
392 raise AirflowBadRequest("SQL query string cannot be None")
394 try:
395 result_iterable = (
396 self.get_conn()
397 .get_database_client(self.__get_database_name(database_name))
398 .get_container_client(self.__get_collection_name(collection_name))
399 .query_items(sql_string, partition_key=self.__get_partition_key(partition_key))
400 )
401 return list(result_iterable)
402 except CosmosHttpResponseError:
403 return None
405 def test_connection(self):
406 """Test a configured Azure Cosmos connection."""
407 try:
408 # Attempt to list existing databases under the configured subscription and retrieve the first in
409 # the returned iterator. The Azure Cosmos API does allow for creation of a
410 # CosmosClient with incorrect values but then will fail properly once items are
411 # retrieved using the client. We need to _actually_ try to retrieve an object to properly test the
412 # connection.
413 next(iter(self.get_conn().list_databases()), None)
414 except Exception as e:
415 return False, str(e)
416 return True, "Successfully connected to Azure Cosmos."
419def get_database_link(database_id: str) -> str:
420 """Get Azure CosmosDB database link."""
421 return "dbs/" + database_id
424def get_collection_link(database_id: str, collection_id: str) -> str:
425 """Get Azure CosmosDB collection link."""
426 return get_database_link(database_id) + "/colls/" + collection_id
429def get_document_link(database_id: str, collection_id: str, document_id: str) -> str:
430 """Get Azure CosmosDB document link."""
431 return get_collection_link(database_id, collection_id) + "/docs/" + document_id