1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import pytest
21
22from airflow.exceptions import AirflowException
23from airflow.providers.databricks.operators.databricks_sql import DatabricksCopyIntoOperator
24from airflow.utils import timezone
25
26DATE = "2017-04-20"
27TASK_ID = "databricks-sql-operator"
28DEFAULT_CONN_ID = "databricks_default"
29COPY_FILE_LOCATION = "s3://my-bucket/jsonData"
30
31
32def test_copy_with_files():
33 op = DatabricksCopyIntoOperator(
34 file_location=COPY_FILE_LOCATION,
35 file_format="JSON",
36 table_name="test",
37 files=["file1", "file2", "file3"],
38 format_options={"dateFormat": "yyyy-MM-dd"},
39 task_id=TASK_ID,
40 )
41 assert (
42 op._create_sql_query()
43 == f"""COPY INTO test
44FROM '{COPY_FILE_LOCATION}'
45FILEFORMAT = JSON
46FILES = ('file1','file2','file3')
47FORMAT_OPTIONS ('dateFormat' = 'yyyy-MM-dd')
48""".strip()
49 )
50
51
52def test_copy_with_expression():
53 expression = "col1, col2"
54 op = DatabricksCopyIntoOperator(
55 file_location=COPY_FILE_LOCATION,
56 file_format="CSV",
57 table_name="test",
58 task_id=TASK_ID,
59 pattern="folder1/file_[a-g].csv",
60 expression_list=expression,
61 format_options={"header": "true"},
62 force_copy=True,
63 )
64 assert (
65 op._create_sql_query()
66 == f"""COPY INTO test
67FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}')
68FILEFORMAT = CSV
69PATTERN = 'folder1/file_[a-g].csv'
70FORMAT_OPTIONS ('header' = 'true')
71COPY_OPTIONS ('force' = 'true')
72""".strip()
73 )
74
75
76def test_copy_with_credential():
77 expression = "col1, col2"
78 op = DatabricksCopyIntoOperator(
79 file_location=COPY_FILE_LOCATION,
80 file_format="CSV",
81 table_name="test",
82 task_id=TASK_ID,
83 expression_list=expression,
84 credential={"AZURE_SAS_TOKEN": "abc"},
85 )
86 assert (
87 op._create_sql_query()
88 == f"""COPY INTO test
89FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') ))
90FILEFORMAT = CSV
91""".strip()
92 )
93
94
95def test_copy_with_target_credential():
96 expression = "col1, col2"
97 op = DatabricksCopyIntoOperator(
98 file_location=COPY_FILE_LOCATION,
99 file_format="CSV",
100 table_name="test",
101 task_id=TASK_ID,
102 expression_list=expression,
103 storage_credential="abc",
104 credential={"AZURE_SAS_TOKEN": "abc"},
105 )
106 assert (
107 op._create_sql_query()
108 == f"""COPY INTO test WITH (CREDENTIAL abc)
109FROM (SELECT {expression} FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') ))
110FILEFORMAT = CSV
111""".strip()
112 )
113
114
115def test_copy_with_encryption():
116 op = DatabricksCopyIntoOperator(
117 file_location=COPY_FILE_LOCATION,
118 file_format="CSV",
119 table_name="test",
120 task_id=TASK_ID,
121 encryption={"TYPE": "AWS_SSE_C", "MASTER_KEY": "abc"},
122 )
123 assert (
124 op._create_sql_query()
125 == f"""COPY INTO test
126FROM '{COPY_FILE_LOCATION}' WITH ( ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc'))
127FILEFORMAT = CSV
128""".strip()
129 )
130
131
132def test_copy_with_encryption_and_credential():
133 op = DatabricksCopyIntoOperator(
134 file_location=COPY_FILE_LOCATION,
135 file_format="CSV",
136 table_name="test",
137 task_id=TASK_ID,
138 encryption={"TYPE": "AWS_SSE_C", "MASTER_KEY": "abc"},
139 credential={"AZURE_SAS_TOKEN": "abc"},
140 )
141 assert (
142 op._create_sql_query()
143 == f"""COPY INTO test
144FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') """
145 """ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc'))
146FILEFORMAT = CSV
147""".strip()
148 )
149
150
151def test_copy_with_validate_all():
152 op = DatabricksCopyIntoOperator(
153 file_location=COPY_FILE_LOCATION,
154 file_format="JSON",
155 table_name="test",
156 task_id=TASK_ID,
157 validate=True,
158 )
159 assert (
160 op._create_sql_query()
161 == f"""COPY INTO test
162FROM '{COPY_FILE_LOCATION}'
163FILEFORMAT = JSON
164VALIDATE ALL
165""".strip()
166 )
167
168
169def test_copy_with_validate_N_rows():
170 op = DatabricksCopyIntoOperator(
171 file_location=COPY_FILE_LOCATION,
172 file_format="JSON",
173 table_name="test",
174 task_id=TASK_ID,
175 validate=10,
176 )
177 assert (
178 op._create_sql_query()
179 == f"""COPY INTO test
180FROM '{COPY_FILE_LOCATION}'
181FILEFORMAT = JSON
182VALIDATE 10 ROWS
183""".strip()
184 )
185
186
187def test_incorrect_params_files_patterns():
188 exception_message = "Only one of 'pattern' or 'files' should be specified"
189 with pytest.raises(AirflowException, match=exception_message):
190 DatabricksCopyIntoOperator(
191 task_id=TASK_ID,
192 file_location=COPY_FILE_LOCATION,
193 file_format="JSON",
194 table_name="test",
195 files=["file1", "file2", "file3"],
196 pattern="abc",
197 )
198
199
200def test_incorrect_params_emtpy_table():
201 exception_message = "table_name shouldn't be empty"
202 with pytest.raises(AirflowException, match=exception_message):
203 DatabricksCopyIntoOperator(
204 task_id=TASK_ID,
205 file_location=COPY_FILE_LOCATION,
206 file_format="JSON",
207 table_name="",
208 )
209
210
211def test_incorrect_params_emtpy_location():
212 exception_message = "file_location shouldn't be empty"
213 with pytest.raises(AirflowException, match=exception_message):
214 DatabricksCopyIntoOperator(
215 task_id=TASK_ID,
216 file_location="",
217 file_format="JSON",
218 table_name="abc",
219 )
220
221
222def test_incorrect_params_wrong_format():
223 file_format = "JSONL"
224 exception_message = f"file_format '{file_format}' isn't supported"
225 with pytest.raises(AirflowException, match=exception_message):
226 DatabricksCopyIntoOperator(
227 task_id=TASK_ID,
228 file_location=COPY_FILE_LOCATION,
229 file_format=file_format,
230 table_name="abc",
231 )
232
233
234@pytest.mark.db_test
235def test_templating(create_task_instance_of_operator):
236 ti = create_task_instance_of_operator(
237 DatabricksCopyIntoOperator,
238 # Templated fields
239 file_location="{{ 'file-location' }}",
240 files="{{ 'files' }}",
241 table_name="{{ 'table-name' }}",
242 databricks_conn_id="{{ 'databricks-conn-id' }}",
243 # Other parameters
244 file_format="JSON",
245 dag_id="test_template_body_templating_dag",
246 task_id="test_template_body_templating_task",
247 execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
248 )
249 ti.render_templates()
250 task: DatabricksCopyIntoOperator = ti.task
251 assert task.file_location == "file-location"
252 assert task.files == "files"
253 assert task.table_name == "table-name"
254 assert task.databricks_conn_id == "databricks-conn-id"