Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py: 0%

141 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import json 

21import logging 

22import uuid 

23from unittest import mock 

24from unittest.mock import PropertyMock 

25 

26import pytest 

27from azure.cosmos.cosmos_client import CosmosClient 

28 

29from airflow.exceptions import AirflowException 

30from airflow.models import Connection 

31from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook 

32from airflow.utils import db 

33from tests.test_utils.providers import get_provider_min_airflow_version 

34 

35 

36class TestAzureCosmosDbHook: 

37 

38 # Set up an environment to test with 

39 def setup_method(self): 

40 # set up some test variables 

41 self.test_end_point = "https://test_endpoint:443" 

42 self.test_master_key = "magic_test_key" 

43 self.test_database_name = "test_database_name" 

44 self.test_collection_name = "test_collection_name" 

45 self.test_database_default = "test_database_default" 

46 self.test_collection_default = "test_collection_default" 

47 db.merge_conn( 

48 Connection( 

49 conn_id="azure_cosmos_test_key_id", 

50 conn_type="azure_cosmos", 

51 login=self.test_end_point, 

52 password=self.test_master_key, 

53 extra=json.dumps( 

54 { 

55 "database_name": self.test_database_default, 

56 "collection_name": self.test_collection_default, 

57 } 

58 ), 

59 ) 

60 ) 

61 

62 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient", autospec=True) 

63 def test_client(self, mock_cosmos): 

64 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

65 assert hook._conn is None 

66 assert isinstance(hook.get_conn(), CosmosClient) 

67 

68 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

69 def test_create_database(self, mock_cosmos): 

70 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

71 hook.create_database(self.test_database_name) 

72 expected_calls = [mock.call().create_database("test_database_name")] 

73 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

74 mock_cosmos.assert_has_calls(expected_calls) 

75 

76 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

77 def test_create_database_exception(self, mock_cosmos): 

78 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

79 with pytest.raises(AirflowException): 

80 hook.create_database(None) 

81 

82 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

83 def test_create_container_exception(self, mock_cosmos): 

84 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

85 with pytest.raises(AirflowException): 

86 hook.create_collection(None) 

87 

88 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

89 def test_create_container(self, mock_cosmos): 

90 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

91 hook.create_collection(self.test_collection_name, self.test_database_name) 

92 expected_calls = [ 

93 mock.call() 

94 .get_database_client("test_database_name") 

95 .create_container("test_collection_name", partition_key=None) 

96 ] 

97 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

98 mock_cosmos.assert_has_calls(expected_calls) 

99 

100 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

101 def test_create_container_default(self, mock_cosmos): 

102 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

103 hook.create_collection(self.test_collection_name) 

104 expected_calls = [ 

105 mock.call() 

106 .get_database_client("test_database_name") 

107 .create_container("test_collection_name", partition_key=None) 

108 ] 

109 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

110 mock_cosmos.assert_has_calls(expected_calls) 

111 

112 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

113 def test_upsert_document_default(self, mock_cosmos): 

114 test_id = str(uuid.uuid4()) 

115 # fmt: off 

116 (mock_cosmos 

117 .return_value 

118 .get_database_client 

119 .return_value 

120 .get_container_client 

121 .return_value 

122 .upsert_item 

123 .return_value) = {'id': test_id} 

124 # fmt: on 

125 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

126 returned_item = hook.upsert_document({"id": test_id}) 

127 expected_calls = [ 

128 mock.call() 

129 .get_database_client("test_database_name") 

130 .get_container_client("test_collection_name") 

131 .upsert_item({"id": test_id}) 

132 ] 

133 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

134 mock_cosmos.assert_has_calls(expected_calls) 

135 logging.getLogger().info(returned_item) 

136 assert returned_item["id"] == test_id 

137 

138 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

139 def test_upsert_document(self, mock_cosmos): 

140 test_id = str(uuid.uuid4()) 

141 # fmt: off 

142 (mock_cosmos 

143 .return_value 

144 .get_database_client 

145 .return_value 

146 .get_container_client 

147 .return_value 

148 .upsert_item 

149 .return_value) = {'id': test_id} 

150 # fmt: on 

151 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

152 returned_item = hook.upsert_document( 

153 {"data1": "somedata"}, 

154 database_name=self.test_database_name, 

155 collection_name=self.test_collection_name, 

156 document_id=test_id, 

157 ) 

158 

159 expected_calls = [ 

160 mock.call() 

161 .get_database_client("test_database_name") 

162 .get_container_client("test_collection_name") 

163 .upsert_item({"data1": "somedata", "id": test_id}) 

164 ] 

165 

166 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

167 mock_cosmos.assert_has_calls(expected_calls) 

168 logging.getLogger().info(returned_item) 

169 assert returned_item["id"] == test_id 

170 

171 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

172 def test_insert_documents(self, mock_cosmos): 

173 test_id1 = str(uuid.uuid4()) 

174 test_id2 = str(uuid.uuid4()) 

175 test_id3 = str(uuid.uuid4()) 

176 documents = [ 

177 {"id": test_id1, "data": "data1"}, 

178 {"id": test_id2, "data": "data2"}, 

179 {"id": test_id3, "data": "data3"}, 

180 ] 

181 

182 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

183 returned_item = hook.insert_documents(documents) 

184 expected_calls = [ 

185 mock.call() 

186 .get_database_client("test_database_name") 

187 .get_container_client("test_collection_name") 

188 .create_item({"data": "data1", "id": test_id1}), 

189 mock.call() 

190 .get_database_client("test_database_name") 

191 .get_container_client("test_collection_name") 

192 .create_item({"data": "data2", "id": test_id2}), 

193 mock.call() 

194 .get_database_client("test_database_name") 

195 .get_container_client("test_collection_name") 

196 .create_item({"data": "data3", "id": test_id3}), 

197 ] 

198 logging.getLogger().info(returned_item) 

199 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

200 mock_cosmos.assert_has_calls(expected_calls, any_order=True) 

201 

202 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

203 def test_delete_database(self, mock_cosmos): 

204 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

205 hook.delete_database(self.test_database_name) 

206 expected_calls = [mock.call().delete_database("test_database_name")] 

207 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

208 mock_cosmos.assert_has_calls(expected_calls) 

209 

210 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

211 def test_delete_database_exception(self, mock_cosmos): 

212 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

213 with pytest.raises(AirflowException): 

214 hook.delete_database(None) 

215 

216 @mock.patch("azure.cosmos.cosmos_client.CosmosClient") 

217 def test_delete_container_exception(self, mock_cosmos): 

218 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

219 with pytest.raises(AirflowException): 

220 hook.delete_collection(None) 

221 

222 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

223 def test_delete_container(self, mock_cosmos): 

224 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

225 hook.delete_collection(self.test_collection_name, self.test_database_name) 

226 expected_calls = [ 

227 mock.call().get_database_client("test_database_name").delete_container("test_collection_name") 

228 ] 

229 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

230 mock_cosmos.assert_has_calls(expected_calls) 

231 

232 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

233 def test_delete_container_default(self, mock_cosmos): 

234 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

235 hook.delete_collection(self.test_collection_name) 

236 expected_calls = [ 

237 mock.call().get_database_client("test_database_name").delete_container("test_collection_name") 

238 ] 

239 mock_cosmos.assert_any_call(self.test_end_point, {"masterKey": self.test_master_key}) 

240 mock_cosmos.assert_has_calls(expected_calls) 

241 

242 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

243 def test_connection_success(self, mock_cosmos): 

244 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

245 hook.get_conn().list_databases.return_value = {"id": self.test_database_name} 

246 status, msg = hook.test_connection() 

247 assert status is True 

248 assert msg == "Successfully connected to Azure Cosmos." 

249 

250 @mock.patch("airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient") 

251 def test_connection_failure(self, mock_cosmos): 

252 hook = AzureCosmosDBHook(azure_cosmos_conn_id="azure_cosmos_test_key_id") 

253 hook.get_conn().list_databases = PropertyMock(side_effect=Exception("Authentication failed.")) 

254 status, msg = hook.test_connection() 

255 assert status is False 

256 assert msg == "Authentication failed." 

257 

258 def test_get_ui_field_behaviour_placeholders(self): 

259 """ 

260 Check that ensure_prefixes decorator working properly 

261 

262 Note: remove this test and the _ensure_prefixes decorator after min airflow version >= 2.5.0 

263 """ 

264 assert list(AzureCosmosDBHook.get_ui_field_behaviour()["placeholders"].keys()) == [ 

265 "login", 

266 "password", 

267 "extra__azure_cosmos__database_name", 

268 "extra__azure_cosmos__collection_name", 

269 ] 

270 if get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >= (2, 5): 

271 raise Exception( 

272 "You must now remove `_ensure_prefixes` from azure utils." 

273 " The functionality is now taken care of by providers manager." 

274 )