1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""This module contains classes that traverse AST and convert it to something else.
16
17If the parser successfully accepts a valid input (the bigquery cell magic arguments),
18the result is an Abstract Syntax Tree (AST) that represents the input as a tree
19with notes containing various useful metadata.
20
21Node visitors can process such tree and convert it to something else that can
22be used for further processing, for example:
23
24 * An optimized version of the tree with redundancy removed/simplified (not used here).
25 * The same tree, but with semantic errors checked, because an otherwise syntactically
26 valid input might still contain errors (not used here, semantic errors are detected
27 elsewhere).
28 * A form that can be directly handed to the code that operates on the input. The
29 ``QueryParamsExtractor`` class, for instance, splits the input arguments into
30 the "--params <...>" part and everything else.
31 The "everything else" part can be then parsed by the default Jupyter argument parser,
32 while the --params option is processed separately by the Python evaluator.
33
34More info on the visitor design pattern:
35https://en.wikipedia.org/wiki/Visitor_pattern
36
37"""
38
39from __future__ import print_function
40
41
42class NodeVisitor(object):
43 """Base visitor class implementing the dispatch machinery."""
44
45 def visit(self, node):
46 method_name = "visit_{}".format(type(node).__name__)
47 visitor_method = getattr(self, method_name, self.method_missing)
48 return visitor_method(node)
49
50 def method_missing(self, node):
51 raise Exception("No visit_{} method".format(type(node).__name__))
52
53
54class QueryParamsExtractor(NodeVisitor):
55 """A visitor that extracts the "--params <...>" part from input line arguments."""
56
57 def visit_InputLine(self, node):
58 params_dict_parts = []
59 other_parts = []
60
61 dest_var_parts = self.visit(node.destination_var)
62 params, other_options = self.visit(node.option_list)
63
64 if dest_var_parts:
65 other_parts.extend(dest_var_parts)
66
67 if dest_var_parts and other_options:
68 other_parts.append(" ")
69 other_parts.extend(other_options)
70
71 params_dict_parts.extend(params)
72
73 return "".join(params_dict_parts), "".join(other_parts)
74
75 def visit_DestinationVar(self, node):
76 return [node.name] if node.name is not None else []
77
78 def visit_CmdOptionList(self, node):
79 params_opt_parts = []
80 other_parts = []
81
82 for i, opt in enumerate(node.options):
83 option_parts = self.visit(opt)
84 list_to_extend = params_opt_parts if opt.name == "params" else other_parts
85
86 if list_to_extend:
87 list_to_extend.append(" ")
88 list_to_extend.extend(option_parts)
89
90 return params_opt_parts, other_parts
91
92 def visit_CmdOption(self, node):
93 result = ["--{}".format(node.name)]
94
95 if node.value is not None:
96 result.append(" ")
97 value_parts = self.visit(node.value)
98 result.extend(value_parts)
99
100 return result
101
102 def visit_CmdOptionValue(self, node):
103 return [node.value]
104
105 def visit_ParamsOption(self, node):
106 value_parts = self.visit(node.value)
107 return value_parts
108
109 def visit_PyVarExpansion(self, node):
110 return [node.raw_value]
111
112 def visit_PyDict(self, node):
113 result = ["{"]
114
115 for i, item in enumerate(node.items):
116 if i > 0:
117 result.append(", ")
118 item_parts = self.visit(item)
119 result.extend(item_parts)
120
121 result.append("}")
122 return result
123
124 def visit_PyDictItem(self, node):
125 result = self.visit(node.key) # key parts
126 result.append(": ")
127 value_parts = self.visit(node.value)
128 result.extend(value_parts)
129 return result
130
131 def visit_PyDictKey(self, node):
132 return [node.key_value]
133
134 def visit_PyScalarValue(self, node):
135 return [node.raw_value]
136
137 def visit_PyTuple(self, node):
138 result = ["("]
139
140 for i, item in enumerate(node.items):
141 if i > 0:
142 result.append(", ")
143 item_parts = self.visit(item)
144 result.extend(item_parts)
145
146 result.append(")")
147 return result
148
149 def visit_PyList(self, node):
150 result = ["["]
151
152 for i, item in enumerate(node.items):
153 if i > 0:
154 result.append(", ")
155 item_parts = self.visit(item)
156 result.extend(item_parts)
157
158 result.append("]")
159 return result