Change to use the get_callback_args route function to return the right function to analyse for keywords and insert the sqlite connection.

This commit is contained in:
Benoit Masson 2017-01-09 02:37:56 +01:00
parent a640affc81
commit 2cc4e8d066
4 changed files with 16 additions and 15 deletions

View File

@ -580,6 +580,7 @@ class Route(object):
func = getattr(func, '__func__' if py3k else 'im_func', func) func = getattr(func, '__func__' if py3k else 'im_func', func)
closure_attr = '__closure__' if py3k else 'func_closure' closure_attr = '__closure__' if py3k else 'func_closure'
while hasattr(func, closure_attr) and getattr(func, closure_attr): while hasattr(func, closure_attr) and getattr(func, closure_attr):
func_previous = func
attributes = getattr(func, closure_attr) attributes = getattr(func, closure_attr)
func = attributes[0].cell_contents func = attributes[0].cell_contents
@ -588,6 +589,9 @@ class Route(object):
# pick first FunctionType instance from multiple arguments # pick first FunctionType instance from multiple arguments
func = filter(lambda x: isinstance(x, FunctionType), func = filter(lambda x: isinstance(x, FunctionType),
map(lambda x: x.cell_contents, attributes)) map(lambda x: x.cell_contents, attributes))
if len(list(func))==0:
func = func_previous
break
func = list(func)[0] # py3 support func = list(func)[0] # py3 support
return func return func

View File

@ -83,9 +83,11 @@ class SQLitePlugin(object):
if bottle.__version__.startswith('0.9'): if bottle.__version__.startswith('0.9'):
config = route['config'] config = route['config']
_callback = route['callback'] _callback = route['callback']
argspec = inspect.getargspec(_callback).args
else: else:
config = route.config config = route.config
_callback = route.callback _callback = route.callback
argspec = route.get_callback_args()
# Override global configuration with route-specific values. # Override global configuration with route-specific values.
if "sqlite" in config: if "sqlite" in config:
@ -100,21 +102,8 @@ class SQLitePlugin(object):
keyword = g('keyword', self.keyword) keyword = g('keyword', self.keyword)
text_factory = g('keyword', self.text_factory) text_factory = g('keyword', self.text_factory)
# Test if the original callback accepts a 'db' keyword. if keyword not in argspec:
# Ignore it if it does not need a database handle. return callback
argspec = inspect.getargspec(_callback)
if keyword not in argspec.args:
#check for closure
no_keyword_arg = True
for closure in _callback.func_closure:
contents = closure.cell_contents
if callable(contents):
argspec = inspect.getargspec(contents)
if keyword in argspec.args:
no_keyword_arg = False
break
if no_keyword_arg:
return callback
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
# Connect to the database # Connect to the database

View File

@ -28,6 +28,13 @@ class SQLiteTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
os.unlink(self.plugin.dbfile) os.unlink(self.plugin.dbfile)
def test_with_view(self):
@self.app.get('/')
@bottle.view('test_view')
def test(db):
self.assertEqual(type(db), type(sqlite3.connect(':memory:')))
self._request('/')
def test_with_keyword(self): def test_with_keyword(self):
@self.app.get('/') @self.app.get('/')
def test(db): def test(db):

1
views/test_view.tpl Normal file
View File

@ -0,0 +1 @@
test_view