Skip to content

Commit 8a68827

Browse files
justinchubypytorchmergebot
authored andcommitted
[BE] Enable ruff's UP rules and autoformat dynamo / functorch and refs (pytorch#105432)
Pull Request resolved: pytorch#105432 Approved by: https://github.com/ezyang
1 parent 88f1197 commit 8a68827

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+188
-242
lines changed

functorch/benchmarks/operator_authoring.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def out_setup(n):
113113
def test_backwards(make_args, nnc=nnc_add, aten=torch.add):
114114
def backwards_setup(n):
115115
args = make_args(n)
116-
(grad_var,) = [a for a in args if a.requires_grad]
116+
(grad_var,) = (a for a in args if a.requires_grad)
117117
aten(*args).sum().backward()
118118
correct = grad_var.grad.clone()
119119
grad_var.grad.zero_()

functorch/einops/rearrange.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,16 @@ class dims."""
108108

109109
custom_rearrange_callable_name = "do_rearrange"
110110
custom_rearrange_callable_code = (
111-
(
112-
f"def {custom_rearrange_callable_name}(tensor):\n"
113-
f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
114-
+ (
115-
"".join(f" {dim}.size = {length}\n" for (dim, length) in specified_lengths)
116-
if specified_lengths else ""
117-
)
118-
+ f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
119-
+ (
120-
f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
121-
if anon_dims else " return tensor\n"
122-
)
111+
f"def {custom_rearrange_callable_name}(tensor):\n"
112+
f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
113+
+ (
114+
"".join(f" {dim}.size = {length}\n" for (dim, length) in specified_lengths)
115+
if specified_lengths else ""
116+
)
117+
+ f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
118+
+ (
119+
f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
120+
if anon_dims else " return tensor\n"
123121
)
124122
)
125123

functorch/examples/compilation/linear_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def bench(f, iters=100, warmup=10):
1818
begin = time.time()
1919
for _ in range(iters):
2020
f()
21-
print((time.time() - begin))
21+
print(time.time() - begin)
2222

2323

2424
class Foo(nn.Module):

functorch/examples/maml_omniglot/support/omniglot_loaders.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def find_classes(root_dir):
121121
r = root.split('/')
122122
lr = len(r)
123123
retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
124-
print("== Found %d items " % len(retour))
124+
print(f"== Found {len(retour)} items ")
125125
return retour
126126

127127

@@ -130,7 +130,7 @@ def index_classes(items):
130130
for i in items:
131131
if i[1] not in idx:
132132
idx[i[1]] = len(idx)
133-
print("== Found %d classes" % len(idx))
133+
print(f"== Found {len(idx)} classes")
134134
return idx
135135

136136

@@ -276,10 +276,10 @@ def load_data_cache(self, data_pack):
276276
x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize)
277277
y_qrys = np.array(y_qrys).astype(int).reshape(self.batchsz, querysz)
278278

279-
x_spts, y_spts, x_qrys, y_qrys = [
279+
x_spts, y_spts, x_qrys, y_qrys = (
280280
torch.from_numpy(z).to(self.device) for z in
281281
[x_spts, y_spts, x_qrys, y_qrys]
282-
]
282+
)
283283

284284
data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
285285

functorch/op_analysis/gen_data.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def gen_data(special_op_lists, analysis_name):
2323
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
2424
noncomposite_ops = all_ops - composite_ops
2525

26-
ops = yaml.load(open('../../aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
26+
ops = yaml.load(open('../../aten/src/ATen/native/native_functions.yaml').read(), Loader=yaml.CLoader)
2727

2828
annotated_ops = {a.strip(): b.strip() for a, b in list(csv.reader(open('annotated_ops')))}
2929
from collections import defaultdict
@@ -132,19 +132,19 @@ def remove_prefix(input_string, prefix):
132132

133133

134134
if True:
135-
with open('run_ops.txt', 'r') as f:
135+
with open('run_ops.txt') as f:
136136
opinfo_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
137-
with open('count_ops.txt', 'r') as f:
137+
with open('count_ops.txt') as f:
138138
opinfo_counts = [i.strip() for i in f.readlines()]
139139
opinfo_counts = defaultdict(int, dict(zip(opinfo_ops, opinfo_counts)))
140140

141141
def count_fn(x):
142142
return opinfo_counts[x['full_name']]
143143

144-
with open('run_decompositions.txt', 'r') as f:
144+
with open('run_decompositions.txt') as f:
145145
decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
146146

147-
with open('public_api', 'r') as f:
147+
with open('public_api') as f:
148148
ref_api = [i.strip() for i in f.readlines()]
149149

150150
def has_ref_impl(x):

test/dynamo/test_autograd_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def backward(ctx, grad_output):
207207

208208
class ModuleWithGradFunc(torch.nn.Module):
209209
def __init__(self, func):
210-
super(ModuleWithGradFunc, self).__init__()
210+
super().__init__()
211211
self.f = func.apply
212212

213213
def forward(self, x):
@@ -336,7 +336,7 @@ def backward(ctx, grad_output):
336336

337337
class MyMod(torch.nn.Module):
338338
def __init__(self):
339-
super(MyMod, self).__init__()
339+
super().__init__()
340340
self.gamma = torch.nn.Parameter(torch.rand([4, 128, 32, 32]))
341341

342342
def forward(self, x):

test/dynamo/test_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
class ToyModel(torch.nn.Module):
1313
def __init__(self):
14-
super(ToyModel, self).__init__()
14+
super().__init__()
1515
self.linear = torch.nn.Linear(10, 10)
1616
self.relu = torch.nn.ReLU()
1717

test/dynamo/test_logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def throw(x):
157157
def test_ddp_graphs(self, records):
158158
class ToyModel(torch.nn.Module):
159159
def __init__(self):
160-
super(ToyModel, self).__init__()
160+
super().__init__()
161161
self.layers = torch.nn.Sequential(
162162
torch.nn.Linear(1024, 1024),
163163
torch.nn.Linear(1024, 1024),

test/dynamo/test_misc.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ def fn(a, b):
822822
v2 = torch.randn((10, 10))
823823
correct = fn(v1, v2)
824824
cnts = torch._dynamo.testing.CompileCounter()
825-
opt_fn = torch._dynamo.optimize((cnts))(fn)
825+
opt_fn = torch._dynamo.optimize(cnts)(fn)
826826
self.assertEqual(opt_fn(v1, v2), correct)
827827
self.assertEqual(cnts.frame_count, 1)
828828
self.assertEqual(cnts.op_count, 3)
@@ -836,7 +836,7 @@ def fn(a, b):
836836
v2 = torch.randn((10, 10))
837837
correct = fn(v1, v2)
838838
cnts = torch._dynamo.testing.CompileCounter()
839-
opt_fn = torch._dynamo.optimize((cnts))(fn)
839+
opt_fn = torch._dynamo.optimize(cnts)(fn)
840840
self.assertEqual(opt_fn(v1, v2), correct)
841841
self.assertEqual(cnts.frame_count, 1)
842842
self.assertEqual(cnts.op_count, 2)
@@ -2201,7 +2201,7 @@ def fn():
22012201
def fn():
22022202
foo.bar(1, 2, 3)
22032203
{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))}
2204-
l = [{str(' ').join('x' + str(i) + ',' for i in range(1 << 9))}]
2204+
l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}]
22052205
"""
22062206
locals = {}
22072207
exec(fn_str, {}, locals)
@@ -3086,7 +3086,7 @@ def foo(self, memo=None, prefix="", remove_duplicate=False):
30863086
memo=memo, prefix=prefix, remove_duplicate=remove_duplicate
30873087
):
30883088
for pn, p in self.named_parameters():
3089-
fpn = "%s.%s" % (mn, pn) if mn else pn
3089+
fpn = f"{mn}.{pn}" if mn else pn
30903090
self.names.append(fpn)
30913091

30923092
# Test plain recurse
@@ -5031,11 +5031,11 @@ def test_compute_exception_table_nested(self):
50315031
(15, 16, 7),
50325032
(17, 17, 6),
50335033
]
5034-
self.assertEquals(len(tab), len(expected))
5034+
self.assertEqual(len(tab), len(expected))
50355035
for entry, exp in zip(tab, expected):
5036-
self.assertEquals(entry.start, exp[0] * 2)
5037-
self.assertEquals(entry.end, exp[1] * 2)
5038-
self.assertEquals(entry.target, exp[2] * 2)
5036+
self.assertEqual(entry.start, exp[0] * 2)
5037+
self.assertEqual(entry.end, exp[1] * 2)
5038+
self.assertEqual(entry.target, exp[2] * 2)
50395039

50405040
@skipIfNotPy311
50415041
def test_remove_dead_code_with_exn_table_entries(self):
@@ -5059,17 +5059,17 @@ def test_remove_dead_code_with_exn_table_entries(self):
50595059
)
50605060
bytecode_transformation.propagate_inst_exn_table_entries(insts)
50615061
insts = bytecode_analysis.remove_dead_code(insts)
5062-
self.assertEquals(len(insts), 5)
5062+
self.assertEqual(len(insts), 5)
50635063
self.assertNotIn(exn_start, insts)
50645064
self.assertNotIn(exn_end, insts)
50655065
self.assertIn(target2, insts)
50665066
self.assertIn(target3, insts)
50675067
bytecode_transformation.update_offsets(insts)
50685068
tab = bytecode_transformation.compute_exception_table(insts)
5069-
self.assertEquals(len(tab), 1)
5070-
self.assertEquals(tab[0].start, 2)
5071-
self.assertEquals(tab[0].end, 4)
5072-
self.assertEquals(tab[0].target, 6)
5069+
self.assertEqual(len(tab), 1)
5070+
self.assertEqual(tab[0].start, 2)
5071+
self.assertEqual(tab[0].end, 4)
5072+
self.assertEqual(tab[0].target, 6)
50735073

50745074
def test_unhandled_exception_in_dynamo(self):
50755075
# traceback.format_exc() approximates an unhandled exception
@@ -5756,7 +5756,7 @@ def guard(L):
57565756
def test_dynamo_compiling_fake_tensor_to_vararg_int(self):
57575757
class MyModule(torch.nn.Module):
57585758
def __init__(self):
5759-
super(MyModule, self).__init__()
5759+
super().__init__()
57605760

57615761
def forward(self, x):
57625762
# use numpy int so it's wrapped as fake tensor in dynamo
@@ -5775,7 +5775,7 @@ def forward(self, x):
57755775
def test_scalar_tensor_is_equivalent_to_symint_argument(self):
57765776
class GumbelTopKSampler(torch.nn.Module):
57775777
def __init__(self, T, k):
5778-
super(GumbelTopKSampler, self).__init__()
5778+
super().__init__()
57795779
self.T = torch.nn.Parameter(
57805780
torch.tensor(T, dtype=torch.float32), requires_grad=False
57815781
)
@@ -5802,7 +5802,7 @@ def forward(self, logits):
58025802
def test_scalar_tensor_is_equivalent_to_symint_list_argument(self):
58035803
class Jitter(torch.nn.Module):
58045804
def __init__(self, jitter_val):
5805-
super(Jitter, self).__init__()
5805+
super().__init__()
58065806
self.jitter_val = jitter_val
58075807

58085808
def roll_tensor(self, input):
@@ -5987,7 +5987,7 @@ def _prepare_for_translation_validation(self):
59875987

59885988
# Z3 symbols.
59895989
[validator.add_var(s, int) for s in (s0, s1, s2)]
5990-
z0, z1, z2 = [validator.z3var(s) for s in (s0, s1, s2)]
5990+
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
59915991

59925992
return (s0, s1, s2), (z0, z1, z2), validator
59935993

test/dynamo/test_modules.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -762,21 +762,21 @@ def forward(self, x):
762762

763763
class ConvCallSuperForwardDirectly(torch.nn.Conv1d):
764764
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
765-
super(ConvCallSuperForwardDirectly, self).__init__(
765+
super().__init__(
766766
in_channels=in_channels,
767767
out_channels=out_channels,
768768
kernel_size=kernel_size,
769769
**kwargs,
770770
)
771771

772772
def forward(self, inputs, mask=None):
773-
outputs = super(ConvCallSuperForwardDirectly, self).forward(inputs)
773+
outputs = super().forward(inputs)
774774
return outputs
775775

776776

777777
class ConvTransposeCallSuperForwardDirectly(torch.nn.ConvTranspose2d):
778778
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
779-
super(ConvTransposeCallSuperForwardDirectly, self).__init__(
779+
super().__init__(
780780
in_channels=in_channels,
781781
out_channels=out_channels,
782782
kernel_size=kernel_size,
@@ -785,7 +785,7 @@ def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
785785

786786
def forward(self, x):
787787
if x.numel() > 0:
788-
return super(ConvTransposeCallSuperForwardDirectly, self).forward(x)
788+
return super().forward(x)
789789
output_shape = [
790790
((i - 1) * d - 2 * p + (di * (k - 1) + 1) + op)
791791
for i, p, di, k, d, op in zip(
@@ -923,7 +923,7 @@ def forward(self, x):
923923
class SequentialWithDuplicatedModule(torch.nn.Module):
924924
# Sequential module(self.layer) contains three duplicated ReLU module.
925925
def __init__(self):
926-
super(SequentialWithDuplicatedModule, self).__init__()
926+
super().__init__()
927927
self.relu = torch.nn.ReLU()
928928
self.layer = torch.nn.Sequential(
929929
torch.nn.Linear(10, 20),
@@ -940,7 +940,7 @@ def forward(self, x):
940940

941941
class SequentialWithDuplicatedModule2(torch.nn.Module):
942942
def __init__(self):
943-
super(SequentialWithDuplicatedModule2, self).__init__()
943+
super().__init__()
944944
self.relu = torch.nn.ReLU()
945945
self.layer = torch.nn.Sequential(
946946
collections.OrderedDict(

0 commit comments

Comments
 (0)