1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import sys
21
22import libcst as cst
23
24
25class _autostacklevel_warn:
26 def __init__(self, delta):
27 self.warnings = __import__("warnings")
28 self.delta = delta
29
30 def __getattr__(self, name: str):
31 return getattr(self.warnings, name)
32
33 def __dir__(self):
34 return dir(self.warnings)
35
36 def warn(self, message, category=None, stacklevel=1, source=None):
37 self.warnings.warn(message, category, stacklevel + self.delta, source)
38
39
40def fixup_decorator_warning_stack(func, delta: int = 2):
41 if func.__globals__.get("warnings") is sys.modules["warnings"]:
42 # Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to
43 # `warnings.warn` to ignore the decorator.
44 func.__globals__["warnings"] = _autostacklevel_warn(delta)
45
46
47class _TaskDecoratorRemover(cst.CSTTransformer):
48 def __init__(self, task_decorator_name: str) -> None:
49 self.decorators_to_remove: set[str] = {
50 "setup",
51 "teardown",
52 "task.skip_if",
53 "task.run_if",
54 task_decorator_name.strip("@"),
55 }
56
57 def _is_task_decorator(self, decorator_node: cst.Decorator) -> bool:
58 decorator_expr = decorator_node.decorator
59 if isinstance(decorator_expr, cst.Name):
60 return decorator_expr.value in self.decorators_to_remove
61 if isinstance(decorator_expr, cst.Attribute) and isinstance(decorator_expr.value, cst.Name):
62 return f"{decorator_expr.value.value}.{decorator_expr.attr.value}" in self.decorators_to_remove
63 if isinstance(decorator_expr, cst.Call):
64 return self._is_task_decorator(cst.Decorator(decorator=decorator_expr.func))
65 return False
66
67 def leave_FunctionDef(
68 self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
69 ) -> cst.FunctionDef:
70 new_decorators = [dec for dec in updated_node.decorators if not self._is_task_decorator(dec)]
71 if len(new_decorators) == len(updated_node.decorators):
72 return updated_node
73 return updated_node.with_changes(decorators=new_decorators)
74
75
76def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
77 """
78 Remove @task or similar decorators as well as @setup and @teardown.
79
80 :param python_source: The python source code
81 :param task_decorator_name: the decorator name
82 """
83 source_tree = cst.parse_module(python_source)
84 modified_tree = source_tree.visit(_TaskDecoratorRemover(task_decorator_name))
85 return modified_tree.code