From 5a6c2e128f16d076b2300d270aad426fc723f745 Mon Sep 17 00:00:00 2001 From: Dingyuan Wang Date: Fri, 25 Aug 2017 15:51:32 +0800 Subject: [PATCH] add support for SQLite user functions and collations --- README.rst | 3 +++ bottle_sqlite.py | 16 +++++++++++++++- test.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 01c863b..acacc7b 100644 --- a/README.rst +++ b/README.rst @@ -68,6 +68,9 @@ The following configuration options exist for the plugin class: * **autocommit**: Whether or not to commit outstanding transactions at the end of the request cycle (default: True). * **dictrows**: Whether or not to support dict-like access to row objects (default: True). * **text_factory**: The text_factory for the connection (default: unicode). +* **functions**: Add user-defined functions for use in SQL, should be a dict like ``{'name': (num_params, func)}`` (default: None). +* **aggregates**: Add user-defined aggregate functions, should be a dict like ``{'name': (num_params, aggregate_class)}`` (default: None). +* **collations**: Add user-defined collations, should be a dict like ``{'name': callable}`` (default: None). You can override each of these values on a per-route basis:: diff --git a/bottle_sqlite.py b/bottle_sqlite.py index 62cb826..cdd9449 100755 --- a/bottle_sqlite.py +++ b/bottle_sqlite.py @@ -59,12 +59,16 @@ class SQLitePlugin(object): unicode = str def __init__(self, dbfile=':memory:', autocommit=True, dictrows=True, - keyword='db', text_factory=unicode): + keyword='db', text_factory=unicode, + functions=None, aggregates=None, collations=None): self.dbfile = dbfile self.autocommit = autocommit self.dictrows = dictrows self.keyword = keyword self.text_factory = text_factory + self.functions = functions or {} + self.aggregates = aggregates or {} + self.collations = collations or {} def setup(self, app): ''' Make sure that other installed plugins don't affect the same @@ -99,6 +103,9 @@ class SQLitePlugin(object): dictrows = g('dictrows', self.dictrows) keyword = g('keyword', self.keyword) text_factory = g('text_factory', self.text_factory) + functions = g('functions', self.functions) + aggregates = g('aggregates', self.aggregates) + collations = g('collations', self.collations) # Test if the original callback accepts a 'db' keyword. # Ignore it if it does not need a database handle. @@ -114,6 +121,13 @@ class SQLitePlugin(object): # This enables column access by name: row['column_name'] if dictrows: db.row_factory = sqlite3.Row + # Create user functions, aggregates and collations + for name, value in functions.items(): + db.create_function(name, *value) + for name, value in aggregates.items(): + db.create_aggregate(name, *value) + for name, value in collations.items(): + db.create_collation(name, value) # Add the connection handle as a keyword argument. kwargs[keyword] = db diff --git a/test.py b/test.py index 681a79e..766c0bf 100644 --- a/test.py +++ b/test.py @@ -84,6 +84,51 @@ class SQLiteTest(unittest.TestCase): self._request('/') + def test_user_functions(self): + class SumSq: + def __init__(self): + self.result = 0 + + def step(self, value): + if value: + self.result += value**2 + + def finalize(self): + return self.result + + def collate_reverse(string1, string2): + if string1 == string2: + return 0 + elif string1 < string2: + return 1 + else: + return -1 + + testfunc1 = lambda: 'test' + testfunc2 = lambda x: x + 1 + + self.app.install(sqlite.Plugin( + keyword='db4', + functions={'testfunc1': (0, testfunc1), 'testfunc2': (1, testfunc2)}, + aggregates={'sumsq': (1, SumSq)}, + collations={'reverse': collate_reverse}, + )) + + @self.app.get('/') + def test(db, db4): + db4.execute("CREATE TABLE todo (id INTEGER PRIMARY KEY, task char(100) NOT NULL)") + result = db4.execute("SELECT testfunc1(), testfunc2(2)").fetchone() + self.assertEqual(tuple(result), ('test', 3)) + db4.execute("INSERT INTO todo VALUES (10, 'a')") + db4.execute("INSERT INTO todo VALUES (11, 'a')") + db4.execute("INSERT INTO todo VALUES (12, 'a')") + result = db4.execute("SELECT sumsq(id) FROM todo WHERE task='a'").fetchone() + self.assertEqual(tuple(result), (365,)) + result = db4.execute("SELECT ('a' < 'b' COLLATE reverse)").fetchone() + self.assertEqual(tuple(result), (0,)) + + self._request('/') + def test_raise_sqlite_integrity_error(self): @self.app.get('/') def test(db):