1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import logging
21import uuid
22from unittest import mock
23from unittest.mock import PropertyMock
24
25import pytest
26from azure.cosmos import PartitionKey
27from azure.cosmos.cosmos_client import CosmosClient
28
29from airflow.exceptions import AirflowException
30from airflow.models import Connection
31from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook
32
33MODULE = "airflow.providers.microsoft.azure.hooks.cosmos"
34
35
36class TestAzureCosmosDbHook:
37 # Set up an environment to test with
38 @pytest.fixture(autouse=True)
39 def setup_test_cases(self, create_mock_connection):
40 # set up some test variables
41 self.test_end_point = "https://test_endpoint:443"
42 self.test_master_key = "magic_test_key"
43 self.test_database_name = "test_database_name"
44 self.test_collection_name = "test_collection_name"
45 self.test_database_default = "test_database_default"
46 self.test_collection_default = "test_collection_default"
47 self.test_partition_key = "/test_partition_key"
48 create_mock_connection(
49 Connection(
50 conn_id="azure_cosmos_test_key_id",
51 conn_type="azure_cosmos",
52 login=self.test_end_point,
53 password=self.test_master_key,
54 extra={
55 "database_name": self.test_database_default,
56 "collection_name": self.test_collection_default,
57 "partition_key": self.test_partition_key,
58 },
59 )
60 )
61
62 @pytest.mark.parametrize(
63 "mocked_connection",
64 [
65 Connection(
66 conn_id="azure_cosmos_test_default_credential",
67 conn_type="azure_cosmos",
68 login="https://test_endpoint:443",
69 extra={
70 "resource_group_name": "resource-group-name",
71 "subscription_id": "subscription_id",
72 "managed_identity_client_id": "test_client_id",
73 "workload_identity_tenant_id": "test_tenant_id",
74 },
75 )
76 ],
77 indirect=True,
78 )
79 @mock.patch(f"{MODULE}.get_sync_default_azure_credential")
80 @mock.patch(f"{MODULE}.CosmosDBManagementClient")
81 @mock.patch(f"{MODULE}.CosmosClient")
82 def test_get_conn(self, mock_cosmos, mock_cosmos_db, mock_default_azure_credential, mocked_connection):
83 mock_cosmos_db.return_value.database_accounts.list_keys.return_value.primary_master_key = "master-key"
84
85 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_default_credential")
86 hook.get_conn()
87
88 mock_default_azure_credential.assert_called()
89 args = mock_default_azure_credential.call_args
90 assert args.kwargs["managed_identity_client_id"] == "test_client_id"
91 assert args.kwargs["workload_identity_tenant_id"] == "test_tenant_id"
92
93 @mock.patch(f"{MODULE}.CosmosClient", autospec=True)
94 def test_client(self, mock_cosmos):
95 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
96 assert isinstance(hook.get_conn(), CosmosClient)
97
98 @mock.patch(f"{MODULE}.CosmosClient")
99 def test_create_database(self, mock_cosmos):
100 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
101 hook.create_database(self.test_database_name)
102 expected_calls = [mock.call().create_database("test_database_name")]
103 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
104 mock_cosmos.assert_has_calls(expected_calls)
105
106 @mock.patch(f"{MODULE}.CosmosClient")
107 def test_create_database_exception(self, mock_cosmos):
108 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
109 with pytest.raises(AirflowException):
110 hook.create_database(None)
111
112 @mock.patch(f"{MODULE}.CosmosClient")
113 def test_create_container_exception(self, mock_cosmos):
114 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
115 with pytest.raises(AirflowException):
116 hook.create_collection(None)
117
118 @mock.patch(f"{MODULE}.CosmosClient")
119 def test_create_container(self, mock_cosmos):
120 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
121 hook.create_collection(self.test_collection_name, self.test_database_name, partition_key="/id")
122 expected_calls = [
123 mock.call()
124 .get_database_client("test_database_name")
125 .create_container("test_collection_name", partition_key=PartitionKey(path="/id"))
126 ]
127 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
128 mock_cosmos.assert_has_calls(expected_calls)
129
130 @mock.patch(f"{MODULE}.CosmosClient")
131 def test_create_container_default(self, mock_cosmos):
132 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
133 hook.create_collection(self.test_collection_name)
134 expected_calls = [
135 mock.call()
136 .get_database_client("test_database_name")
137 .create_container(
138 "test_collection_name", partition_key=PartitionKey(path=self.test_partition_key)
139 )
140 ]
141 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
142 mock_cosmos.assert_has_calls(expected_calls)
143
144 @mock.patch(f"{MODULE}.CosmosClient")
145 def test_upsert_document_default(self, mock_cosmos):
146 test_id = str(uuid.uuid4())
147
148 (
149 mock_cosmos.return_value.get_database_client.return_value.get_container_client.return_value.upsert_item.return_value
150 ) = {"id": test_id}
151
152 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
153 returned_item = hook.upsert_document({"id": test_id})
154 expected_calls = [
155 mock.call()
156 .get_database_client("test_database_name")
157 .get_container_client("test_collection_name")
158 .upsert_item({"id": test_id})
159 ]
160 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
161 mock_cosmos.assert_has_calls(expected_calls)
162 logging.getLogger().info(returned_item)
163 assert returned_item["id"] == test_id
164
165 @mock.patch(f"{MODULE}.CosmosClient")
166 def test_upsert_document(self, mock_cosmos):
167 test_id = str(uuid.uuid4())
168
169 (
170 mock_cosmos.return_value.get_database_client.return_value.get_container_client.return_value.upsert_item.return_value
171 ) = {"id": test_id}
172
173 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
174 returned_item = hook.upsert_document(
175 {"data1": "somedata"},
176 database_name=self.test_database_name,
177 collection_name=self.test_collection_name,
178 document_id=test_id,
179 )
180
181 expected_calls = [
182 mock.call()
183 .get_database_client("test_database_name")
184 .get_container_client("test_collection_name")
185 .upsert_item({"data1": "somedata", "id": test_id})
186 ]
187
188 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
189 mock_cosmos.assert_has_calls(expected_calls)
190 logging.getLogger().info(returned_item)
191 assert returned_item["id"] == test_id
192
193 @mock.patch(f"{MODULE}.CosmosClient")
194 def test_insert_documents(self, mock_cosmos):
195 test_id1 = str(uuid.uuid4())
196 test_id2 = str(uuid.uuid4())
197 test_id3 = str(uuid.uuid4())
198 documents = [
199 {"id": test_id1, "data": "data1"},
200 {"id": test_id2, "data": "data2"},
201 {"id": test_id3, "data": "data3"},
202 ]
203
204 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
205 returned_item = hook.insert_documents(documents)
206 expected_calls = [
207 mock.call()
208 .get_database_client("test_database_name")
209 .get_container_client("test_collection_name")
210 .create_item({"data": "data1", "id": test_id1}),
211 mock.call()
212 .get_database_client("test_database_name")
213 .get_container_client("test_collection_name")
214 .create_item({"data": "data2", "id": test_id2}),
215 mock.call()
216 .get_database_client("test_database_name")
217 .get_container_client("test_collection_name")
218 .create_item({"data": "data3", "id": test_id3}),
219 ]
220 logging.getLogger().info(returned_item)
221 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
222 mock_cosmos.assert_has_calls(expected_calls, any_order=True)
223
224 @mock.patch(f"{MODULE}.CosmosClient")
225 def test_delete_database(self, mock_cosmos):
226 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
227 hook.delete_database(self.test_database_name)
228 expected_calls = [mock.call().delete_database("test_database_name")]
229 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
230 mock_cosmos.assert_has_calls(expected_calls)
231
232 @mock.patch(f"{MODULE}.CosmosClient")
233 def test_delete_database_exception(self, mock_cosmos):
234 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
235 with pytest.raises(AirflowException):
236 hook.delete_database(None)
237
238 @mock.patch("azure.cosmos.cosmos_client.CosmosClient")
239 def test_delete_container_exception(self, mock_cosmos):
240 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
241 with pytest.raises(AirflowException):
242 hook.delete_collection(None)
243
244 @mock.patch(f"{MODULE}.CosmosClient")
245 def test_delete_container(self, mock_cosmos):
246 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
247 hook.delete_collection(self.test_collection_name, self.test_database_name)
248 expected_calls = [
249 mock.call().get_database_client("test_database_name").delete_container("test_collection_name")
250 ]
251 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
252 mock_cosmos.assert_has_calls(expected_calls)
253
254 @mock.patch(f"{MODULE}.CosmosClient")
255 def test_delete_container_default(self, mock_cosmos):
256 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
257 hook.delete_collection(self.test_collection_name)
258 expected_calls = [
259 mock.call().get_database_client("test_database_name").delete_container("test_collection_name")
260 ]
261 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key})
262 mock_cosmos.assert_has_calls(expected_calls)
263
264 @mock.patch(f"{MODULE}.CosmosClient")
265 def test_connection_success(self, mock_cosmos):
266 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
267 hook.get_conn().list_databases.return_value = {"id": self.test_database_name}
268 status, msg = hook.test_connection()
269 assert status is True
270 assert msg == "Successfully connected to Azure Cosmos."
271
272 @mock.patch(f"{MODULE}.CosmosClient")
273 def test_connection_failure(self, mock_cosmos):
274 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id")
275 hook.get_conn().list_databases = PropertyMock(side_effect=Exception("Authentication failed."))
276 status, msg = hook.test_connection()
277 assert status is False
278 assert msg == "Authentication failed."