1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Shared helper functions for tqdm progress bar."""
16
17import concurrent.futures
18import sys
19import time
20import typing
21from typing import Optional
22import warnings
23
24try:
25 import tqdm # type: ignore
26except ImportError:
27 tqdm = None
28
29try:
30 import tqdm.notebook as tqdm_notebook # type: ignore
31except ImportError:
32 tqdm_notebook = None
33
34if typing.TYPE_CHECKING: # pragma: NO COVER
35 from google.cloud.bigquery import QueryJob
36 from google.cloud.bigquery.table import RowIterator
37
38_NO_TQDM_ERROR = (
39 "A progress bar was requested, but there was an error loading the tqdm "
40 "library. Please install tqdm to use the progress bar functionality."
41)
42
43_PROGRESS_BAR_UPDATE_INTERVAL = 0.5
44
45
46def get_progress_bar(progress_bar_type, description, total, unit):
47 """Construct a tqdm progress bar object, if tqdm is installed."""
48 if tqdm is None or tqdm_notebook is None and progress_bar_type == "tqdm_notebook":
49 if progress_bar_type is not None:
50 warnings.warn(_NO_TQDM_ERROR, UserWarning, stacklevel=3)
51 return None
52
53 try:
54 if progress_bar_type == "tqdm":
55 return tqdm.tqdm(
56 bar_format="{l_bar}{bar}|",
57 colour="green",
58 desc=description,
59 file=sys.stdout,
60 total=total,
61 unit=unit,
62 )
63 elif progress_bar_type == "tqdm_notebook":
64 return tqdm_notebook.tqdm(
65 bar_format="{l_bar}{bar}|",
66 desc=description,
67 file=sys.stdout,
68 total=total,
69 unit=unit,
70 )
71 elif progress_bar_type == "tqdm_gui":
72 return tqdm.tqdm_gui(desc=description, total=total, unit=unit)
73 except (KeyError, TypeError): # pragma: NO COVER
74 # Protect ourselves from any tqdm errors. In case of
75 # unexpected tqdm behavior, just fall back to showing
76 # no progress bar.
77 warnings.warn(_NO_TQDM_ERROR, UserWarning, stacklevel=3)
78 return None
79
80
81def wait_for_query(
82 query_job: "QueryJob",
83 progress_bar_type: Optional[str] = None,
84 max_results: Optional[int] = None,
85) -> "RowIterator":
86 """Return query result and display a progress bar while the query running, if tqdm is installed.
87
88 Args:
89 query_job:
90 The job representing the execution of the query on the server.
91 progress_bar_type:
92 The type of progress bar to use to show query progress.
93 max_results:
94 The maximum number of rows the row iterator should return.
95
96 Returns:
97 A row iterator over the query results.
98 """
99 default_total = 1
100 current_stage = None
101 start_time = time.perf_counter()
102
103 progress_bar = get_progress_bar(
104 progress_bar_type, "Query is running", default_total, "query"
105 )
106 if progress_bar is None:
107 return query_job.result(max_results=max_results)
108
109 i = 0
110 while True:
111 if query_job.query_plan:
112 default_total = len(query_job.query_plan)
113 current_stage = query_job.query_plan[i]
114 progress_bar.total = len(query_job.query_plan)
115 progress_bar.set_description(
116 f"Query executing stage {current_stage.name} and status {current_stage.status} : {time.perf_counter() - start_time:.2f}s"
117 )
118 try:
119 query_result = query_job.result(
120 timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=max_results
121 )
122 progress_bar.update(default_total)
123 progress_bar.set_description(
124 f"Job ID {query_job.job_id} successfully executed",
125 )
126 break
127 except concurrent.futures.TimeoutError:
128 query_job.reload() # Refreshes the state via a GET request.
129 if current_stage:
130 if current_stage.status == "COMPLETE":
131 if i < default_total - 1:
132 progress_bar.update(i + 1)
133 i += 1
134 continue
135
136 progress_bar.close()
137 return query_result