Skip to content

Commit

Permalink
SQLite memory database can be shared between connections
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky committed Jan 24, 2022
1 parent 8390233 commit 8a6ccb7
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 15 deletions.
7 changes: 6 additions & 1 deletion pony/orm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,11 +724,16 @@ def __call__(self, func=None, provider=None):
OnConnectDecorator.check_provider(provider)
return OnConnectDecorator(self.database, provider)


db_id_counter = itertools.count(1)


class Database(object):
def __deepcopy__(self, memo):
return self # Database cannot be cloned by deepcopy()
@cut_traceback
def __init__(self, *args, **kwargs):
self.id = next(db_id_counter)
# argument 'self' cannot be named 'database', because 'database' can be in kwargs
self.priority = 0
self._insert_cache = {}
Expand Down Expand Up @@ -779,7 +784,7 @@ def _bind(self, *args, **kwargs):
provider_module = import_module('pony.orm.dbproviders.' + provider)
provider_cls = provider_module.provider_cls
kwargs['pony_call_on_connect'] = self.call_on_connect
self.provider = provider_cls(*args, **kwargs)
self.provider = provider_cls(self, *args, **kwargs)
@property
def last_sql(database):
return database._dblocal.last_sql
Expand Down
3 changes: 2 additions & 1 deletion pony/orm/dbapiprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ class DBAPIProvider(object):

fk_types = { 'SERIAL' : 'INTEGER', 'BIGSERIAL' : 'BIGINT' }

def __init__(provider, *args, **kwargs):
def __init__(provider, database, *args, **kwargs):
provider.database = database
pool_mockup = kwargs.pop('pony_pool_mockup', None)
call_on_connect = kwargs.pop('pony_call_on_connect', None)
if pool_mockup: provider.pool = pool_mockup
Expand Down
36 changes: 24 additions & 12 deletions pony/orm/dbproviders/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import absolute_import
from pony.py23compat import buffer, int_types

import os.path, sys, re, json, datetime, time
import os, os.path, sys, re, json, datetime, time
import sqlite3 as sqlite
from decimal import Decimal
from random import random
Expand Down Expand Up @@ -337,8 +337,12 @@ class SQLiteProvider(DBAPIProvider):
(Json, SQLiteJsonConverter)
]

def __init__(provider, *args, **kwargs):
DBAPIProvider.__init__(provider, *args, **kwargs)
def __init__(provider, database, filename, **kwargs):
is_shared_memory_db = filename == ':sharedmemory:'
if is_shared_memory_db:
filename = "file:memdb%d_%s?mode=memory&cache=shared" % (database.id, os.urandom(8).hex())
kwargs["uri"] = True
DBAPIProvider.__init__(provider, database, is_shared_memory_db, filename, **kwargs)
provider.pre_transaction_lock = Lock()
provider.transaction_lock = Lock()

Expand Down Expand Up @@ -434,8 +438,10 @@ def release(provider, connection, cache=None):
raise
DBAPIProvider.release(provider, connection, cache)

def get_pool(provider, filename, create_db=False, **kwargs):
if filename != ':memory:':
def get_pool(provider, is_shared_memory_db, filename, create_db=False, **kwargs):
if is_shared_memory_db or filename == ':memory:':
pass
else:
# When relative filename is specified, it is considered
# not relative to cwd, but to user module where
# Database instance is created
Expand All @@ -450,7 +456,7 @@ def get_pool(provider, filename, create_db=False, **kwargs):
# 1 - SQLiteProvider.__init__()
# 0 - pony.dbproviders.sqlite.get_pool()
filename = absolutize_path(filename, frame_depth=cut_traceback_depth+5)
return SQLitePool(filename, create_db, **kwargs)
return SQLitePool(is_shared_memory_db, filename, create_db, **kwargs)

def table_exists(provider, connection, table_name, case_sensitive=True):
return provider._exists(connection, table_name, None, case_sensitive)
Expand Down Expand Up @@ -645,15 +651,19 @@ def py_string_slice(s, start, end):
return s[start:end]

class SQLitePool(Pool):
def __init__(pool, filename, create_db, **kwargs): # called separately in each thread
def __init__(pool, is_shared_memory_db, filename, create_db, **kwargs): # called separately in each thread
pool.is_shared_memory_db = is_shared_memory_db
pool.filename = filename
pool.create_db = create_db
pool.kwargs = kwargs
pool.con = None
def _connect(pool):
filename = pool.filename
if filename != ':memory:' and not pool.create_db and not os.path.exists(filename):
if pool.is_shared_memory_db or pool.filename == ':memory:':
pass
elif not pool.create_db and not os.path.exists(filename):
throw(IOError, "Database file is not found: %r" % filename)

pool.con = con = sqlite.connect(filename, isolation_level=None, **pool.kwargs)
con.text_factory = _text_factory

Expand Down Expand Up @@ -685,10 +695,12 @@ def create_function(name, num_params, func):

con.execute('PRAGMA case_sensitive_like = true')
def disconnect(pool):
if pool.filename != ':memory:':
if pool.is_shared_memory_db or pool.filename == ':memory:':
pass
else:
Pool.disconnect(pool)
def drop(pool, con):
if pool.filename != ':memory:':
Pool.drop(pool, con)
else:
if pool.is_shared_memory_db or pool.filename == ':memory:':
con.rollback()
else:
Pool.drop(pool, con)
2 changes: 1 addition & 1 deletion pony/orm/tests/test_sqlbuilding_formatstyles.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TestFormatStyles(unittest.TestCase):
def setUp(self):
self.key1 = 'KEY1'
self.key2 = 'KEY2'
self.provider = DBAPIProvider(pony_pool_mockup=TestPool(None))
self.provider = DBAPIProvider(database=None, pony_pool_mockup=TestPool(None))
self.ast = [ 'SELECT', [ 'ALL', ['COLUMN', None, 'A']], [ 'FROM', [None, 'TABLE', 'T1']],
[ 'WHERE', [ 'EQ', ['COLUMN', None, 'B'], [ 'PARAM', self.key1 ] ],
[ 'EQ', ['COLUMN', None, 'C'], [ 'PARAM', self.key2 ] ],
Expand Down
37 changes: 37 additions & 0 deletions pony/orm/tests/test_sqlite_shared_memory_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import absolute_import, print_function, division

import threading
import unittest

from pony.orm.core import *


db = Database('sqlite', ':sharedmemory:')


class Person(db.Entity):
name = Required(str)

db.generate_mapping(create_tables=True)

with db_session:
Person(name='John')
Person(name='Mike')


class TestThread(threading.Thread):
def __init__(self, *args, **kwargs):
super().__init__(*args, *kwargs)
self.result = []
def run(self):
with db_session:
persons = Person.select().fetch()
self.result.extend(p.name for p in persons)


class TestFlush(unittest.TestCase):
def test1(self):
thread1 = TestThread()
thread1.start()
thread1.join()
self.assertEqual(set(thread1.result), {'John', 'Mike'})

0 comments on commit 8a6ccb7

Please sign in to comment.