Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/tests/providers/microsoft/azure/hooks/test_cosmos.py: 0%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

149 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import logging 

21import uuid 

22from unittest import mock 

23from unittest.mock import PropertyMock 

24 

25import pytest 

26from azure.cosmos import PartitionKey 

27from azure.cosmos.cosmos_client import CosmosClient 

28 

29from airflow.exceptions import AirflowException 

30from airflow.models import Connection 

31from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook 

32 

33MODULE = "airflow.providers.microsoft.azure.hooks.cosmos" 

34 

35 

36class TestAzureCosmosDbHook: 

37 # Set up an environment to test with 

38 @pytest.fixture(autouse=True) 

39 def setup_test_cases(self, create_mock_connection): 

40 # set up some test variables 

41 self.test_end_point = "https://test_endpoint:443" 

42 self.test_master_key = "magic_test_key" 

43 self.test_database_name = "test_database_name" 

44 self.test_collection_name = "test_collection_name" 

45 self.test_database_default = "test_database_default" 

46 self.test_collection_default = "test_collection_default" 

47 self.test_partition_key = "/test_partition_key" 

48 create_mock_connection( 

49 Connection( 

50 conn_id="azure_cosmos_test_key_id", 

51 conn_type="azure_cosmos", 

52 login=self.test_end_point, 

53 password=self.test_master_key, 

54 extra={ 

55 "database_name": self.test_database_default, 

56 "collection_name": self.test_collection_default, 

57 "partition_key": self.test_partition_key, 

58 }, 

59 ) 

60 ) 

61 

62 @pytest.mark.parametrize( 

63 "mocked_connection", 

64 [ 

65 Connection( 

66 conn_id="azure_cosmos_test_default_credential", 

67 conn_type="azure_cosmos", 

68 login="https://test_endpoint:443", 

69 extra={ 

70 "resource_group_name": "resource-group-name", 

71 "subscription_id": "subscription_id", 

72 "managed_identity_client_id": "test_client_id", 

73 "workload_identity_tenant_id": "test_tenant_id", 

74 }, 

75 ) 

76 ], 

77 indirect=True, 

78 ) 

79 @mock.patch(f"{MODULE}.get_sync_default_azure_credential") 

80 @mock.patch(f"{MODULE}.CosmosDBManagementClient") 

81 @mock.patch(f"{MODULE}.CosmosClient") 

82 def test_get_conn(self, mock_cosmos, mock_cosmos_db, mock_default_azure_credential, mocked_connection): 

83 mock_cosmos_db.return_value.database_accounts.list_keys.return_value.primary_master_key = "master-key" 

84 

85 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_default_credential") 

86 hook.get_conn() 

87 

88 mock_default_azure_credential.assert_called() 

89 args = mock_default_azure_credential.call_args 

90 assert args.kwargs["managed_identity_client_id"] == "test_client_id" 

91 assert args.kwargs["workload_identity_tenant_id"] == "test_tenant_id" 

92 

93 @mock.patch(f"{MODULE}.CosmosClient", autospec=True) 

94 def test_client(self, mock_cosmos): 

95 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

96 assert isinstance(hook.get_conn(), CosmosClient) 

97 

98 @mock.patch(f"{MODULE}.CosmosClient") 

99 def test_create_database(self, mock_cosmos): 

100 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

101 hook.create_database(self.test_database_name) 

102 expected_calls = [mock.call().create_database("test_database_name")] 

103 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

104 mock_cosmos.assert_has_calls(expected_calls) 

105 

106 @mock.patch(f"{MODULE}.CosmosClient") 

107 def test_create_database_exception(self, mock_cosmos): 

108 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

109 with pytest.raises(AirflowException): 

110 hook.create_database(None) 

111 

112 @mock.patch(f"{MODULE}.CosmosClient") 

113 def test_create_container_exception(self, mock_cosmos): 

114 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

115 with pytest.raises(AirflowException): 

116 hook.create_collection(None) 

117 

118 @mock.patch(f"{MODULE}.CosmosClient") 

119 def test_create_container(self, mock_cosmos): 

120 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

121 hook.create_collection(self.test_collection_name, self.test_database_name, partition_key="/id") 

122 expected_calls = [ 

123 mock.call() 

124 .get_database_client("test_database_name") 

125 .create_container("test_collection_name", partition_key=PartitionKey(path="/id")) 

126 ] 

127 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

128 mock_cosmos.assert_has_calls(expected_calls) 

129 

130 @mock.patch(f"{MODULE}.CosmosClient") 

131 def test_create_container_default(self, mock_cosmos): 

132 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

133 hook.create_collection(self.test_collection_name) 

134 expected_calls = [ 

135 mock.call() 

136 .get_database_client("test_database_name") 

137 .create_container( 

138 "test_collection_name", partition_key=PartitionKey(path=self.test_partition_key) 

139 ) 

140 ] 

141 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

142 mock_cosmos.assert_has_calls(expected_calls) 

143 

144 @mock.patch(f"{MODULE}.CosmosClient") 

145 def test_upsert_document_default(self, mock_cosmos): 

146 test_id = str(uuid.uuid4()) 

147 

148 ( 

149 mock_cosmos.return_value.get_database_client.return_value.get_container_client.return_value.upsert_item.return_value 

150 ) = {"id": test_id} 

151 

152 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

153 returned_item = hook.upsert_document({"id": test_id}) 

154 expected_calls = [ 

155 mock.call() 

156 .get_database_client("test_database_name") 

157 .get_container_client("test_collection_name") 

158 .upsert_item({"id": test_id}) 

159 ] 

160 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

161 mock_cosmos.assert_has_calls(expected_calls) 

162 logging.getLogger().info(returned_item) 

163 assert returned_item["id"] == test_id 

164 

165 @mock.patch(f"{MODULE}.CosmosClient") 

166 def test_upsert_document(self, mock_cosmos): 

167 test_id = str(uuid.uuid4()) 

168 

169 ( 

170 mock_cosmos.return_value.get_database_client.return_value.get_container_client.return_value.upsert_item.return_value 

171 ) = {"id": test_id} 

172 

173 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

174 returned_item = hook.upsert_document( 

175 {"data1": "somedata"}, 

176 database_name=self.test_database_name, 

177 collection_name=self.test_collection_name, 

178 document_id=test_id, 

179 ) 

180 

181 expected_calls = [ 

182 mock.call() 

183 .get_database_client("test_database_name") 

184 .get_container_client("test_collection_name") 

185 .upsert_item({"data1": "somedata", "id": test_id}) 

186 ] 

187 

188 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

189 mock_cosmos.assert_has_calls(expected_calls) 

190 logging.getLogger().info(returned_item) 

191 assert returned_item["id"] == test_id 

192 

193 @mock.patch(f"{MODULE}.CosmosClient") 

194 def test_insert_documents(self, mock_cosmos): 

195 test_id1 = str(uuid.uuid4()) 

196 test_id2 = str(uuid.uuid4()) 

197 test_id3 = str(uuid.uuid4()) 

198 documents = [ 

199 {"id": test_id1, "data": "data1"}, 

200 {"id": test_id2, "data": "data2"}, 

201 {"id": test_id3, "data": "data3"}, 

202 ] 

203 

204 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

205 returned_item = hook.insert_documents(documents) 

206 expected_calls = [ 

207 mock.call() 

208 .get_database_client("test_database_name") 

209 .get_container_client("test_collection_name") 

210 .create_item({"data": "data1", "id": test_id1}), 

211 mock.call() 

212 .get_database_client("test_database_name") 

213 .get_container_client("test_collection_name") 

214 .create_item({"data": "data2", "id": test_id2}), 

215 mock.call() 

216 .get_database_client("test_database_name") 

217 .get_container_client("test_collection_name") 

218 .create_item({"data": "data3", "id": test_id3}), 

219 ] 

220 logging.getLogger().info(returned_item) 

221 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

222 mock_cosmos.assert_has_calls(expected_calls, any_order=True) 

223 

224 @mock.patch(f"{MODULE}.CosmosClient") 

225 def test_delete_database(self, mock_cosmos): 

226 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

227 hook.delete_database(self.test_database_name) 

228 expected_calls = [mock.call().delete_database("test_database_name")] 

229 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

230 mock_cosmos.assert_has_calls(expected_calls) 

231 

232 @mock.patch(f"{MODULE}.CosmosClient") 

233 def test_delete_database_exception(self, mock_cosmos): 

234 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

235 with pytest.raises(AirflowException): 

236 hook.delete_database(None) 

237 

238 @mock.patch("azure.cosmos.cosmos_client.CosmosClient") 

239 def test_delete_container_exception(self, mock_cosmos): 

240 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

241 with pytest.raises(AirflowException): 

242 hook.delete_collection(None) 

243 

244 @mock.patch(f"{MODULE}.CosmosClient") 

245 def test_delete_container(self, mock_cosmos): 

246 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

247 hook.delete_collection(self.test_collection_name, self.test_database_name) 

248 expected_calls = [ 

249 mock.call().get_database_client("test_database_name").delete_container("test_collection_name") 

250 ] 

251 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

252 mock_cosmos.assert_has_calls(expected_calls) 

253 

254 @mock.patch(f"{MODULE}.CosmosClient") 

255 def test_delete_container_default(self, mock_cosmos): 

256 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

257 hook.delete_collection(self.test_collection_name) 

258 expected_calls = [ 

259 mock.call().get_database_client("test_database_name").delete_container("test_collection_name") 

260 ] 

261 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

262 mock_cosmos.assert_has_calls(expected_calls) 

263 

264 @mock.patch(f"{MODULE}.CosmosClient") 

265 def test_connection_success(self, mock_cosmos): 

266 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

267 hook.get_conn().list_databases.return_value = {"id": self.test_database_name} 

268 status, msg = hook.test_connection() 

269 assert status is True 

270 assert msg == "Successfully connected to Azure Cosmos." 

271 

272 @mock.patch(f"{MODULE}.CosmosClient") 

273 def test_connection_failure(self, mock_cosmos): 

274 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

275 hook.get_conn().list_databases = PropertyMock(side_effect=Exception("Authentication failed.")) 

276 status, msg = hook.test_connection() 

277 assert status is False 

278 assert msg == "Authentication failed."