add support for SQLite user functions and collations
This commit is contained in:
parent
a1e9860040
commit
5a6c2e128f
|
@ -68,6 +68,9 @@ The following configuration options exist for the plugin class:
|
|||
* **autocommit**: Whether or not to commit outstanding transactions at the end of the request cycle (default: True).
|
||||
* **dictrows**: Whether or not to support dict-like access to row objects (default: True).
|
||||
* **text_factory**: The text_factory for the connection (default: unicode).
|
||||
* **functions**: Add user-defined functions for use in SQL, should be a dict like ``{'name': (num_params, func)}`` (default: None).
|
||||
* **aggregates**: Add user-defined aggregate functions, should be a dict like ``{'name': (num_params, aggregate_class)}`` (default: None).
|
||||
* **collations**: Add user-defined collations, should be a dict like ``{'name': callable}`` (default: None).
|
||||
|
||||
You can override each of these values on a per-route basis::
|
||||
|
||||
|
|
|
@ -59,12 +59,16 @@ class SQLitePlugin(object):
|
|||
unicode = str
|
||||
|
||||
def __init__(self, dbfile=':memory:', autocommit=True, dictrows=True,
|
||||
keyword='db', text_factory=unicode):
|
||||
keyword='db', text_factory=unicode,
|
||||
functions=None, aggregates=None, collations=None):
|
||||
self.dbfile = dbfile
|
||||
self.autocommit = autocommit
|
||||
self.dictrows = dictrows
|
||||
self.keyword = keyword
|
||||
self.text_factory = text_factory
|
||||
self.functions = functions or {}
|
||||
self.aggregates = aggregates or {}
|
||||
self.collations = collations or {}
|
||||
|
||||
def setup(self, app):
|
||||
''' Make sure that other installed plugins don't affect the same
|
||||
|
@ -99,6 +103,9 @@ class SQLitePlugin(object):
|
|||
dictrows = g('dictrows', self.dictrows)
|
||||
keyword = g('keyword', self.keyword)
|
||||
text_factory = g('text_factory', self.text_factory)
|
||||
functions = g('functions', self.functions)
|
||||
aggregates = g('aggregates', self.aggregates)
|
||||
collations = g('collations', self.collations)
|
||||
|
||||
# Test if the original callback accepts a 'db' keyword.
|
||||
# Ignore it if it does not need a database handle.
|
||||
|
@ -114,6 +121,13 @@ class SQLitePlugin(object):
|
|||
# This enables column access by name: row['column_name']
|
||||
if dictrows:
|
||||
db.row_factory = sqlite3.Row
|
||||
# Create user functions, aggregates and collations
|
||||
for name, value in functions.items():
|
||||
db.create_function(name, *value)
|
||||
for name, value in aggregates.items():
|
||||
db.create_aggregate(name, *value)
|
||||
for name, value in collations.items():
|
||||
db.create_collation(name, value)
|
||||
# Add the connection handle as a keyword argument.
|
||||
kwargs[keyword] = db
|
||||
|
||||
|
|
45
test.py
45
test.py
|
@ -84,6 +84,51 @@ class SQLiteTest(unittest.TestCase):
|
|||
|
||||
self._request('/')
|
||||
|
||||
def test_user_functions(self):
|
||||
class SumSq:
|
||||
def __init__(self):
|
||||
self.result = 0
|
||||
|
||||
def step(self, value):
|
||||
if value:
|
||||
self.result += value**2
|
||||
|
||||
def finalize(self):
|
||||
return self.result
|
||||
|
||||
def collate_reverse(string1, string2):
|
||||
if string1 == string2:
|
||||
return 0
|
||||
elif string1 < string2:
|
||||
return 1
|
||||
else:
|
||||
return -1
|
||||
|
||||
testfunc1 = lambda: 'test'
|
||||
testfunc2 = lambda x: x + 1
|
||||
|
||||
self.app.install(sqlite.Plugin(
|
||||
keyword='db4',
|
||||
functions={'testfunc1': (0, testfunc1), 'testfunc2': (1, testfunc2)},
|
||||
aggregates={'sumsq': (1, SumSq)},
|
||||
collations={'reverse': collate_reverse},
|
||||
))
|
||||
|
||||
@self.app.get('/')
|
||||
def test(db, db4):
|
||||
db4.execute("CREATE TABLE todo (id INTEGER PRIMARY KEY, task char(100) NOT NULL)")
|
||||
result = db4.execute("SELECT testfunc1(), testfunc2(2)").fetchone()
|
||||
self.assertEqual(tuple(result), ('test', 3))
|
||||
db4.execute("INSERT INTO todo VALUES (10, 'a')")
|
||||
db4.execute("INSERT INTO todo VALUES (11, 'a')")
|
||||
db4.execute("INSERT INTO todo VALUES (12, 'a')")
|
||||
result = db4.execute("SELECT sumsq(id) FROM todo WHERE task='a'").fetchone()
|
||||
self.assertEqual(tuple(result), (365,))
|
||||
result = db4.execute("SELECT ('a' < 'b' COLLATE reverse)").fetchone()
|
||||
self.assertEqual(tuple(result), (0,))
|
||||
|
||||
self._request('/')
|
||||
|
||||
def test_raise_sqlite_integrity_error(self):
|
||||
@self.app.get('/')
|
||||
def test(db):
|
||||
|
|
Loading…
Reference in New Issue
Block a user