Skip to content

Commit de6c5dd

Browse files
committed
compile: symtable: implement list/set/dict/generator comprehensions
1 parent 940b430 commit de6c5dd

File tree

4 files changed

+1336
-92
lines changed

4 files changed

+1336
-92
lines changed

compile/make_symtable_test.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
inp = [
2121
('''1''', "eval"),
2222
('''a*b*c''', "eval"),
23+
# Functions
2324
('''def fn(): pass''', "exec"),
2425
('''def fn(a,b):\n e=1\n return a*b*c*d*e''', "exec"),
2526
('''def fn(a,b):\n def nested(c,d):\n return a*b*c*d*e''', "exec"),
@@ -48,6 +49,27 @@ def outer():
4849
def inner():
4950
nonlocal x
5051
x = 2''', "exec", SyntaxError),
52+
# List Comp
53+
('''[ x for x in xs ]''', "exec"),
54+
('''[ x+y for x in xs for y in ys ]''', "exec"),
55+
('''[ x+y+z for x in xs if x if y if z if r for y in ys if x if y if z if p for z in zs if x if y if z if q]''', "exec"),
56+
('''[ x+y for x in [ x for x in xs ] ]''', "exec"),
57+
('''[ x for x in xs ]\n[ y for y in ys ]''', "exec"),
58+
# Generator expr
59+
('''( x for x in xs )''', "exec"),
60+
('''( x+y for x in xs for y in ys )''', "exec"),
61+
('''( x+y+z for x in xs if x if y if z if r for y in ys if x if y if z if p for z in zs if x if y if z if q)''', "exec"),
62+
('''( x+y for x in ( x for x in xs ) )''', "exec"),
63+
# Set comp
64+
('''{ x for x in xs }''', "exec"),
65+
('''{ x+y for x in xs for y in ys }''', "exec"),
66+
('''{ x+y+z for x in xs if x if y if z if r for y in ys if x if y if z if p for z in zs if x if y if z if q}''', "exec"),
67+
('''{ x+y for x in { x for x in xs } }''', "exec"),
68+
# Dict comp
69+
('''{ x:1 for x in xs }''', "exec"),
70+
('''{ x+y:1 for x in xs for y in ys }''', "exec"),
71+
('''{ x+y+z:1 for x in xs if x if y if z if r for y in ys if x if y if z if p for z in zs if x if y if z if q}''', "exec"),
72+
('''{ x+y:k for k, x in { x:1 for x in xs } }''', "exec"),
5173
# FIXME need with x as y
5274
]
5375

@@ -137,21 +159,13 @@ def dump_symtable(st):
137159
#out += 'ImportStar:%s,\n' % dump_bool(st.has_import_star()) # Return True if the block uses a starred from-import.
138160
out += 'Varnames:%s,\n' % dump_strings(st._table.varnames)
139161
out += 'Symbols: Symbols{\n'
140-
children = dict()
141162
for name in sorted(st.get_identifiers()):
142163
s = st.lookup(name)
143164
out += '"%s":%s,\n' % (name, dump_symbol(s))
144-
ns = s.get_namespaces()
145-
if len(ns) == 0:
146-
pass
147-
elif len(ns) == 1:
148-
children[name] = ns[0]
149-
else:
150-
raise AssertionError("More than one namespace")
151165
out += '},\n'
152-
out += 'Children:map[string]*SymTable{\n'
153-
for name, symtable in sorted(children.items()):
154-
out += '"%s":%s,\n' % (name, dump_symtable(symtable))
166+
out += 'Children:Children{\n'
167+
for symtable in st.get_children():
168+
out += '%s,\n' % dump_symtable(symtable)
155169
out += '},\n'
156170
out += "}"
157171
return out

compile/symtable.go

+59-39
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package compile
1111

1212
import (
13+
"fmt"
1314
"log"
1415
"strings"
1516

@@ -84,9 +85,9 @@ type Symbol struct {
8485

8586
type Symbols map[string]Symbol
8687

87-
func NewSymbols() Symbols {
88-
return make(Symbols)
89-
}
88+
type Children []*SymTable
89+
90+
type LookupChild map[ast.Ast]*SymTable
9091

9192
type SymTable struct {
9293
Type BlockType // 'class', 'module', and 'function'
@@ -104,13 +105,14 @@ type SymTable struct {
104105
// col_offset int // offset of first line of block
105106
// opt_lineno int // lineno of last exec or import *
106107
// opt_col_offset int // offset of last exec or import *
107-
// tmpname int // counter for listcomp temp vars
108-
109-
Symbols Symbols
110-
Global *SymTable // symbol table entry for module
111-
Parent *SymTable
112-
Varnames []string // list of function parameters
113-
Children map[string]*SymTable // Child SymTables keyed by symbol name
108+
TmpName int // counter for listcomp temp vars
109+
110+
Symbols Symbols
111+
Global *SymTable // symbol table entry for module
112+
Parent *SymTable
113+
Varnames []string // list of function parameters
114+
Children Children // Child SymTables
115+
LookupChild LookupChild // Child symtables keyed by ast
114116
}
115117

116118
// Make a new top symbol table from the ast supplied
@@ -127,11 +129,12 @@ func NewSymTable(Ast ast.Ast) *SymTable {
127129
// Make a new symbol table from the ast supplied of the given type
128130
func newSymTable(Type BlockType, Name string, parent *SymTable) *SymTable {
129131
st := &SymTable{
130-
Type: Type,
131-
Name: Name,
132-
Parent: parent,
133-
Symbols: NewSymbols(),
134-
Children: make(map[string]*SymTable),
132+
Type: Type,
133+
Name: Name,
134+
Parent: parent,
135+
Symbols: make(Symbols),
136+
Children: make(Children, 0),
137+
LookupChild: make(LookupChild),
135138
}
136139
if parent == nil {
137140
st.Global = st
@@ -142,6 +145,15 @@ func newSymTable(Type BlockType, Name string, parent *SymTable) *SymTable {
142145
return st
143146
}
144147

148+
// Make a new symtable and add it to parent
149+
func newSymTableBlock(Ast ast.Ast, Type BlockType, Name string, parent *SymTable) *SymTable {
150+
stNew := newSymTable(Type, Name, parent)
151+
parent.Children = append(parent.Children, stNew)
152+
parent.LookupChild[Ast] = stNew
153+
// FIXME set stNew.Lineno
154+
return stNew
155+
}
156+
145157
// Parse the ast into the symbol table
146158
func (st *SymTable) Parse(Ast ast.Ast) {
147159
ast.Walk(Ast, func(Ast ast.Ast) bool {
@@ -199,10 +211,8 @@ func (st *SymTable) Parse(Ast ast.Ast) {
199211
st.AddDef(node.Name, defLocal)
200212
name := string(node.Name)
201213

202-
// Make a new symtable and add it to parent
203-
stNew := newSymTable(FunctionBlock, name, st)
204-
st.Children[name] = stNew
205-
// FIXME set stNew.Lineno
214+
// Make a new symtable
215+
stNew := newSymTableBlock(Ast, FunctionBlock, name, st)
206216

207217
// Walk the Decorators and Returns in this Symtable
208218
for _, expr := range node.DecoratorList {
@@ -292,34 +302,44 @@ func (st *SymTable) Parse(Ast ast.Ast) {
292302
})
293303
}
294304

295-
func (st *SymTable) parseComprehension(Ast ast.Ast, scope_name ast.Identifier, generators []ast.Comprehension, elt ast.Expr, value ast.Expr) {
296-
/* FIXME
297-
_, is_generator := Ast.(*ast.GeneratorExp)
298-
needs_tmp := !is_generator
305+
// make a new temporary name
306+
func (st *SymTable) newTmpName() {
307+
st.TmpName++
308+
id := ast.Identifier(fmt.Sprintf("_[%d]", st.TmpName))
309+
st.AddDef(id, defLocal)
310+
}
311+
312+
func (st *SymTable) parseComprehension(Ast ast.Ast, scopeName string, generators []ast.Comprehension, elt ast.Expr, value ast.Expr) {
313+
_, isGenerator := Ast.(*ast.GeneratorExp)
314+
needsTmp := !isGenerator
299315
outermost := generators[0]
300316
// Outermost iterator is evaluated in current scope
301317
st.Parse(outermost.Iter)
302318
// Create comprehension scope for the rest
303-
if scope_name == "" || !symtable_enter_block(st, scope_name, FunctionBlock, e, e.lineno, e.col_offset) {
304-
return 0
305-
}
306-
st.st_cur.ste_generator = is_generator
319+
stNew := newSymTableBlock(Ast, FunctionBlock, scopeName, st)
320+
stNew.Generator = isGenerator
307321
// Outermost iter is received as an argument
308-
id := ast.Identifier(fmt.Sprintf(".%d", pos))
309-
st.AddDef(id, defParam)
322+
id := ast.Identifier(fmt.Sprintf(".%d", 0))
323+
stNew.AddDef(id, defParam)
310324
// Allocate temporary name if needed
311-
if needs_tmp {
312-
symtable_new_tmpname(st)
325+
if needsTmp {
326+
stNew.newTmpName()
313327
}
314-
VISIT(st, expr, outermost.target)
315-
parseSeq(st, expr, outermost.ifs)
316-
parseSeq_tail(st, comprehension, generators, 1)
317-
if value {
318-
VISIT(st, expr, value)
328+
stNew.Parse(outermost.Target)
329+
for _, expr := range outermost.Ifs {
330+
stNew.Parse(expr)
319331
}
320-
VISIT(st, expr, elt)
321-
return symtable_exit_block(st, e)
322-
*/
332+
for _, comprehension := range generators[1:] {
333+
stNew.Parse(comprehension.Target)
334+
stNew.Parse(comprehension.Iter)
335+
for _, expr := range comprehension.Ifs {
336+
stNew.Parse(expr)
337+
}
338+
}
339+
if value != nil {
340+
stNew.Parse(value)
341+
}
342+
stNew.Parse(elt)
323343
}
324344

325345
const duplicateArgument = "duplicate argument %q in function definition"

0 commit comments

Comments
 (0)