Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/providers/microsoft/azure/hooks/cosmos.py: 0%
142 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""
19This module contains integration with Azure CosmosDB.
21AzureCosmosDBHook communicates via the Azure Cosmos library. Make sure that a
22Airflow connection of type `azure_cosmos` exists. Authorization can be done by supplying a
23login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify
24the default database and collection to use (see connection `azure_cosmos_default` for an example).
25"""
26from __future__ import annotations
28import json
29import uuid
30from typing import Any
32from azure.cosmos.cosmos_client import CosmosClient
33from azure.cosmos.exceptions import CosmosHttpResponseError
35from airflow.exceptions import AirflowBadRequest
36from airflow.hooks.base import BaseHook
37from airflow.providers.microsoft.azure.utils import _ensure_prefixes, get_field
40class AzureCosmosDBHook(BaseHook):
41 """
42 Interacts with Azure CosmosDB.
44 login should be the endpoint uri, password should be the master key
45 optionally, you can use the following extras to default these values
46 {"database_name": "<DATABASE_NAME>", "collection_name": "COLLECTION_NAME"}.
48 :param azure_cosmos_conn_id: Reference to the
49 :ref:`Azure CosmosDB connection<howto/connection:azure_cosmos>`.
50 """
52 conn_name_attr = "azure_cosmos_conn_id"
53 default_conn_name = "azure_cosmos_default"
54 conn_type = "azure_cosmos"
55 hook_name = "Azure CosmosDB"
57 @staticmethod
58 def get_connection_form_widgets() -> dict[str, Any]:
59 """Returns connection widgets to add to connection form"""
60 from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
61 from flask_babel import lazy_gettext
62 from wtforms import StringField
64 return {
65 "database_name": StringField(
66 lazy_gettext("Cosmos Database Name (optional)"), widget=BS3TextFieldWidget()
67 ),
68 "collection_name": StringField(
69 lazy_gettext("Cosmos Collection Name (optional)"), widget=BS3TextFieldWidget()
70 ),
71 }
73 @staticmethod
74 @_ensure_prefixes(conn_type="azure_cosmos") # todo: remove when min airflow version >= 2.5
75 def get_ui_field_behaviour() -> dict[str, Any]:
76 """Returns custom field behaviour"""
77 return {
78 "hidden_fields": ["schema", "port", "host", "extra"],
79 "relabeling": {
80 "login": "Cosmos Endpoint URI",
81 "password": "Cosmos Master Key Token",
82 },
83 "placeholders": {
84 "login": "endpoint uri",
85 "password": "master key",
86 "database_name": "database name",
87 "collection_name": "collection name",
88 },
89 }
91 def __init__(self, azure_cosmos_conn_id: str = default_conn_name) -> None:
92 super().__init__()
93 self.conn_id = azure_cosmos_conn_id
94 self._conn: CosmosClient | None = None
96 self.default_database_name = None
97 self.default_collection_name = None
99 def _get_field(self, extras, name):
100 return get_field(
101 conn_id=self.conn_id,
102 conn_type=self.conn_type,
103 extras=extras,
104 field_name=name,
105 )
107 def get_conn(self) -> CosmosClient:
108 """Return a cosmos db client."""
109 if not self._conn:
110 conn = self.get_connection(self.conn_id)
111 extras = conn.extra_dejson
112 endpoint_uri = conn.login
113 master_key = conn.password
115 self.default_database_name = self._get_field(extras, "database_name")
116 self.default_collection_name = self._get_field(extras, "collection_name")
118 # Initialize the Python Azure Cosmos DB client
119 self._conn = CosmosClient(endpoint_uri, {"masterKey": master_key})
120 return self._conn
122 def __get_database_name(self, database_name: str | None = None) -> str:
123 self.get_conn()
124 db_name = database_name
125 if db_name is None:
126 db_name = self.default_database_name
128 if db_name is None:
129 raise AirflowBadRequest("Database name must be specified")
131 return db_name
133 def __get_collection_name(self, collection_name: str | None = None) -> str:
134 self.get_conn()
135 coll_name = collection_name
136 if coll_name is None:
137 coll_name = self.default_collection_name
139 if coll_name is None:
140 raise AirflowBadRequest("Collection name must be specified")
142 return coll_name
144 def does_collection_exist(self, collection_name: str, database_name: str) -> bool:
145 """Checks if a collection exists in CosmosDB."""
146 if collection_name is None:
147 raise AirflowBadRequest("Collection name cannot be None.")
149 existing_container = list(
150 self.get_conn()
151 .get_database_client(self.__get_database_name(database_name))
152 .query_containers(
153 "SELECT * FROM r WHERE r.id=@id",
154 parameters=[json.dumps({"name": "@id", "value": collection_name})],
155 )
156 )
157 if len(existing_container) == 0:
158 return False
160 return True
162 def create_collection(
163 self,
164 collection_name: str,
165 database_name: str | None = None,
166 partition_key: str | None = None,
167 ) -> None:
168 """Creates a new collection in the CosmosDB database."""
169 if collection_name is None:
170 raise AirflowBadRequest("Collection name cannot be None.")
172 # We need to check to see if this container already exists so we don't try
173 # to create it twice
174 existing_container = list(
175 self.get_conn()
176 .get_database_client(self.__get_database_name(database_name))
177 .query_containers(
178 "SELECT * FROM r WHERE r.id=@id",
179 parameters=[json.dumps({"name": "@id", "value": collection_name})],
180 )
181 )
183 # Only create if we did not find it already existing
184 if len(existing_container) == 0:
185 self.get_conn().get_database_client(self.__get_database_name(database_name)).create_container(
186 collection_name, partition_key=partition_key
187 )
189 def does_database_exist(self, database_name: str) -> bool:
190 """Checks if a database exists in CosmosDB."""
191 if database_name is None:
192 raise AirflowBadRequest("Database name cannot be None.")
194 existing_database = list(
195 self.get_conn().query_databases(
196 "SELECT * FROM r WHERE r.id=@id",
197 parameters=[json.dumps({"name": "@id", "value": database_name})],
198 )
199 )
200 if len(existing_database) == 0:
201 return False
203 return True
205 def create_database(self, database_name: str) -> None:
206 """Creates a new database in CosmosDB."""
207 if database_name is None:
208 raise AirflowBadRequest("Database name cannot be None.")
210 # We need to check to see if this database already exists so we don't try
211 # to create it twice
212 existing_database = list(
213 self.get_conn().query_databases(
214 "SELECT * FROM r WHERE r.id=@id",
215 parameters=[json.dumps({"name": "@id", "value": database_name})],
216 )
217 )
219 # Only create if we did not find it already existing
220 if len(existing_database) == 0:
221 self.get_conn().create_database(database_name)
223 def delete_database(self, database_name: str) -> None:
224 """Deletes an existing database in CosmosDB."""
225 if database_name is None:
226 raise AirflowBadRequest("Database name cannot be None.")
228 self.get_conn().delete_database(database_name)
230 def delete_collection(self, collection_name: str, database_name: str | None = None) -> None:
231 """Deletes an existing collection in the CosmosDB database."""
232 if collection_name is None:
233 raise AirflowBadRequest("Collection name cannot be None.")
235 self.get_conn().get_database_client(self.__get_database_name(database_name)).delete_container(
236 collection_name
237 )
239 def upsert_document(self, document, database_name=None, collection_name=None, document_id=None):
240 """
241 Inserts a new document (or updates an existing one) into an existing
242 collection in the CosmosDB database.
243 """
244 # Assign unique ID if one isn't provided
245 if document_id is None:
246 document_id = str(uuid.uuid4())
248 if document is None:
249 raise AirflowBadRequest("You cannot insert a None document")
251 # Add document id if isn't found
252 if "id" in document:
253 if document["id"] is None:
254 document["id"] = document_id
255 else:
256 document["id"] = document_id
258 created_document = (
259 self.get_conn()
260 .get_database_client(self.__get_database_name(database_name))
261 .get_container_client(self.__get_collection_name(collection_name))
262 .upsert_item(document)
263 )
265 return created_document
267 def insert_documents(
268 self, documents, database_name: str | None = None, collection_name: str | None = None
269 ) -> list:
270 """Insert a list of new documents into an existing collection in the CosmosDB database."""
271 if documents is None:
272 raise AirflowBadRequest("You cannot insert empty documents")
274 created_documents = []
275 for single_document in documents:
276 created_documents.append(
277 self.get_conn()
278 .get_database_client(self.__get_database_name(database_name))
279 .get_container_client(self.__get_collection_name(collection_name))
280 .create_item(single_document)
281 )
283 return created_documents
285 def delete_document(
286 self,
287 document_id: str,
288 database_name: str | None = None,
289 collection_name: str | None = None,
290 partition_key: str | None = None,
291 ) -> None:
292 """Delete an existing document out of a collection in the CosmosDB database."""
293 if document_id is None:
294 raise AirflowBadRequest("Cannot delete a document without an id")
295 (
296 self.get_conn()
297 .get_database_client(self.__get_database_name(database_name))
298 .get_container_client(self.__get_collection_name(collection_name))
299 .delete_item(document_id, partition_key=partition_key)
300 )
302 def get_document(
303 self,
304 document_id: str,
305 database_name: str | None = None,
306 collection_name: str | None = None,
307 partition_key: str | None = None,
308 ):
309 """Get a document from an existing collection in the CosmosDB database."""
310 if document_id is None:
311 raise AirflowBadRequest("Cannot get a document without an id")
313 try:
314 return (
315 self.get_conn()
316 .get_database_client(self.__get_database_name(database_name))
317 .get_container_client(self.__get_collection_name(collection_name))
318 .read_item(document_id, partition_key=partition_key)
319 )
320 except CosmosHttpResponseError:
321 return None
323 def get_documents(
324 self,
325 sql_string: str,
326 database_name: str | None = None,
327 collection_name: str | None = None,
328 partition_key: str | None = None,
329 ) -> list | None:
330 """Get a list of documents from an existing collection in the CosmosDB database via SQL query."""
331 if sql_string is None:
332 raise AirflowBadRequest("SQL query string cannot be None")
334 try:
335 result_iterable = (
336 self.get_conn()
337 .get_database_client(self.__get_database_name(database_name))
338 .get_container_client(self.__get_collection_name(collection_name))
339 .query_items(sql_string, partition_key=partition_key)
340 )
341 return list(result_iterable)
342 except CosmosHttpResponseError:
343 return None
345 def test_connection(self):
346 """Test a configured Azure Cosmos connection."""
347 try:
348 # Attempt to list existing databases under the configured subscription and retrieve the first in
349 # the returned iterator. The Azure Cosmos API does allow for creation of a
350 # CosmosClient with incorrect values but then will fail properly once items are
351 # retrieved using the client. We need to _actually_ try to retrieve an object to properly test the
352 # connection.
353 next(iter(self.get_conn().list_databases()), None)
354 except Exception as e:
355 return False, str(e)
356 return True, "Successfully connected to Azure Cosmos."
359def get_database_link(database_id: str) -> str:
360 """Get Azure CosmosDB database link"""
361 return "dbs/" + database_id
364def get_collection_link(database_id: str, collection_id: str) -> str:
365 """Get Azure CosmosDB collection link"""
366 return get_database_link(database_id) + "/colls/" + collection_id
369def get_document_link(database_id: str, collection_id: str, document_id: str) -> str:
370 """Get Azure CosmosDB document link"""
371 return get_collection_link(database_id, collection_id) + "/docs/" + document_id