From 8a6ccb7c6273ca173779cf8993ff05ffbada305e Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 24 Jan 2022 04:36:51 +0100 Subject: [PATCH] SQLite memory database can be shared between connections --- pony/orm/core.py | 7 +++- pony/orm/dbapiprovider.py | 3 +- pony/orm/dbproviders/sqlite.py | 36 ++++++++++++------ .../tests/test_sqlbuilding_formatstyles.py | 2 +- .../orm/tests/test_sqlite_shared_memory_db.py | 37 +++++++++++++++++++ 5 files changed, 70 insertions(+), 15 deletions(-) create mode 100644 pony/orm/tests/test_sqlite_shared_memory_db.py diff --git a/pony/orm/core.py b/pony/orm/core.py index 2682951b..845e672a 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -724,11 +724,16 @@ def __call__(self, func=None, provider=None): OnConnectDecorator.check_provider(provider) return OnConnectDecorator(self.database, provider) + +db_id_counter = itertools.count(1) + + class Database(object): def __deepcopy__(self, memo): return self # Database cannot be cloned by deepcopy() @cut_traceback def __init__(self, *args, **kwargs): + self.id = next(db_id_counter) # argument 'self' cannot be named 'database', because 'database' can be in kwargs self.priority = 0 self._insert_cache = {} @@ -779,7 +784,7 @@ def _bind(self, *args, **kwargs): provider_module = import_module('pony.orm.dbproviders.' + provider) provider_cls = provider_module.provider_cls kwargs['pony_call_on_connect'] = self.call_on_connect - self.provider = provider_cls(*args, **kwargs) + self.provider = provider_cls(self, *args, **kwargs) @property def last_sql(database): return database._dblocal.last_sql diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 9f44c7c7..e9d8b9f4 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -122,7 +122,8 @@ class DBAPIProvider(object): fk_types = { 'SERIAL' : 'INTEGER', 'BIGSERIAL' : 'BIGINT' } - def __init__(provider, *args, **kwargs): + def __init__(provider, database, *args, **kwargs): + provider.database = database pool_mockup = kwargs.pop('pony_pool_mockup', None) call_on_connect = kwargs.pop('pony_call_on_connect', None) if pool_mockup: provider.pool = pool_mockup diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index bbce763b..4aa16ee7 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from pony.py23compat import buffer, int_types -import os.path, sys, re, json, datetime, time +import os, os.path, sys, re, json, datetime, time import sqlite3 as sqlite from decimal import Decimal from random import random @@ -337,8 +337,12 @@ class SQLiteProvider(DBAPIProvider): (Json, SQLiteJsonConverter) ] - def __init__(provider, *args, **kwargs): - DBAPIProvider.__init__(provider, *args, **kwargs) + def __init__(provider, database, filename, **kwargs): + is_shared_memory_db = filename == ':sharedmemory:' + if is_shared_memory_db: + filename = "file:memdb%d_%s?mode=memory&cache=shared" % (database.id, os.urandom(8).hex()) + kwargs["uri"] = True + DBAPIProvider.__init__(provider, database, is_shared_memory_db, filename, **kwargs) provider.pre_transaction_lock = Lock() provider.transaction_lock = Lock() @@ -434,8 +438,10 @@ def release(provider, connection, cache=None): raise DBAPIProvider.release(provider, connection, cache) - def get_pool(provider, filename, create_db=False, **kwargs): - if filename != ':memory:': + def get_pool(provider, is_shared_memory_db, filename, create_db=False, **kwargs): + if is_shared_memory_db or filename == ':memory:': + pass + else: # When relative filename is specified, it is considered # not relative to cwd, but to user module where # Database instance is created @@ -450,7 +456,7 @@ def get_pool(provider, filename, create_db=False, **kwargs): # 1 - SQLiteProvider.__init__() # 0 - pony.dbproviders.sqlite.get_pool() filename = absolutize_path(filename, frame_depth=cut_traceback_depth+5) - return SQLitePool(filename, create_db, **kwargs) + return SQLitePool(is_shared_memory_db, filename, create_db, **kwargs) def table_exists(provider, connection, table_name, case_sensitive=True): return provider._exists(connection, table_name, None, case_sensitive) @@ -645,15 +651,19 @@ def py_string_slice(s, start, end): return s[start:end] class SQLitePool(Pool): - def __init__(pool, filename, create_db, **kwargs): # called separately in each thread + def __init__(pool, is_shared_memory_db, filename, create_db, **kwargs): # called separately in each thread + pool.is_shared_memory_db = is_shared_memory_db pool.filename = filename pool.create_db = create_db pool.kwargs = kwargs pool.con = None def _connect(pool): filename = pool.filename - if filename != ':memory:' and not pool.create_db and not os.path.exists(filename): + if pool.is_shared_memory_db or pool.filename == ':memory:': + pass + elif not pool.create_db and not os.path.exists(filename): throw(IOError, "Database file is not found: %r" % filename) + pool.con = con = sqlite.connect(filename, isolation_level=None, **pool.kwargs) con.text_factory = _text_factory @@ -685,10 +695,12 @@ def create_function(name, num_params, func): con.execute('PRAGMA case_sensitive_like = true') def disconnect(pool): - if pool.filename != ':memory:': + if pool.is_shared_memory_db or pool.filename == ':memory:': + pass + else: Pool.disconnect(pool) def drop(pool, con): - if pool.filename != ':memory:': - Pool.drop(pool, con) - else: + if pool.is_shared_memory_db or pool.filename == ':memory:': con.rollback() + else: + Pool.drop(pool, con) diff --git a/pony/orm/tests/test_sqlbuilding_formatstyles.py b/pony/orm/tests/test_sqlbuilding_formatstyles.py index 70c217fd..96fbbfe4 100644 --- a/pony/orm/tests/test_sqlbuilding_formatstyles.py +++ b/pony/orm/tests/test_sqlbuilding_formatstyles.py @@ -10,7 +10,7 @@ class TestFormatStyles(unittest.TestCase): def setUp(self): self.key1 = 'KEY1' self.key2 = 'KEY2' - self.provider = DBAPIProvider(pony_pool_mockup=TestPool(None)) + self.provider = DBAPIProvider(database=None, pony_pool_mockup=TestPool(None)) self.ast = [ 'SELECT', [ 'ALL', ['COLUMN', None, 'A']], [ 'FROM', [None, 'TABLE', 'T1']], [ 'WHERE', [ 'EQ', ['COLUMN', None, 'B'], [ 'PARAM', self.key1 ] ], [ 'EQ', ['COLUMN', None, 'C'], [ 'PARAM', self.key2 ] ], diff --git a/pony/orm/tests/test_sqlite_shared_memory_db.py b/pony/orm/tests/test_sqlite_shared_memory_db.py new file mode 100644 index 00000000..0469c542 --- /dev/null +++ b/pony/orm/tests/test_sqlite_shared_memory_db.py @@ -0,0 +1,37 @@ +from __future__ import absolute_import, print_function, division + +import threading +import unittest + +from pony.orm.core import * + + +db = Database('sqlite', ':sharedmemory:') + + +class Person(db.Entity): + name = Required(str) + +db.generate_mapping(create_tables=True) + +with db_session: + Person(name='John') + Person(name='Mike') + + +class TestThread(threading.Thread): + def __init__(self, *args, **kwargs): + super().__init__(*args, *kwargs) + self.result = [] + def run(self): + with db_session: + persons = Person.select().fetch() + self.result.extend(p.name for p in persons) + + +class TestFlush(unittest.TestCase): + def test1(self): + thread1 = TestThread() + thread1.start() + thread1.join() + self.assertEqual(set(thread1.result), {'John', 'Mike'})