add support for SQLite user functions and collations
This commit is contained in:
parent
a1e9860040
commit
5a6c2e128f
|
@ -68,6 +68,9 @@ The following configuration options exist for the plugin class:
|
||||||
* **autocommit**: Whether or not to commit outstanding transactions at the end of the request cycle (default: True).
|
* **autocommit**: Whether or not to commit outstanding transactions at the end of the request cycle (default: True).
|
||||||
* **dictrows**: Whether or not to support dict-like access to row objects (default: True).
|
* **dictrows**: Whether or not to support dict-like access to row objects (default: True).
|
||||||
* **text_factory**: The text_factory for the connection (default: unicode).
|
* **text_factory**: The text_factory for the connection (default: unicode).
|
||||||
|
* **functions**: Add user-defined functions for use in SQL, should be a dict like ``{'name': (num_params, func)}`` (default: None).
|
||||||
|
* **aggregates**: Add user-defined aggregate functions, should be a dict like ``{'name': (num_params, aggregate_class)}`` (default: None).
|
||||||
|
* **collations**: Add user-defined collations, should be a dict like ``{'name': callable}`` (default: None).
|
||||||
|
|
||||||
You can override each of these values on a per-route basis::
|
You can override each of these values on a per-route basis::
|
||||||
|
|
||||||
|
|
|
@ -59,12 +59,16 @@ class SQLitePlugin(object):
|
||||||
unicode = str
|
unicode = str
|
||||||
|
|
||||||
def __init__(self, dbfile=':memory:', autocommit=True, dictrows=True,
|
def __init__(self, dbfile=':memory:', autocommit=True, dictrows=True,
|
||||||
keyword='db', text_factory=unicode):
|
keyword='db', text_factory=unicode,
|
||||||
|
functions=None, aggregates=None, collations=None):
|
||||||
self.dbfile = dbfile
|
self.dbfile = dbfile
|
||||||
self.autocommit = autocommit
|
self.autocommit = autocommit
|
||||||
self.dictrows = dictrows
|
self.dictrows = dictrows
|
||||||
self.keyword = keyword
|
self.keyword = keyword
|
||||||
self.text_factory = text_factory
|
self.text_factory = text_factory
|
||||||
|
self.functions = functions or {}
|
||||||
|
self.aggregates = aggregates or {}
|
||||||
|
self.collations = collations or {}
|
||||||
|
|
||||||
def setup(self, app):
|
def setup(self, app):
|
||||||
''' Make sure that other installed plugins don't affect the same
|
''' Make sure that other installed plugins don't affect the same
|
||||||
|
@ -99,6 +103,9 @@ class SQLitePlugin(object):
|
||||||
dictrows = g('dictrows', self.dictrows)
|
dictrows = g('dictrows', self.dictrows)
|
||||||
keyword = g('keyword', self.keyword)
|
keyword = g('keyword', self.keyword)
|
||||||
text_factory = g('text_factory', self.text_factory)
|
text_factory = g('text_factory', self.text_factory)
|
||||||
|
functions = g('functions', self.functions)
|
||||||
|
aggregates = g('aggregates', self.aggregates)
|
||||||
|
collations = g('collations', self.collations)
|
||||||
|
|
||||||
# Test if the original callback accepts a 'db' keyword.
|
# Test if the original callback accepts a 'db' keyword.
|
||||||
# Ignore it if it does not need a database handle.
|
# Ignore it if it does not need a database handle.
|
||||||
|
@ -114,6 +121,13 @@ class SQLitePlugin(object):
|
||||||
# This enables column access by name: row['column_name']
|
# This enables column access by name: row['column_name']
|
||||||
if dictrows:
|
if dictrows:
|
||||||
db.row_factory = sqlite3.Row
|
db.row_factory = sqlite3.Row
|
||||||
|
# Create user functions, aggregates and collations
|
||||||
|
for name, value in functions.items():
|
||||||
|
db.create_function(name, *value)
|
||||||
|
for name, value in aggregates.items():
|
||||||
|
db.create_aggregate(name, *value)
|
||||||
|
for name, value in collations.items():
|
||||||
|
db.create_collation(name, value)
|
||||||
# Add the connection handle as a keyword argument.
|
# Add the connection handle as a keyword argument.
|
||||||
kwargs[keyword] = db
|
kwargs[keyword] = db
|
||||||
|
|
||||||
|
|
45
test.py
45
test.py
|
@ -84,6 +84,51 @@ class SQLiteTest(unittest.TestCase):
|
||||||
|
|
||||||
self._request('/')
|
self._request('/')
|
||||||
|
|
||||||
|
def test_user_functions(self):
|
||||||
|
class SumSq:
|
||||||
|
def __init__(self):
|
||||||
|
self.result = 0
|
||||||
|
|
||||||
|
def step(self, value):
|
||||||
|
if value:
|
||||||
|
self.result += value**2
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
return self.result
|
||||||
|
|
||||||
|
def collate_reverse(string1, string2):
|
||||||
|
if string1 == string2:
|
||||||
|
return 0
|
||||||
|
elif string1 < string2:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
testfunc1 = lambda: 'test'
|
||||||
|
testfunc2 = lambda x: x + 1
|
||||||
|
|
||||||
|
self.app.install(sqlite.Plugin(
|
||||||
|
keyword='db4',
|
||||||
|
functions={'testfunc1': (0, testfunc1), 'testfunc2': (1, testfunc2)},
|
||||||
|
aggregates={'sumsq': (1, SumSq)},
|
||||||
|
collations={'reverse': collate_reverse},
|
||||||
|
))
|
||||||
|
|
||||||
|
@self.app.get('/')
|
||||||
|
def test(db, db4):
|
||||||
|
db4.execute("CREATE TABLE todo (id INTEGER PRIMARY KEY, task char(100) NOT NULL)")
|
||||||
|
result = db4.execute("SELECT testfunc1(), testfunc2(2)").fetchone()
|
||||||
|
self.assertEqual(tuple(result), ('test', 3))
|
||||||
|
db4.execute("INSERT INTO todo VALUES (10, 'a')")
|
||||||
|
db4.execute("INSERT INTO todo VALUES (11, 'a')")
|
||||||
|
db4.execute("INSERT INTO todo VALUES (12, 'a')")
|
||||||
|
result = db4.execute("SELECT sumsq(id) FROM todo WHERE task='a'").fetchone()
|
||||||
|
self.assertEqual(tuple(result), (365,))
|
||||||
|
result = db4.execute("SELECT ('a' < 'b' COLLATE reverse)").fetchone()
|
||||||
|
self.assertEqual(tuple(result), (0,))
|
||||||
|
|
||||||
|
self._request('/')
|
||||||
|
|
||||||
def test_raise_sqlite_integrity_error(self):
|
def test_raise_sqlite_integrity_error(self):
|
||||||
@self.app.get('/')
|
@self.app.get('/')
|
||||||
def test(db):
|
def test(db):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user