Skip to content

Commit

Permalink
Pony ORM release 0.4.7:
Browse files Browse the repository at this point in the history
- @db_session decorator is required for any database interaction;
- support of pickling/unpickling (queries and objects can be stored in memcached);
- lazy collections - don't load all the items if only one is needed;
- incremental query construction: q = q.filter(...);
- datetime precision now can be specified;
- multiple bugs were fixed.
  • Loading branch information
kozlovsky committed Jun 19, 2013
2 parents b07bc82 + 2fa989b commit 89e455e
Show file tree
Hide file tree
Showing 40 changed files with 599 additions and 327 deletions.
6 changes: 5 additions & 1 deletion pony/orm/asttranslation.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ def postSliceobj(translator, node):
return ':'.join(item.src for item in node.nodes)
def postConst(translator, node):
node.priority = 1
return repr(node.value)
value = node.value
if type(value) is float: # for Python < 2.7
s = str(value)
if float(s) == value: return s
return repr(value)
def postList(translator, node):
node.priority = 1
return '[%s]' % ', '.join(item.src for item in node.nodes)
Expand Down
365 changes: 224 additions & 141 deletions pony/orm/core.py

Large diffs are not rendered by default.

68 changes: 59 additions & 9 deletions pony/orm/dbapiprovider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from decimal import Decimal, InvalidOperation
from datetime import datetime, date, time
import re

from pony.utils import is_utf8, simple_decorator, throw, localbase
from pony.converting import str2date, str2datetime
Expand Down Expand Up @@ -56,11 +57,22 @@ def wrap_dbapi_exceptions(func, provider, *args, **kwargs):
except dbapi_module.Error, e: raise Error(e)
except dbapi_module.Warning, e: raise Warning(e)

version_re = re.compile('[0-9\.]+')

def get_version_tuple(s):
m = version_re.match(s)
if m is not None:
return tuple(map(int, m.group(0).split('.')))
return None

class DBAPIProvider(object):
paramstyle = 'qmark'
quote_char = '"'
max_params_count = 200

table_if_not_exists_syntax = True
max_time_precision = default_time_precision = 6

dbapi_module = None
dbschema_cls = None
translator_cls = None
Expand All @@ -75,7 +87,7 @@ def __init__(provider, *args, **kwargs):
provider.release(connection)

def inspect_connection(provider, connection):
provider.table_if_not_exists_syntax = True
pass

def get_default_entity_table_name(provider, entity):
return entity.__name__
Expand Down Expand Up @@ -362,8 +374,8 @@ def __init__(converter, py_type, attr=None):
def init(converter, kwargs):
attr = converter.attr
args = attr.args
if len(args) > 2: throw(TypeError, 'Too many positional parameters for Decimal (expected: precision and scale)')

if len(args) > 2: throw(TypeError, 'Too many positional parameters for Decimal '
'(expected: precision and scale), got: %s' % args)
if args: precision = args[0]
else: precision = kwargs.pop('precision', 12)
if not isinstance(precision, (int, long)):
Expand Down Expand Up @@ -398,6 +410,10 @@ def init(converter, kwargs):
converter.min_val = min_val
converter.max_val = max_val
def validate(converter, val):
if isinstance(val, float):
s = str(val)
if float(s) != val: s = repr(val)
val = Decimal(s)
try: val = Decimal(val)
except InvalidOperation, exc:
throw(TypeError, 'Invalid value for attribute %s: %r' % (converter.attr, val))
Expand Down Expand Up @@ -432,18 +448,52 @@ def validate(converter, val):
throw(TypeError, "Attribute %r: expected type is 'date'. Got: %r" % (converter.attr, val))
def sql2py(converter, val):
if not isinstance(val, date): throw(ValueError,
'Value of unexpected type received from database: instead of date got %s', type(val))
'Value of unexpected type received from database: instead of date got %s' % type(val))
return val
def sql_type(converter):
return 'DATE'

class DatetimeConverter(Converter):
sql_type_name = 'DATETIME'
def __init__(converter, py_type, attr=None):
converter.precision = None # for the case when attr is None
Converter.__init__(converter, py_type, attr)
def init(converter, kwargs):
attr = converter.attr
args = attr.args
if len(args) > 1: throw(TypeError, 'Too many positional parameters for datetime attribute %s. '
'Expected: precision, got: %r' % (attr, args))
provider = attr.entity._database_.provider
if args:
precision = args[0]
if 'precision' in kwargs: throw(TypeError,
'Precision for datetime attribute %s has both positional and keyword value' % attr)
else: precision = kwargs.pop('precision', provider.default_time_precision)
if not isinstance(precision, int) or not 0 <= precision <= 6: throw(ValueError,
'Precision value of datetime attribute %s must be between 0 and 6. Got: %r' % (attr, precision))
if precision > provider.max_time_precision: throw(ValueError,
'Precision value (%d) of attribute %s exceeds max datetime precision (%d) of %s %s'
% (precision, attr, provider.max_time_precision, provider.dialect, provider.server_version))
converter.precision = precision
def validate(converter, val):
if isinstance(val, datetime): return val
if isinstance(val, basestring): return str2datetime(val)
throw(TypeError, "Attribute %r: expected type is 'datetime'. Got: %r" % (converter.attr, val))
if isinstance(val, datetime): pass
elif isinstance(val, basestring): val = str2datetime(val)
else: throw(TypeError, "Attribute %r: expected type is 'datetime'. Got: %r" % (converter.attr, val))
p = converter.precision
if not p: val = val.replace(microsecond=0)
elif p == 6: pass
else:
rounding = 10 ** (6-p)
microsecond = (val.microsecond // rounding) * rounding
val = val.replace(microsecond=microsecond)
return val
def sql2py(converter, val):
if not isinstance(val, datetime): raise ValueError
if not isinstance(val, datetime): throw(ValueError,
'Value of unexpected type received from database: instead of datetime got %s' % type(val))
return val
def sql_type(converter):
return 'DATETIME'
attr = converter.attr
precision = converter.precision
if not attr or precision == attr.entity._database_.provider.default_time_precision:
return converter.sql_type_name
return converter.sql_type_name + '(%d)' % precision
5 changes: 3 additions & 2 deletions pony/orm/dbproviders/_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class PGTable(dbschema.Table):
pass

class PGSchema(dbschema.DBSchema):
dialect = 'PostgreSQL'
table_class = PGTable
column_class = PGColumn

Expand Down Expand Up @@ -68,10 +69,10 @@ def sql_type(converter):
return 'BYTEA'

class PGDatetimeConverter(dbapiprovider.DatetimeConverter):
def sql_type(converter):
return 'TIMESTAMP'
sql_type_name = 'TIMESTAMP'

class PGProvider(DBAPIProvider):
dialect = 'PostgreSQL'
paramstyle = 'pyformat'

dbapi_module = None # pgdb or psycopg2
Expand Down
34 changes: 32 additions & 2 deletions pony/orm/dbproviders/mysql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from decimal import Decimal, InvalidOperation
from datetime import datetime, date, time
from datetime import datetime, date, time, timedelta

import warnings
warnings.filterwarnings('ignore', '^Table.+already exists$', Warning, '^pony\\.orm\\.dbapiprovider$')
Expand All @@ -10,14 +10,15 @@

from pony.orm import dbschema
from pony.orm import dbapiprovider
from pony.orm.dbapiprovider import DBAPIProvider, Pool
from pony.orm.dbapiprovider import DBAPIProvider, Pool, get_version_tuple
from pony.orm.sqltranslation import SQLTranslator
from pony.orm.sqlbuilding import Value, SQLBuilder, join

class MySQLColumn(dbschema.Column):
auto_template = '%(type)s PRIMARY KEY AUTO_INCREMENT'

class MySQLSchema(dbschema.DBSchema):
dialect = 'MySQL'
column_class = MySQLColumn

class MySQLTranslator(SQLTranslator):
Expand Down Expand Up @@ -74,9 +75,12 @@ def sql_type(converter):
return 'LONGBLOB'

class MySQLProvider(DBAPIProvider):
dialect = 'MySQL'
paramstyle = 'format'
quote_char = "`"

max_time_precision = default_time_precision = 0

dbapi_module = MySQLdb
dbschema_cls = MySQLSchema
translator_cls = MySQLTranslator
Expand All @@ -95,13 +99,39 @@ class MySQLProvider(DBAPIProvider):
(date, dbapiprovider.DateConverter)
]

def inspect_connection(provider, connection):
cursor = connection.cursor()
cursor.execute('select version()')
row = cursor.fetchone()
assert row is not None
provider.server_version = get_version_tuple(row[0])
if provider.server_version >= (5, 6, 4):
provider.max_time_precision = 6

def get_pool(provider, *args, **kwargs):
if 'conv' not in kwargs:
conv = MySQLdb.converters.conversions.copy()
conv[FIELD_TYPE.BLOB] = [(FLAG.BINARY, buffer)]
conv[FIELD_TYPE.TIMESTAMP] = str2datetime
conv[FIELD_TYPE.DATETIME] = str2datetime
conv[FIELD_TYPE.TIME] = str2timedelta
kwargs['conv'] = conv
if 'charset' not in kwargs:
kwargs['charset'] = 'utf8'
return Pool(MySQLdb, *args, **kwargs)

provider_cls = MySQLProvider

def str2datetime(s):
if 19 < len(s) < 26: s += '000000'[:26-len(s)]
s = s.replace('-', ' ').replace(':', ' ').replace('.', ' ').replace('T', ' ')
return datetime(*map(int, s.split()))

def str2timedelta(s):
if '.' in s:
s, fractional = s.split('.')
microseconds = int((fractional + '000000')[:6])
else: microseconds = 0
h, m, s = map(int, s.split(':'))
td = timedelta(hours=abs(h), minutes=m, seconds=s, microseconds=microseconds)
return -td if h < 0 else td
25 changes: 20 additions & 5 deletions pony/orm/dbproviders/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pony.orm import core, dbschema, sqlbuilding, dbapiprovider, sqltranslation
from pony.orm.core import log_orm, log_sql, DatabaseError
from pony.orm.dbapiprovider import DBAPIProvider, wrap_dbapi_exceptions
from pony.orm.dbapiprovider import DBAPIProvider, wrap_dbapi_exceptions, get_version_tuple
from pony.utils import is_utf8, throw

trigger_template = """
Expand Down Expand Up @@ -56,6 +56,7 @@ class OraColumn(dbschema.Column):
auto_template = None

class OraSchema(dbschema.DBSchema):
dialect = 'Oracle'
table_class = OraTable
column_class = OraColumn

Expand All @@ -64,9 +65,16 @@ def __init__(monad, translator, value=None):
assert value in (None, '')
sqltranslation.ConstMonad.__init__(monad, translator, None)

class OraConstMonad(sqltranslation.ConstMonad):
@staticmethod
def new(translator, value):
if value == '': value = None
return sqltranslation.ConstMonad.new(translator, value)

class OraTranslator(sqltranslation.SQLTranslator):
dialect = 'Oracle'
NoneMonad = OraNoneMonad
ConstMonad = OraConstMonad

@classmethod
def get_normalized_type_of(translator, value):
Expand Down Expand Up @@ -179,19 +187,26 @@ def sql2py(converter, val):
return val

class OraDatetimeConverter(dbapiprovider.DatetimeConverter):
def sql_type(converter):
return 'TIMESTAMP(6)'
sql_type_name = 'TIMESTAMP'

class OraProvider(DBAPIProvider):
dialect = 'Oracle'
paramstyle = 'named'

table_if_not_exists_syntax = False

dbapi_module = cx_Oracle
dbschema_cls = OraSchema
translator_cls = OraTranslator
sqlbuilder_cls = OraBuilder

def inspect_connection(provider, connection):
provider.table_if_not_exists_syntax = False
sql = "select version from product_component_version where product like 'Oracle Database %'"
cursor = connection.cursor()
cursor.execute(sql)
row = cursor.fetchone()
assert row is not None
provider.server_version = get_version_tuple(row[0])

def get_default_entity_table_name(provider, entity):
return DBAPIProvider.get_default_entity_table_name(provider, entity).upper()
Expand Down Expand Up @@ -220,7 +235,7 @@ def executemany(provider, cursor, sql, arguments_list):
@wrap_dbapi_exceptions
def execute_returning_id(provider, cursor, sql, arguments):
set_input_sizes(cursor, arguments)
var = cursor.var(cx_Oracle.NUMBER)
var = cursor.var(cx_Oracle.STRING, 40, cursor.arraysize, outconverter=int)
arguments['new_id'] = var
cursor.execute(sql, arguments)
return var.getvalue()
Expand Down
4 changes: 4 additions & 0 deletions pony/orm/dbproviders/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def get_create_command(foreign_key):
return None

class SQLiteSchema(dbschema.DBSchema):
dialect = 'SQLite'
fk_class = SQLiteForeignKey

class SQLiteTranslator(sqltranslation.SQLTranslator):
Expand Down Expand Up @@ -88,11 +89,14 @@ def py2sql(converter, val):
return datetime2timestamp(val)

class SQLiteProvider(DBAPIProvider):
dialect = 'SQLite'
dbapi_module = sqlite
dbschema_cls = SQLiteSchema
translator_cls = SQLiteTranslator
sqlbuilder_cls = SQLiteBuilder

server_version = sqlite.sqlite_version_info

converter_classes = [
(bool, dbapiprovider.BoolConverter),
(unicode, dbapiprovider.UnicodeConverter),
Expand Down
5 changes: 4 additions & 1 deletion pony/orm/dbschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pony.utils import throw

class DBSchema(object):
dialect = None
def __init__(schema, provider, uppercase=True):
schema.provider = provider
schema.tables = {}
Expand Down Expand Up @@ -152,7 +153,9 @@ def get_sql(column, created_tables=None):
append(case(column.auto_template % dict(type=column.sql_type)))
else:
append(case(column.sql_type))
if column.is_pk: append(case('PRIMARY KEY'))
if column.is_pk:
if schema.dialect == 'SQLite': append(case('NOT NULL'))
append(case('PRIMARY KEY'))
else:
if column.is_unique: append(case('UNIQUE'))
if column.is_not_null: append(case('NOT NULL'))
Expand Down
3 changes: 1 addition & 2 deletions pony/orm/examples/estore.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class Category(db.Entity):
DELIVERED = 'DELIVERED'
CANCELLED = 'CANCELLED'

@db_session
def populate_database():
c1 = Customer(email='[email protected]', password='***',
name='John Smith', country='USA', address='address 1')
Expand Down Expand Up @@ -158,8 +159,6 @@ def populate_database():
OrderItem(order=o5, product=p1, price=Decimal('284.00'), quantity=1)
OrderItem(order=o5, product=p2, price=Decimal('478.50'), quantity=1)

commit()


def test_queries():

Expand Down
4 changes: 2 additions & 2 deletions pony/orm/examples/presentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Student(db.Entity):
db.generate_mapping(create_tables=True)
# db.generate_mapping(check_tables=True)

@with_transaction
@db_session
def populate_database():
if select(s for s in Student).count() > 0:
return
Expand Down Expand Up @@ -98,7 +98,7 @@ def print_students(students):
print s.name
print

@with_transaction
@db_session
def test_queries():
students = select(s for s in Student)
print_students(students)
Expand Down
4 changes: 2 additions & 2 deletions pony/orm/integration/bottle_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from bottle import HTTPResponse, HTTPError
from pony.orm.core import with_transaction
from pony.orm.core import db_session

def is_allowed_exception(e):
return isinstance(e, HTTPResponse) and not isinstance(e, HTTPError)
Expand All @@ -8,4 +8,4 @@ class PonyPlugin(object):
name = 'pony'
api = 2
def apply(self, callback, route):
return with_transaction(allowed_exceptions=is_allowed_exception)(callback)
return db_session(allowed_exceptions=is_allowed_exception)(callback)
Loading

0 comments on commit 89e455e

Please sign in to comment.