From 515de2a2ae8e466a0b3fcefceeb7a6976846eac5 Mon Sep 17 00:00:00 2001 From: wetor Date: Sat, 14 Oct 2023 01:31:15 +0800 Subject: [PATCH] compile,py: fix closure and decorator --- compile/compile.go | 15 +- py/code.go | 61 ++++---- vm/tests/class.py | 23 ++- vm/tests/decorators.py | 327 +++++++++++++++++++++++++++++++++++++++++ vm/tests/functions.py | 38 +++-- vm/tests/libtest.py | 57 +++++++ 6 files changed, 466 insertions(+), 55 deletions(-) create mode 100644 vm/tests/decorators.py create mode 100644 vm/tests/libtest.py diff --git a/compile/compile.go b/compile/compile.go index 76d46c1a..fd8d7a2b 100644 --- a/compile/compile.go +++ b/compile/compile.go @@ -213,6 +213,8 @@ func (c *compiler) compileAst(Ast ast.Ast, filename string, futureFlags int, don case *ast.Suite: panic("suite should not be possible") case *ast.Lambda: + code.Argcount = int32(len(node.Args.Args)) + code.Kwonlyargcount = int32(len(node.Args.Kwonlyargs)) // Make None the first constant as lambda can't have a docstring c.Const(py.None) code.Name = "" @@ -220,6 +222,8 @@ func (c *compiler) compileAst(Ast ast.Ast, filename string, futureFlags int, don c.Expr(node.Body) valueOnStack = true case *ast.FunctionDef: + code.Argcount = int32(len(node.Args.Args)) + code.Kwonlyargcount = int32(len(node.Args.Kwonlyargs)) code.Name = string(node.Name) c.setQualname() c.Stmts(c.docString(node.Body, true)) @@ -299,6 +303,7 @@ func (c *compiler) compileAst(Ast ast.Ast, filename string, futureFlags int, don code.Stacksize = int32(c.OpCodes.StackDepth()) code.Nlocals = int32(len(code.Varnames)) code.Lnotab = string(c.OpCodes.Lnotab()) + code.InitCell2arg() return nil } @@ -479,7 +484,8 @@ func (c *compiler) makeClosure(code *py.Code, args uint32, child *compiler, qual if reftype == symtable.ScopeCell { arg = c.FindId(name, c.Code.Cellvars) } else { /* (reftype == FREE) */ - arg = c.FindId(name, c.Code.Freevars) + // using CellAndFreeVars in closures requires skipping Cellvars + arg = len(c.Code.Cellvars) + c.FindId(name, c.Code.Freevars) } if arg < 0 { panic(fmt.Sprintf("compile: makeClosure: lookup %q in %q %v %v\nfreevars of %q: %v\n", name, c.SymTable.Name, reftype, arg, code.Name, code.Freevars)) @@ -1363,7 +1369,12 @@ func (c *compiler) NameOp(name string, ctx ast.ExprContext) { if op == 0 { panic("NameOp: Op not set") } - c.OpArg(op, c.Index(mangled, dict)) + i := c.Index(mangled, dict) + // using CellAndFreeVars in closures requires skipping Cellvars + if scope == symtable.ScopeFree { + i += uint32(len(c.Code.Cellvars)) + } + c.OpArg(op, i) } // Call a function which is already on the stack with n arguments already on the stack diff --git a/py/code.go b/py/code.go index 09027497..ad9a42af 100644 --- a/py/code.go +++ b/py/code.go @@ -112,8 +112,6 @@ func NewCode(argcount int32, kwonlyargcount int32, filename_ Object, name_ Object, firstlineno int32, lnotab_ Object) *Code { - var cell2arg []byte - // Type assert the objects consts := consts_.(Tuple) namesTuple := names_.(Tuple) @@ -154,7 +152,6 @@ func NewCode(argcount int32, kwonlyargcount int32, // return nil; // } - n_cellvars := len(cellvars) intern_strings(namesTuple) intern_strings(varnamesTuple) intern_strings(freevarsTuple) @@ -167,13 +164,40 @@ func NewCode(argcount int32, kwonlyargcount int32, } } } + + co := &Code{ + Argcount: argcount, + Kwonlyargcount: kwonlyargcount, + Nlocals: nlocals, + Stacksize: stacksize, + Flags: flags, + Code: code, + Consts: consts, + Names: names, + Varnames: varnames, + Freevars: freevars, + Cellvars: cellvars, + Filename: filename, + Name: name, + Firstlineno: firstlineno, + Lnotab: lnotab, + Weakreflist: nil, + } + co.InitCell2arg() + return co +} + +// Create mapping between cells and arguments if needed. +func (co *Code) InitCell2arg() { + var cell2arg []byte + n_cellvars := len(co.Cellvars) /* Create mapping between cells and arguments if needed. */ if n_cellvars != 0 { - total_args := argcount + kwonlyargcount - if flags&CO_VARARGS != 0 { + total_args := co.Argcount + co.Kwonlyargcount + if co.Flags&CO_VARARGS != 0 { total_args++ } - if flags&CO_VARKEYWORDS != 0 { + if co.Flags&CO_VARKEYWORDS != 0 { total_args++ } used_cell2arg := false @@ -182,9 +206,9 @@ func NewCode(argcount int32, kwonlyargcount int32, cell2arg[i] = CO_CELL_NOT_AN_ARG } // Find cells which are also arguments. - for i, cell := range cellvars { + for i, cell := range co.Cellvars { for j := int32(0); j < total_args; j++ { - arg := varnames[j] + arg := co.Varnames[j] if cell == arg { cell2arg[i] = byte(j) used_cell2arg = true @@ -196,26 +220,7 @@ func NewCode(argcount int32, kwonlyargcount int32, cell2arg = nil } } - - return &Code{ - Argcount: argcount, - Kwonlyargcount: kwonlyargcount, - Nlocals: nlocals, - Stacksize: stacksize, - Flags: flags, - Code: code, - Consts: consts, - Names: names, - Varnames: varnames, - Freevars: freevars, - Cellvars: cellvars, - Cell2arg: cell2arg, - Filename: filename, - Name: name, - Firstlineno: firstlineno, - Lnotab: lnotab, - Weakreflist: nil, - } + co.Cell2arg = cell2arg } // Return number of free variables diff --git a/vm/tests/class.py b/vm/tests/class.py index 2c8fd70b..f9781cd7 100644 --- a/vm/tests/class.py +++ b/vm/tests/class.py @@ -47,17 +47,16 @@ def method1(self, x): c = x() assert c.method1(1) == 2 -# FIXME doesn't work -# doc="CLASS_DEREF2" -# def classderef2(x): -# class DeRefTest: -# VAR = x -# def method1(self, x): -# "method1" -# return self.VAR+x -# return DeRefTest -# x = classderef2(1) -# c = x() -# assert c.method1(1) == 2 +doc="CLASS_DEREF2" +def classderef2(x): + class DeRefTest: + VAR = x + def method1(self, x): + "method1" + return self.VAR+x + return DeRefTest +x = classderef2(1) +c = x() +assert c.method1(1) == 2 doc="finished" diff --git a/vm/tests/decorators.py b/vm/tests/decorators.py new file mode 100644 index 00000000..b7e2d703 --- /dev/null +++ b/vm/tests/decorators.py @@ -0,0 +1,327 @@ +# Copyright 2023 The go-python Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Copied from Python-3.4.9\Lib\test\test_decorators.py + +import libtest as self + +def funcattrs(**kwds): + def decorate(func): + # FIXME func.__dict__.update(kwds) + for k, v in kwds.items(): + func.__dict__[k] = v + return func + return decorate + +class MiscDecorators (object): + @staticmethod + def author(name): + def decorate(func): + func.__dict__['author'] = name + return func + return decorate + +# ----------------------------------------------- + +class DbcheckError (Exception): + def __init__(self, exprstr, func, args, kwds): + # A real version of this would set attributes here + Exception.__init__(self, "dbcheck %r failed (func=%s args=%s kwds=%s)" % + (exprstr, func, args, kwds)) + + +def dbcheck(exprstr, globals=None, locals=None): + "Decorator to implement debugging assertions" + def decorate(func): + expr = compile(exprstr, "dbcheck-%s" % func.__name__, "eval") + def check(*args, **kwds): + if not eval(expr, globals, locals): + raise DbcheckError(exprstr, func, args, kwds) + return func(*args, **kwds) + return check + return decorate + +# ----------------------------------------------- + +def countcalls(counts): + "Decorator to count calls to a function" + def decorate(func): + func_name = func.__name__ + counts[func_name] = 0 + def call(*args, **kwds): + counts[func_name] += 1 + return func(*args, **kwds) + call.__name__ = func_name + return call + return decorate + +# ----------------------------------------------- + +# FIXME: dict can only have string keys +# def memoize(func): +# saved = {} +# def call(*args): +# try: +# return saved[args] +# except KeyError: +# res = func(*args) +# saved[args] = res +# return res +# except TypeError: +# # Unhashable argument +# return func(*args) +# call.__name__ = func.__name__ +# return call +def memoize(func): + saved = {} + def call(*args): + try: + if isinstance(args[0], list): + raise TypeError + return saved[str(args)] + except KeyError: + res = func(*args) + saved[str(args)] = res + return res + except TypeError: + # Unhashable argument + return func(*args) + call.__name__ = func.__name__ + return call + +# ----------------------------------------------- + +doc="test_single" +# FIXME staticmethod +# class C(object): +# @staticmethod +# def foo(): return 42 +# self.assertEqual(C.foo(), 42) +# self.assertEqual(C().foo(), 42) + +doc="test_staticmethod_function" +@staticmethod +def notamethod(x): + return x +self.assertRaises(TypeError, notamethod, 1) + +doc="test_dotted" +# FIXME class decorator +# decorators = MiscDecorators() +# @decorators.author('Cleese') +# def foo(): return 42 +# self.assertEqual(foo(), 42) +# self.assertEqual(foo.author, 'Cleese') + +doc="test_argforms" +def noteargs(*args, **kwds): + def decorate(func): + setattr(func, 'dbval', (args, kwds)) + return func + return decorate + +args = ( 'Now', 'is', 'the', 'time' ) +kwds = dict(one=1, two=2) +@noteargs(*args, **kwds) +def f1(): return 42 +self.assertEqual(f1(), 42) +self.assertEqual(f1.dbval, (args, kwds)) + +@noteargs('terry', 'gilliam', eric='idle', john='cleese') +def f2(): return 84 +self.assertEqual(f2(), 84) +self.assertEqual(f2.dbval, (('terry', 'gilliam'), + dict(eric='idle', john='cleese'))) + +@noteargs(1, 2,) +def f3(): pass +self.assertEqual(f3.dbval, ((1, 2), {})) + +doc="test_dbcheck" +# FIXME TypeError: "catching 'BaseException' that does not inherit from BaseException is not allowed" +# @dbcheck('args[1] is not None') +# def f(a, b): +# return a + b +# self.assertEqual(f(1, 2), 3) +# self.assertRaises(DbcheckError, f, 1, None) + +doc="test_memoize" +counts = {} + +@memoize +@countcalls(counts) +def double(x): + return x * 2 +self.assertEqual(double.__name__, 'double') + +self.assertEqual(counts, dict(double=0)) + +# Only the first call with a given argument bumps the call count: +# +# Only the first call with a given argument bumps the call count: +# +self.assertEqual(double(2), 4) +self.assertEqual(counts['double'], 1) +self.assertEqual(double(2), 4) +self.assertEqual(counts['double'], 1) +self.assertEqual(double(3), 6) +self.assertEqual(counts['double'], 2) + +# Unhashable arguments do not get memoized: +# +self.assertEqual(double([10]), [10, 10]) +self.assertEqual(counts['double'], 3) +self.assertEqual(double([10]), [10, 10]) +self.assertEqual(counts['double'], 4) + +doc="test_errors" +# Test syntax restrictions - these are all compile-time errors: +# +for expr in [ "1+2", "x[3]", "(1, 2)" ]: + # Sanity check: is expr is a valid expression by itself? + compile(expr, "testexpr", "exec") + + codestr = "@%s\ndef f(): pass" % expr + self.assertRaises(SyntaxError, compile, codestr, "test", "exec") + +# You can't put multiple decorators on a single line: +# +self.assertRaises(SyntaxError, compile, + "@f1 @f2\ndef f(): pass", "test", "exec") + +# Test runtime errors + +def unimp(func): + raise NotImplementedError +context = dict(nullval=None, unimp=unimp) + +for expr, exc in [ ("undef", NameError), + ("nullval", TypeError), + ("nullval.attr", NameError), # FIXME ("nullval.attr", AttributeError), + ("unimp", NotImplementedError)]: + codestr = "@%s\ndef f(): pass\nassert f() is None" % expr + code = compile(codestr, "test", "exec") + self.assertRaises(exc, eval, code, context) + +doc="test_double" +class C(object): + @funcattrs(abc=1, xyz="haha") + @funcattrs(booh=42) + def foo(self): return 42 +self.assertEqual(C().foo(), 42) +self.assertEqual(C.foo.abc, 1) +self.assertEqual(C.foo.xyz, "haha") +self.assertEqual(C.foo.booh, 42) + + +doc="test_order" +# Test that decorators are applied in the proper order to the function +# they are decorating. +def callnum(num): + """Decorator factory that returns a decorator that replaces the + passed-in function with one that returns the value of 'num'""" + def deco(func): + return lambda: num + return deco +@callnum(2) +@callnum(1) +def foo(): return 42 +self.assertEqual(foo(), 2, + "Application order of decorators is incorrect") + + +doc="test_eval_order" +# Evaluating a decorated function involves four steps for each +# decorator-maker (the function that returns a decorator): +# +# 1: Evaluate the decorator-maker name +# 2: Evaluate the decorator-maker arguments (if any) +# 3: Call the decorator-maker to make a decorator +# 4: Call the decorator +# +# When there are multiple decorators, these steps should be +# performed in the above order for each decorator, but we should +# iterate through the decorators in the reverse of the order they +# appear in the source. +# FIXME class decorator +# actions = [] +# +# def make_decorator(tag): +# actions.append('makedec' + tag) +# def decorate(func): +# actions.append('calldec' + tag) +# return func +# return decorate +# +# class NameLookupTracer (object): +# def __init__(self, index): +# self.index = index +# +# def __getattr__(self, fname): +# if fname == 'make_decorator': +# opname, res = ('evalname', make_decorator) +# elif fname == 'arg': +# opname, res = ('evalargs', str(self.index)) +# else: +# assert False, "Unknown attrname %s" % fname +# actions.append('%s%d' % (opname, self.index)) +# return res +# +# c1, c2, c3 = map(NameLookupTracer, [ 1, 2, 3 ]) +# +# expected_actions = [ 'evalname1', 'evalargs1', 'makedec1', +# 'evalname2', 'evalargs2', 'makedec2', +# 'evalname3', 'evalargs3', 'makedec3', +# 'calldec3', 'calldec2', 'calldec1' ] +# +# actions = [] +# @c1.make_decorator(c1.arg) +# @c2.make_decorator(c2.arg) +# @c3.make_decorator(c3.arg) +# def foo(): return 42 +# self.assertEqual(foo(), 42) +# +# self.assertEqual(actions, expected_actions) +# +# # Test the equivalence claim in chapter 7 of the reference manual. +# # +# actions = [] +# def bar(): return 42 +# bar = c1.make_decorator(c1.arg)(c2.make_decorator(c2.arg)(c3.make_decorator(c3.arg)(bar))) +# self.assertEqual(bar(), 42) +# self.assertEqual(actions, expected_actions) + +doc="test_simple" +def plain(x): + x.extra = 'Hello' + return x +@plain +class C(object): pass +self.assertEqual(C.extra, 'Hello') + +doc="test_double" +def ten(x): + x.extra = 10 + return x +def add_five(x): + x.extra += 5 + return x + +@add_five +@ten +class C(object): pass +self.assertEqual(C.extra, 15) + +doc="test_order" +def applied_first(x): + x.extra = 'first' + return x +def applied_second(x): + x.extra = 'second' + return x +@applied_second +@applied_first +class C(object): pass +self.assertEqual(C.extra, 'second') +doc="finished" diff --git a/vm/tests/functions.py b/vm/tests/functions.py index da4bf924..aab079f9 100644 --- a/vm/tests/functions.py +++ b/vm/tests/functions.py @@ -21,18 +21,32 @@ def fn2(x,y=1): assert fn2(1,y=4) == 5 # Closure +doc="closure1" +closure1 = lambda x: lambda y: x+y +cf1 = closure1(1) +assert cf1(1) == 2 +assert cf1(2) == 3 + +doc="closure2" +def closure2(*args, **kwargs): + def inc(): + kwargs['x'] += 1 + return kwargs['x'] + return inc +cf2 = closure2(x=1) +assert cf2() == 2 +assert cf2() == 3 -# FIXME something wrong with closures over function arguments... -# doc="counter3" -# def counter3(x): -# def inc(): -# nonlocal x -# x += 1 -# return x -# return inc -# fn3 = counter3(1) -# assert fn3() == 2 -# assert fn3() == 3 +doc="counter3" +def counter3(x): + def inc(): + nonlocal x + x += 1 + return x + return inc +fn3 = counter3(1) +assert fn3() == 2 +assert fn3() == 3 doc="counter4" def counter4(initial): @@ -238,6 +252,4 @@ def fn16_6(*,a,b,c): ck(fn16_5, "fn16_5() missing 2 required keyword-only arguments: 'a' and 'b'") ck(fn16_6, "fn16_6() missing 3 required keyword-only arguments: 'a', 'b', and 'c'") -#FIXME decorators - doc="finished" diff --git a/vm/tests/libtest.py b/vm/tests/libtest.py new file mode 100644 index 00000000..8038556d --- /dev/null +++ b/vm/tests/libtest.py @@ -0,0 +1,57 @@ +# Copyright 2023 The go-python Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Imitate the calling method of unittest + +def assertRaises(expecting, fn, *args, **kwargs): + """Check the exception was raised - don't check the text""" + try: + fn(*args, **kwargs) + except expecting as e: + pass + else: + assert False, "%s not raised" % (expecting,) + +def assertEqual(first, second, msg=None): + if msg: + assert first == second, "%s not equal" % (msg,) + else: + assert first == second + +def assertIs(expr1, expr2, msg=None): + if msg: + assert expr1 is expr2, "%s is not None" % (msg,) + else: + assert expr1 is expr2 + +def assertIsNone(obj, msg=None): + if msg: + assert obj is None, "%s is not None" % (msg,) + else: + assert obj is None + +def assertTrue(obj, msg=None): + if msg: + assert obj, "%s is not True" % (msg,) + else: + assert obj + +def assertRaisesText(expecting, text, fn, *args, **kwargs): + """Check the exception with text in is raised""" + try: + fn(*args, **kwargs) + except expecting as e: + assert text in e.args[0], "'%s' not found in '%s'" % (text, e.args[0]) + else: + assert False, "%s not raised" % (expecting,) + +def assertTypedEqual(actual, expect, msg=None): + assertEqual(actual, expect, msg) + def recurse(actual, expect): + if isinstance(expect, (tuple, list)): + for x, y in zip(actual, expect): + recurse(x, y) + else: + assertIs(type(actual), type(expect)) + recurse(actual, expect)