-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
/
Copy pathcheck_for_inconsistent_pandas_namespace.py
140 lines (109 loc) · 4.25 KB
/
check_for_inconsistent_pandas_namespace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Check that test suite file doesn't use the pandas namespace inconsistently.
We check for cases of ``Series`` and ``pd.Series`` appearing in the same file
(likewise for other pandas objects).
This is meant to be run as a pre-commit hook - to run it manually, you can do:
pre-commit run inconsistent-namespace-usage --all-files
To automatically fixup a given file, you can pass `--replace`, e.g.
python scripts/check_for_inconsistent_pandas_namespace.py test_me.py --replace
though note that you may need to manually fixup some imports and that you will also
need the additional dependency `tokenize-rt` (which is left out from the pre-commit
hook so that it uses the same virtualenv as the other local ones).
The general structure is similar to that of some plugins from
https://github.com/asottile/pyupgrade .
"""
import argparse
import ast
from collections.abc import (
MutableMapping,
Sequence,
)
import sys
from typing import NamedTuple
ERROR_MESSAGE = (
"{path}:{lineno}:{col_offset}: "
"Found both '{prefix}.{name}' and '{name}' in {path}"
)
class OffsetWithNamespace(NamedTuple):
lineno: int
col_offset: int
namespace: str
class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
self.pandas_namespace: MutableMapping[OffsetWithNamespace, str] = {}
self.imported_from_pandas: set[str] = set()
def visit_Attribute(self, node: ast.Attribute) -> None:
if isinstance(node.value, ast.Name) and node.value.id in {"pandas", "pd"}:
offset_with_namespace = OffsetWithNamespace(
node.lineno, node.col_offset, node.value.id
)
self.pandas_namespace[offset_with_namespace] = node.attr
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module is not None and "pandas" in node.module:
self.imported_from_pandas.update(name.name for name in node.names)
self.generic_visit(node)
def replace_inconsistent_pandas_namespace(visitor: Visitor, content: str) -> str:
from tokenize_rt import (
reversed_enumerate,
src_to_tokens,
tokens_to_src,
)
tokens = src_to_tokens(content)
for n, i in reversed_enumerate(tokens):
offset_with_namespace = OffsetWithNamespace(i.offset[0], i.offset[1], i.src)
if (
offset_with_namespace in visitor.pandas_namespace
and visitor.pandas_namespace[offset_with_namespace]
in visitor.imported_from_pandas
):
# Replace `pd`
tokens[n] = i._replace(src="")
# Replace `.`
tokens[n + 1] = tokens[n + 1]._replace(src="")
new_src: str = tokens_to_src(tokens)
return new_src
def check_for_inconsistent_pandas_namespace(
content: str, path: str, *, replace: bool
) -> str | None:
tree = ast.parse(content)
visitor = Visitor()
visitor.visit(tree)
inconsistencies = visitor.imported_from_pandas.intersection(
visitor.pandas_namespace.values()
)
if not inconsistencies:
# No inconsistent namespace usage, nothing to replace.
return None
if not replace:
inconsistency = inconsistencies.pop()
lineno, col_offset, prefix = next(
key for key, val in visitor.pandas_namespace.items() if val == inconsistency
)
msg = ERROR_MESSAGE.format(
lineno=lineno,
col_offset=col_offset,
prefix=prefix,
name=inconsistency,
path=path,
)
sys.stdout.write(msg)
sys.exit(1)
return replace_inconsistent_pandas_namespace(visitor, content)
def main(argv: Sequence[str] | None = None) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("paths", nargs="*")
parser.add_argument("--replace", action="store_true")
args = parser.parse_args(argv)
for path in args.paths:
with open(path, encoding="utf-8") as fd:
content = fd.read()
new_content = check_for_inconsistent_pandas_namespace(
content, path, replace=args.replace
)
if not args.replace or new_content is None:
continue
with open(path, "w", encoding="utf-8") as fd:
fd.write(new_content)
if __name__ == "__main__":
main()