235
235
< div class ="pytorch-left-menu-search ">
236
236
237
237
< div class ="version ">
238
- < a href ='https://pytorch.org/docs/versions.html '> main (2.1.0a0+git660a0d8 ) ▼</ a >
238
+ < a href ='https://pytorch.org/docs/versions.html '> main (2.1.0a0+git6308563 ) ▼</ a >
239
239
</ div >
240
240
241
241
@@ -1258,8 +1258,6 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
1258
1258
< span class ="k "> if</ span > < span class ="n "> pre_autograd</ span > < span class ="p "> :</ span >
1259
1259
< span class ="k "> assert</ span > < span class ="n "> aten_graph</ span > < span class ="p "> ,</ span > < span class ="s2 "> "pre_autograd=True can only be used when aten_graph=True"</ span >
1260
1260
< span class ="n "> f</ span > < span class ="o "> =</ span > < span class ="n "> innermost_fn</ span > < span class ="p "> (</ span > < span class ="n "> f</ span > < span class ="p "> )</ span >
1261
- < span class ="n "> call_to_inspect</ span > < span class ="o "> =</ span > < span class ="n "> f</ span > < span class ="o "> .</ span > < span class ="n "> forward</ span > < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> f</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Module</ span > < span class ="p "> )</ span > < span class ="k "> else</ span > < span class ="n "> f</ span >
1262
- < span class ="n "> original_signature</ span > < span class ="o "> =</ span > < span class ="n "> inspect</ span > < span class ="o "> .</ span > < span class ="n "> signature</ span > < span class ="p "> (</ span > < span class ="n "> call_to_inspect</ span > < span class ="p "> )</ span >
1263
1261
1264
1262
< span class ="k "> if</ span > < span class ="n "> functionalize</ span > < span class ="ow "> and</ span > < span class ="ow "> not</ span > < span class ="n "> aten_graph</ span > < span class ="p "> :</ span >
1265
1263
< span class ="k "> raise</ span > < span class ="n "> UserError</ span > < span class ="p "> (</ span >
@@ -1308,7 +1306,6 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
1308
1306
1309
1307
< span class ="n "> fake_mode</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span >
1310
1308
< span class ="n "> example_inputs</ span > < span class ="o "> =</ span > < span class ="p "> []</ span >
1311
- < span class ="n "> var_to_range_map</ span > < span class ="o "> =</ span > < span class ="p "> {}</ span >
1312
1309
1313
1310
< span class ="k "> def</ span > < span class ="nf "> dynamo_normalization_capturing_compiler</ span > < span class ="p "> (</ span >
1314
1311
< span class ="n "> gm</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> fx</ span > < span class ="o "> .</ span > < span class ="n "> GraphModule</ span > < span class ="p "> ,</ span > < span class ="n "> inner_example_inputs</ span >
@@ -1319,11 +1316,9 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
1319
1316
< span class ="p "> ),</ span > < span class ="s2 "> "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."</ span >
1320
1317
< span class ="n "> graph</ span > < span class ="o "> =</ span > < span class ="n "> gm</ span >
1321
1318
1322
- < span class ="k "> nonlocal</ span > < span class ="n "> fake_mode</ span > < span class ="p "> ,</ span > < span class ="n "> example_inputs</ span > < span class =" p " > , </ span > < span class =" n " > var_to_range_map </ span >
1319
+ < span class ="k "> nonlocal</ span > < span class ="n "> fake_mode</ span > < span class ="p "> ,</ span > < span class ="n "> example_inputs</ span >
1323
1320
< span class ="n "> fake_mode</ span > < span class ="o "> =</ span > < span class ="n "> _guards</ span > < span class ="o "> .</ span > < span class ="n "> detect_fake_mode</ span > < span class ="p "> (</ span > < span class ="n "> inner_example_inputs</ span > < span class ="p "> )</ span >
1324
1321
< span class ="n "> example_inputs</ span > < span class ="o "> =</ span > < span class ="n "> inner_example_inputs</ span >
1325
- < span class ="k "> if</ span > < span class ="n "> fake_mode</ span > < span class ="ow "> and</ span > < span class ="n "> fake_mode</ span > < span class ="o "> .</ span > < span class ="n "> shape_env</ span > < span class ="p "> :</ span >
1326
- < span class ="n "> var_to_range_map</ span > < span class ="o "> =</ span > < span class ="n "> fake_mode</ span > < span class ="o "> .</ span > < span class ="n "> shape_env</ span > < span class ="o "> .</ span > < span class ="n "> var_to_range</ span >
1327
1322
1328
1323
< span class ="k "> def</ span > < span class ="nf "> result_capturing_wrapper</ span > < span class ="p "> (</ span > < span class ="o "> *</ span > < span class ="n "> graph_inputs</ span > < span class ="p "> ):</ span >
1329
1324
< span class ="k "> nonlocal</ span > < span class ="n "> graph_captured_result</ span >
@@ -1366,7 +1361,7 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
1366
1361
< span class ="n "> dim_constraints</ span > < span class ="o "> =</ span > < span class ="n "> shape_env</ span > < span class ="o "> .</ span > < span class ="n "> dim_constraints</ span >
1367
1362
< span class ="k "> assert</ span > < span class ="n "> dim_constraints</ span > < span class ="ow "> is</ span > < span class ="ow "> not</ span > < span class ="kc "> None</ span >
1368
1363
< span class ="n "> dim_constraints</ span > < span class ="o "> .</ span > < span class ="n "> solve</ span > < span class ="p "> ()</ span >
1369
- < span class ="n "> msg</ span > < span class ="o "> =</ span > < span class ="n "> dim_constraints</ span > < span class ="o "> .</ span > < span class ="n "> prettify_results</ span > < span class ="p "> (</ span > < span class ="n "> original_signature </ span > < span class ="p "> )</ span >
1364
+ < span class ="n "> msg</ span > < span class ="o "> =</ span > < span class ="n "> dim_constraints</ span > < span class ="o "> .</ span > < span class ="n "> prettify_results</ span > < span class ="p "> (</ span > < span class ="n "> inspect </ span > < span class ="o " > . </ span > < span class =" n " > signature </ span > < span class =" p "> ( </ span > < span class =" n " > f </ span > < span class =" p " > ) )</ span >
1370
1365
< span class ="k "> if</ span > < span class ="n "> constraint_violation_error</ span > < span class ="p "> :</ span >
1371
1366
< span class ="n "> constraint_violation_error</ span > < span class ="o "> .</ span > < span class ="n "> args</ span > < span class ="o "> =</ span > < span class ="p "> (</ span >
1372
1367
< span class ="n "> constraint_violation_error</ span > < span class ="o "> .</ span > < span class ="n "> args</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span > < span class ="o "> +</ span > < span class ="n "> msg</ span > < span class ="p "> ,</ span >
@@ -1491,13 +1486,22 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
1491
1486
< span class ="k "> if</ span > < span class ="n "> constraints</ span >
1492
1487
< span class ="k "> else</ span > < span class ="kc "> None</ span >
1493
1488
< span class ="p "> )</ span >
1494
- < span class ="n "> new_graph</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "inline_constraints"</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="p "> {</ span >
1495
- < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "val"</ span > < span class ="p "> ]</ span > < span class ="o "> .</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> expr</ span > < span class ="p "> :</ span > < span class ="n "> var_to_range_map</ span > < span class ="p "> [</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "val"</ span > < span class ="p "> ]</ span > < span class ="o "> .</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> expr</ span > < span class ="p "> ]</ span >
1496
- < span class ="k "> for</ span > < span class ="n "> node</ span > < span class ="ow "> in</ span > < span class ="n "> new_graph</ span > < span class ="o "> .</ span > < span class ="n "> graph</ span > < span class ="o "> .</ span > < span class ="n "> nodes</ span >
1497
- < span class ="k "> if</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> op</ span > < span class ="o "> !=</ span > < span class ="s2 "> "placeholder"</ span > < span class ="ow "> and</ span > < span class ="s2 "> "val"</ span > < span class ="ow "> in</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span >
1498
- < span class ="c1 "> # Find constraints frome unbacked symints</ span >
1499
- < span class ="ow "> and</ span > < span class ="n "> re</ span > < span class ="o "> .</ span > < span class ="n "> match</ span > < span class ="p "> (</ span > < span class ="sa "> r</ span > < span class ="s2 "> "^i\d+$"</ span > < span class ="p "> ,</ span > < span class ="nb "> str</ span > < span class ="p "> (</ span > < span class ="n "> node</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "val"</ span > < span class ="p "> ]))</ span >
1500
- < span class ="p "> }</ span >
1489
+
1490
+ < span class ="k "> if</ span > < span class ="p "> (</ span > < span class ="n "> shape_env</ span > < span class ="o "> :=</ span > < span class ="nb "> getattr</ span > < span class ="p "> (</ span > < span class ="n "> fake_mode</ span > < span class ="p "> ,</ span > < span class ="s2 "> "shape_env"</ span > < span class ="p "> ,</ span > < span class ="kc "> None</ span > < span class ="p "> ))</ span > < span class ="ow "> is</ span > < span class ="ow "> not</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
1491
+ < span class ="n "> dim_constraints</ span > < span class ="o "> =</ span > < span class ="n "> shape_env</ span > < span class ="o "> .</ span > < span class ="n "> dim_constraints</ span >
1492
+ < span class ="k "> assert</ span > < span class ="n "> dim_constraints</ span > < span class ="ow "> is</ span > < span class ="ow "> not</ span > < span class ="kc "> None</ span >
1493
+ < span class ="n "> dim_constraints</ span > < span class ="o "> .</ span > < span class ="n "> solve</ span > < span class ="p "> ()</ span >
1494
+ < span class ="n "> log</ span > < span class ="o "> .</ span > < span class ="n "> warning</ span > < span class ="p "> (</ span >
1495
+ < span class ="s2 "> "Summary of dimension constraints:</ span > < span class ="si "> %s</ span > < span class ="s2 "> "</ span > < span class ="p "> ,</ span >
1496
+ < span class ="n "> dim_constraints</ span > < span class ="o "> .</ span > < span class ="n "> prettify_results</ span > < span class ="p "> (</ span > < span class ="n "> inspect</ span > < span class ="o "> .</ span > < span class ="n "> signature</ span > < span class ="p "> (</ span > < span class ="n "> f</ span > < span class ="p "> )),</ span >
1497
+ < span class ="p "> )</ span >
1498
+
1499
+ < span class ="c1 "> # Inline constraints added by users correspond to unbacked symbols in shape_env,</ span >
1500
+ < span class ="n "> new_graph</ span > < span class ="o "> .</ span > < span class ="n "> meta</ span > < span class ="p "> [</ span > < span class ="s2 "> "inline_constraints"</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="p "> {</ span >
1501
+ < span class ="n "> k</ span > < span class ="p "> :</ span > < span class ="n "> v</ span >
1502
+ < span class ="k "> for</ span > < span class ="n "> k</ span > < span class ="p "> ,</ span > < span class ="n "> v</ span > < span class ="ow "> in</ span > < span class ="n "> shape_env</ span > < span class ="o "> .</ span > < span class ="n "> var_to_range</ span > < span class ="o "> .</ span > < span class ="n "> items</ span > < span class ="p "> ()</ span >
1503
+ < span class ="k "> if</ span > < span class ="n "> re</ span > < span class ="o "> .</ span > < span class ="n "> match</ span > < span class ="p "> (</ span > < span class ="sa "> r</ span > < span class ="s2 "> "^[if]\d+$"</ span > < span class ="p "> ,</ span > < span class ="nb "> str</ span > < span class ="p "> (</ span > < span class ="n "> k</ span > < span class ="p "> ))</ span >
1504
+ < span class ="p "> }</ span >
1501
1505
1502
1506
< span class ="k "> def</ span > < span class ="nf "> signature_to_fullargspec</ span > < span class ="p "> (</ span > < span class ="n "> sig</ span > < span class ="p "> :</ span > < span class ="n "> inspect</ span > < span class ="o "> .</ span > < span class ="n "> Signature</ span > < span class ="p "> ):</ span >
1503
1507
< span class ="c1 "> # Get a list of Parameter objects from the Signature object</ span >
@@ -1541,7 +1545,10 @@ <h1>Source code for torch._dynamo.eval_frame</h1><div class="highlight"><pre>
1541
1545
1542
1546
< span class ="c1 "> # Make dynamo graph to have same input/output spec as user code</ span >
1543
1547
< span class ="k "> def</ span > < span class ="nf "> argument_names</ span > < span class ="p "> (</ span > < span class ="n "> f</ span > < span class ="p "> :</ span > < span class ="n "> Callable</ span > < span class ="p "> [</ span > < span class ="o "> ...</ span > < span class ="p "> ,</ span > < span class ="n "> Any</ span > < span class ="p "> ],</ span > < span class ="o "> *</ span > < span class ="n "> args</ span > < span class ="p "> ,</ span > < span class ="o "> **</ span > < span class ="n "> kwargs</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> List</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ]:</ span >
1544
- < span class ="n "> fullargspec</ span > < span class ="o "> =</ span > < span class ="n "> signature_to_fullargspec</ span > < span class ="p "> (</ span > < span class ="n "> original_signature</ span > < span class ="p "> )</ span >
1548
+ < span class ="n "> call_to_inspect</ span > < span class ="o "> =</ span > < span class ="n "> f</ span > < span class ="o "> .</ span > < span class ="n "> forward</ span > < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> f</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Module</ span > < span class ="p "> )</ span > < span class ="k "> else</ span > < span class ="n "> f</ span >
1549
+
1550
+ < span class ="n "> sig</ span > < span class ="o "> =</ span > < span class ="n "> inspect</ span > < span class ="o "> .</ span > < span class ="n "> signature</ span > < span class ="p "> (</ span > < span class ="n "> call_to_inspect</ span > < span class ="p "> )</ span >
1551
+ < span class ="n "> fullargspec</ span > < span class ="o "> =</ span > < span class ="n "> signature_to_fullargspec</ span > < span class ="p "> (</ span > < span class ="n "> sig</ span > < span class ="p "> )</ span >
1545
1552
1546
1553
< span class ="c1 "> # 1. Map `args` 1-to-1 to positional arguments in original signature.</ span >
1547
1554
< span class ="n "> input_strs</ span > < span class ="o "> =</ span > < span class ="n "> fullargspec</ span > < span class ="o "> .</ span > < span class ="n "> args</ span > < span class ="p "> [:</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> args</ span > < span class ="p "> )]</ span >
0 commit comments