add support for SQLite user functions and collations

This commit is contained in:
Dingyuan Wang 2017-08-25 15:51:32 +08:00 committed by Marcel Hellkamp
parent a1e9860040
commit 5a6c2e128f
3 changed files with 63 additions and 1 deletions

View File

@ -68,6 +68,9 @@ The following configuration options exist for the plugin class:
* **autocommit**: Whether or not to commit outstanding transactions at the end of the request cycle (default: True). * **autocommit**: Whether or not to commit outstanding transactions at the end of the request cycle (default: True).
* **dictrows**: Whether or not to support dict-like access to row objects (default: True). * **dictrows**: Whether or not to support dict-like access to row objects (default: True).
* **text_factory**: The text_factory for the connection (default: unicode). * **text_factory**: The text_factory for the connection (default: unicode).
* **functions**: Add user-defined functions for use in SQL, should be a dict like ``{'name': (num_params, func)}`` (default: None).
* **aggregates**: Add user-defined aggregate functions, should be a dict like ``{'name': (num_params, aggregate_class)}`` (default: None).
* **collations**: Add user-defined collations, should be a dict like ``{'name': callable}`` (default: None).
You can override each of these values on a per-route basis:: You can override each of these values on a per-route basis::

View File

@ -59,12 +59,16 @@ class SQLitePlugin(object):
unicode = str unicode = str
def __init__(self, dbfile=':memory:', autocommit=True, dictrows=True, def __init__(self, dbfile=':memory:', autocommit=True, dictrows=True,
keyword='db', text_factory=unicode): keyword='db', text_factory=unicode,
functions=None, aggregates=None, collations=None):
self.dbfile = dbfile self.dbfile = dbfile
self.autocommit = autocommit self.autocommit = autocommit
self.dictrows = dictrows self.dictrows = dictrows
self.keyword = keyword self.keyword = keyword
self.text_factory = text_factory self.text_factory = text_factory
self.functions = functions or {}
self.aggregates = aggregates or {}
self.collations = collations or {}
def setup(self, app): def setup(self, app):
''' Make sure that other installed plugins don't affect the same ''' Make sure that other installed plugins don't affect the same
@ -99,6 +103,9 @@ class SQLitePlugin(object):
dictrows = g('dictrows', self.dictrows) dictrows = g('dictrows', self.dictrows)
keyword = g('keyword', self.keyword) keyword = g('keyword', self.keyword)
text_factory = g('text_factory', self.text_factory) text_factory = g('text_factory', self.text_factory)
functions = g('functions', self.functions)
aggregates = g('aggregates', self.aggregates)
collations = g('collations', self.collations)
# Test if the original callback accepts a 'db' keyword. # Test if the original callback accepts a 'db' keyword.
# Ignore it if it does not need a database handle. # Ignore it if it does not need a database handle.
@ -114,6 +121,13 @@ class SQLitePlugin(object):
# This enables column access by name: row['column_name'] # This enables column access by name: row['column_name']
if dictrows: if dictrows:
db.row_factory = sqlite3.Row db.row_factory = sqlite3.Row
# Create user functions, aggregates and collations
for name, value in functions.items():
db.create_function(name, *value)
for name, value in aggregates.items():
db.create_aggregate(name, *value)
for name, value in collations.items():
db.create_collation(name, value)
# Add the connection handle as a keyword argument. # Add the connection handle as a keyword argument.
kwargs[keyword] = db kwargs[keyword] = db

45
test.py
View File

@ -84,6 +84,51 @@ class SQLiteTest(unittest.TestCase):
self._request('/') self._request('/')
def test_user_functions(self):
class SumSq:
def __init__(self):
self.result = 0
def step(self, value):
if value:
self.result += value**2
def finalize(self):
return self.result
def collate_reverse(string1, string2):
if string1 == string2:
return 0
elif string1 < string2:
return 1
else:
return -1
testfunc1 = lambda: 'test'
testfunc2 = lambda x: x + 1
self.app.install(sqlite.Plugin(
keyword='db4',
functions={'testfunc1': (0, testfunc1), 'testfunc2': (1, testfunc2)},
aggregates={'sumsq': (1, SumSq)},
collations={'reverse': collate_reverse},
))
@self.app.get('/')
def test(db, db4):
db4.execute("CREATE TABLE todo (id INTEGER PRIMARY KEY, task char(100) NOT NULL)")
result = db4.execute("SELECT testfunc1(), testfunc2(2)").fetchone()
self.assertEqual(tuple(result), ('test', 3))
db4.execute("INSERT INTO todo VALUES (10, 'a')")
db4.execute("INSERT INTO todo VALUES (11, 'a')")
db4.execute("INSERT INTO todo VALUES (12, 'a')")
result = db4.execute("SELECT sumsq(id) FROM todo WHERE task='a'").fetchone()
self.assertEqual(tuple(result), (365,))
result = db4.execute("SELECT ('a' < 'b' COLLATE reverse)").fetchone()
self.assertEqual(tuple(result), (0,))
self._request('/')
def test_raise_sqlite_integrity_error(self): def test_raise_sqlite_integrity_error(self):
@self.app.get('/') @self.app.get('/')
def test(db): def test(db):