Skip to content

Commit d931917

Browse files
committed
fix type check error
1 parent 796bd89 commit d931917

File tree

1 file changed

+51
-30
lines changed
  • aws_lambda_powertools/utilities/data_masking/provider

1 file changed

+51
-30
lines changed

aws_lambda_powertools/utilities/data_masking/provider/base.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import functools
44
import json
55
import re
6-
from typing import Any, Callable, Iterable
6+
from typing import Any, Callable
77

8+
# , Iterable
89
from aws_lambda_powertools.utilities.data_masking.constants import DATA_MASKING_STRING
910

1011
PRESERVE_CHARS = set("-_. ")
@@ -69,14 +70,14 @@ def decrypt(self, data, provider_options: dict | None = None, **encryption_conte
6970

7071
def erase(
7172
self,
72-
data,
73+
data: Any,
7374
dynamic_mask: bool | None = None,
7475
custom_mask: str | None = None,
7576
regex_pattern: str | None = None,
7677
mask_format: str | None = None,
7778
masking_rules: dict | None = None,
7879
**kwargs,
79-
) -> Iterable[str]:
80+
) -> str | dict | list | tuple | set:
8081
"""
8182
This method irreversibly erases data.
8283
@@ -85,47 +86,68 @@ def erase(
8586
8687
If the data to be erased is of an iterable type like `list`, `tuple`,
8788
or `set`, this method will return a new object of the same type as the
88-
input data but with each element replaced by the string "*****" or following one of the custom masks.
89+
input data but with each element masked according to the specified rules.
8990
"""
90-
result = DATA_MASKING_STRING
91-
92-
if data:
93-
if isinstance(data, str):
94-
if dynamic_mask:
95-
result = self._custom_erase(data, **kwargs)
96-
if custom_mask:
97-
result = self._pattern_mask(data, custom_mask)
98-
if regex_pattern and mask_format:
99-
result = self._regex_mask(data, regex_pattern, mask_format)
100-
elif isinstance(data, dict):
101-
if masking_rules:
102-
result = self._apply_masking_rules(data, masking_rules)
103-
elif isinstance(data, (list, tuple, set)):
104-
result = type(data)(
105-
self.erase(
106-
item,
107-
dynamic_mask=dynamic_mask,
108-
custom_mask=custom_mask,
109-
regex_pattern=regex_pattern,
110-
mask_format=mask_format,
111-
masking_rules=masking_rules,
112-
**kwargs,
113-
)
114-
for item in data
91+
result = None
92+
93+
# Handle empty or None data
94+
if not data:
95+
result = DATA_MASKING_STRING if isinstance(data, (str, bytes)) else data
96+
97+
# Handle string data
98+
elif isinstance(data, str):
99+
if regex_pattern and mask_format:
100+
result = self._regex_mask(data, regex_pattern, mask_format)
101+
elif custom_mask:
102+
result = self._pattern_mask(data, custom_mask)
103+
elif dynamic_mask:
104+
result = self._custom_erase(data, **kwargs)
105+
else:
106+
result = DATA_MASKING_STRING
107+
108+
# Handle dictionary data
109+
elif isinstance(data, dict):
110+
result = (
111+
self._apply_masking_rules(data, masking_rules)
112+
if masking_rules
113+
else {k: DATA_MASKING_STRING for k in data}
114+
)
115+
116+
# Handle iterable data (list, tuple, set)
117+
elif isinstance(data, (list, tuple, set)):
118+
masked_data = (
119+
self.erase(
120+
item,
121+
dynamic_mask=dynamic_mask,
122+
custom_mask=custom_mask,
123+
regex_pattern=regex_pattern,
124+
mask_format=mask_format,
125+
masking_rules=masking_rules,
126+
**kwargs,
115127
)
128+
for item in data
129+
)
130+
result = type(data)(masked_data)
131+
132+
# Default case
133+
else:
134+
result = DATA_MASKING_STRING
116135

117136
return result
118137

119138
def _apply_masking_rules(self, data: dict, masking_rules: dict) -> dict:
139+
"""Apply masking rules to dictionary data."""
120140
return {
121141
key: self.erase(str(value), **masking_rules[key]) if key in masking_rules else str(value)
122142
for key, value in data.items()
123143
}
124144

125145
def _pattern_mask(self, data: str, pattern: str) -> str:
146+
"""Apply pattern masking to string data."""
126147
return pattern[: len(data)] if len(pattern) >= len(data) else pattern
127148

128149
def _regex_mask(self, data: str, regex_pattern: str, mask_format: str) -> str:
150+
"""Apply regex masking to string data."""
129151
try:
130152
if regex_pattern not in _regex_cache:
131153
_regex_cache[regex_pattern] = re.compile(regex_pattern)
@@ -137,5 +159,4 @@ def _custom_erase(self, data: str, **kwargs) -> str:
137159
if not data:
138160
return ""
139161

140-
# Use join with list comprehension instead of building list incrementally
141162
return "".join("*" if char not in PRESERVE_CHARS else char for char in data)

0 commit comments

Comments
 (0)