Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py: 0%
141 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import json
21import logging
22import uuid
23from unittest import mock
24from unittest.mock import PropertyMock
26import pytest
27from azure.cosmos.cosmos_client import CosmosClient
29from airflow.exceptions import AirflowException
30from airflow.models import Connection
31from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook
32from airflow.utils import db
33from tests.test_utils.providers import get_provider_min_airflow_version
36class TestAzureCosmosDbHook:
38 # Set up an environment to test with
39 def setup_method(self):
40 # set up some test variables
41 self.test_end_point = "https://test_endpoint:443"
42 self.test_master_key = "magic_test_key"
43 self.test_database_name = "test_database_name"
44 self.test_collection_name = "test_collection_name"
45 self.test_database_default = "test_database_default"
46 self.test_collection_default = "test_collection_default"
47 db.merge_conn(
48 Connection(
49 conn_id="azure_cosmos_test_key_id",
50 conn_type="azure_cosmos",
51 login=self.test_end_point,
52 password=self.test_master_key,
53 extra=json.dumps(
54 {
55 "database_name": self.test_database_default,
56 "collection_name": self.test_collection_default,
57 }
58 ),
59 )
60 )
62 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient", autospec=True)
63 def test_client(self, mock_cosmos):
64 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
65 assert hook._conn is None
66 assert isinstance(hook.get_conn(), CosmosClient)
68 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
69 def test_create_database(self, mock_cosmos):
70 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
71 hook.create_database(self.test_database_name)
72 expected_calls = [mock.call().create_database("test_database_name")]
73 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
74 mock_cosmos.assert_has_calls(expected_calls)
76 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
77 def test_create_database_exception(self, mock_cosmos):
78 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
79 with pytest.raises(AirflowException):
80 hook.create_database(None)
82 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
83 def test_create_container_exception(self, mock_cosmos):
84 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
85 with pytest.raises(AirflowException):
86 hook.create_collection(None)
88 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
89 def test_create_container(self, mock_cosmos):
90 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
91 hook.create_collection(self.test_collection_name, self.test_database_name)
92 expected_calls = [
93 mock.call()
94 .get_database_client("test_database_name")
95 .create_container("test_collection_name", partition_key=None)
96 ]
97 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
98 mock_cosmos.assert_has_calls(expected_calls)
100 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
101 def test_create_container_default(self, mock_cosmos):
102 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
103 hook.create_collection(self.test_collection_name)
104 expected_calls = [
105 mock.call()
106 .get_database_client("test_database_name")
107 .create_container("test_collection_name", partition_key=None)
108 ]
109 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
110 mock_cosmos.assert_has_calls(expected_calls)
112 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
113 def test_upsert_document_default(self, mock_cosmos):
114 test_id = str(uuid.uuid4())
115 # fmt: off
116 (mock_cosmos
117 .return_value
118 .get_database_client
119 .return_value
120 .get_container_client
121 .return_value
122 .upsert_item
123 .return_value) = {'id': test_id}
124 # fmt: on
125 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
126 returned_item = hook.upsert_document({"id": test_id})
127 expected_calls = [
128 mock.call()
129 .get_database_client("test_database_name")
130 .get_container_client("test_collection_name")
131 .upsert_item({"id": test_id})
132 ]
133 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
134 mock_cosmos.assert_has_calls(expected_calls)
135 logging.getLogger().info(returned_item)
136 assert returned_item["id"] == test_id
138 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
139 def test_upsert_document(self, mock_cosmos):
140 test_id = str(uuid.uuid4())
141 # fmt: off
142 (mock_cosmos
143 .return_value
144 .get_database_client
145 .return_value
146 .get_container_client
147 .return_value
148 .upsert_item
149 .return_value) = {'id': test_id}
150 # fmt: on
151 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
152 returned_item = hook.upsert_document(
153 {"data1": "somedata"},
154 database_name=self.test_database_name,
155 collection_name=self.test_collection_name,
156 document_id=test_id,
157 )
159 expected_calls = [
160 mock.call()
161 .get_database_client("test_database_name")
162 .get_container_client("test_collection_name")
163 .upsert_item({"data1": "somedata", "id": test_id})
164 ]
166 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
167 mock_cosmos.assert_has_calls(expected_calls)
168 logging.getLogger().info(returned_item)
169 assert returned_item["id"] == test_id
171 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
172 def test_insert_documents(self, mock_cosmos):
173 test_id1 = str(uuid.uuid4())
174 test_id2 = str(uuid.uuid4())
175 test_id3 = str(uuid.uuid4())
176 documents = [
177 {"id": test_id1, "data": "data1"},
178 {"id": test_id2, "data": "data2"},
179 {"id": test_id3, "data": "data3"},
180 ]
182 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
183 returned_item = hook.insert_documents(documents)
184 expected_calls = [
185 mock.call()
186 .get_database_client("test_database_name")
187 .get_container_client("test_collection_name")
188 .create_item({"data": "data1", "id": test_id1}),
189 mock.call()
190 .get_database_client("test_database_name")
191 .get_container_client("test_collection_name")
192 .create_item({"data": "data2", "id": test_id2}),
193 mock.call()
194 .get_database_client("test_database_name")
195 .get_container_client("test_collection_name")
196 .create_item({"data": "data3", "id": test_id3}),
197 ]
198 logging.getLogger().info(returned_item)
199 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
200 mock_cosmos.assert_has_calls(expected_calls, any_order=True)
202 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
203 def test_delete_database(self, mock_cosmos):
204 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
205 hook.delete_database(self.test_database_name)
206 expected_calls = [mock.call().delete_database("test_database_name")]
207 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
208 mock_cosmos.assert_has_calls(expected_calls)
210 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
211 def test_delete_database_exception(self, mock_cosmos):
212 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
213 with pytest.raises(AirflowException):
214 hook.delete_database(None)
216 @mock.patch("azure.cosmos.cosmos_client.CosmosClient")
217 def test_delete_container_exception(self, mock_cosmos):
218 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
219 with pytest.raises(AirflowException):
220 hook.delete_collection(None)
222 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
223 def test_delete_container(self, mock_cosmos):
224 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
225 hook.delete_collection(self.test_collection_name, self.test_database_name)
226 expected_calls = [
227 mock.call().get_database_client("test_database_name").delete_container("test_collection_name")
228 ]
229 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
230 mock_cosmos.assert_has_calls(expected_calls)
232 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
233 def test_delete_container_default(self, mock_cosmos):
234 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
235 hook.delete_collection(self.test_collection_name)
236 expected_calls = [
237 mock.call().get_database_client("test_database_name").delete_container("test_collection_name")
238 ]
239 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
240 mock_cosmos.assert_has_calls(expected_calls)
242 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
243 def test_connection_success(self, mock_cosmos):
244 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
245 hook.get_conn().list_databases.return_value = {"id": self.test_database_name}
246 status, msg = hook.test_connection()
247 assert status is True
248 assert msg == "Successfully connected to Azure Cosmos."
250 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient")
251 def test_connection_failure(self, mock_cosmos):
252 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
253 hook.get_conn().list_databases = PropertyMock(side_effect=Exception("Authentication failed."))
254 status, msg = hook.test_connection()
255 assert status is False
256 assert msg == "Authentication failed."
258 def test_get_ui_field_behaviour_placeholders(self):
259 """
260 Check that ensure_prefixes decorator working properly
262 Note: remove this test and the _ensure_prefixes decorator after min airflow version >= 2.5.0
263 """
264 assert list(AzureCosmosDBHook.get_ui_field_behaviour()["placeholders"].keys()) == [
265 "login",
266 "password",
267 "extra__azure_cosmos__database_name",
268 "extra__azure_cosmos__collection_name",
269 ]
270 if get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >= (2, 5):
271 raise Exception(
272 "You must now remove `_ensure_prefixes` from azure utils."
273 " The functionality is now taken care of by providers manager."
274 )