Skip to content

Commit

Permalink
port across the stored procedure code
Browse files Browse the repository at this point in the history
  • Loading branch information
damoxc committed Mar 3, 2010
1 parent 0b08093 commit ee81e77
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 36 deletions.
14 changes: 14 additions & 0 deletions _mssql.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from sqlfront cimport DBPROCESS, BYTE

ctypedef struct _mssql_parameter_node:
_mssql_parameter_node *next
BYTE *value

cdef class MSSQLConnection:

# class property variables
Expand All @@ -21,10 +25,20 @@ cdef class MSSQLConnection:

cdef void clear_metadata(self)
cdef convert_db_value(self, BYTE *, int, int)
cdef BYTE *convert_python_value(self, value, int*, int*)
cdef fetch_next_row_dict(self, int)
cdef format_and_run_query(self, query_string, params=?)
cdef get_result(self)
cdef get_row(self, int)

cdef class MSSQLRowIterator:
cdef MSSQLConnection conn

cdef class MSSQLStoredProcedure:
cdef MSSQLConnection conn
cdef DBPROCESS *dbproc
cdef char *procname
cdef dict params
cdef _mssql_parameter_node *params_list

cdef void _bind(self, value, int, char *, int, int, int)
272 changes: 236 additions & 36 deletions _mssql.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import datetime
from sqlfront cimport *
from stdio cimport fprintf, sprintf, FILE
from stdlib cimport strlen
from python_mem cimport PyMem_Malloc, PyMem_Free

cdef extern int rmv_lcl(char *, char *, size_t)

Expand All @@ -33,42 +34,40 @@ cdef list connection_object_list = list()
#############################
## DB-API type definitions ##
#############################
cdef enum:
STRING = 1
BINARY = 2
NUMBER = 3
DATETIME = 4
DECIMAL = 5
STRING = 1
BINARY = 2
NUMBER = 3
DATETIME = 4
DECIMAL = 5

##################
## DB-LIB types ##
##################
cdef enum:
SQLBINARY = SYBBINARY
SQLBIT = SYBBIT
SQLCHAR = SYBCHAR
SQLDATETIME = SYBDATETIME
SQLDATETIM4 = SYBDATETIME4
SQLDATETIMN = SYBDATETIMN
SQLDECIMAL = SYBDECIMAL
SQLFLT4 = SYBREAL
SQLFLT8 = SYBFLT8
SQLFLTN = SYBFLTN
SQLIMAGE = SYBIMAGE
SQLINT1 = SYBINT1
SQLINT2 = SYBINT2
SQLINT4 = SYBINT4
SQLINT8 = SYBINT8
SQLINTN = SYBINTN
SQLMONEY = SYBMONEY
SQLMONEY4 = SYBMONEY4
SQLMONEYN = SYBMONEYN
SQLNUMERIC = SYBNUMERIC
SQLREAL = SYBREAL
SQLTEXT = SYBTEXT
SQLVARBINARY = SYBVARBINARY
SQLVARCHAR = SYBVARCHAR
SQLUUID = 36
SQLBINARY = SYBBINARY
SQLBIT = SYBBIT
SQLCHAR = SYBCHAR
SQLDATETIME = SYBDATETIME
SQLDATETIM4 = SYBDATETIME4
SQLDATETIMN = SYBDATETIMN
SQLDECIMAL = SYBDECIMAL
SQLFLT4 = SYBREAL
SQLFLT8 = SYBFLT8
SQLFLTN = SYBFLTN
SQLIMAGE = SYBIMAGE
SQLINT1 = SYBINT1
SQLINT2 = SYBINT2
SQLINT4 = SYBINT4
SQLINT8 = SYBINT8
SQLINTN = SYBINTN
SQLMONEY = SYBMONEY
SQLMONEY4 = SYBMONEY4
SQLMONEYN = SYBMONEYN
SQLNUMERIC = SYBNUMERIC
SQLREAL = SYBREAL
SQLTEXT = SYBTEXT
SQLVARBINARY = SYBVARBINARY
SQLVARCHAR = SYBVARCHAR
SQLUUID = 36

#######################
## Exception classes ##
Expand Down Expand Up @@ -440,6 +439,66 @@ cdef class MSSQLConnection:
else:
return (<char *>data)[:length]

cdef BYTE *convert_python_value(self, value, int *dbtype, int *length):
cdef int *intValue
cdef double *dblValue
cdef long *longValue

if value is None:
return NULL

if dbtype[0] == SQLBIT:
intValue = <int *> PyMem_Malloc(sizeof(int))
intValue[0] = <int>value
return <BYTE *><DBBIT *>intValue

elif dbtype[0] == SQLINT1:
intValue = <int *> PyMem_Malloc(sizeof(int))
intValue[0] = <int>value
return <BYTE *><DBTINYINT *>intValue

elif dbtype[0] == SQLINT2:
intValue = <int *> PyMem_Malloc(sizeof(int))
intValue[0] = <int>value
return <BYTE *><DBSMALLINT *>intValue

elif dbtype[0] == SQLINT4:
intValue = <int *> PyMem_Malloc(sizeof(int))
intValue[0] = <int>value
return <BYTE *><DBINT *>intValue

elif dbtype[0] == SQLINT8:
longValue = <long *>PyMem_Malloc(sizeof(long))
longValue[0] = <long>value
return <BYTE *>longValue

elif dbtype[0] == SQLFLT4:
dblValue = <double *>PyMem_Malloc(sizeof(double))
dblValue[0] = <double>value
return <BYTE *><DBREAL *>dblValue

elif dbtype[0] == SQLFLT8:
dblValue = <double *>PyMem_Malloc(sizeof(double))
dblValue[0] = <double>value
return <BYTE *><DBFLT8 *>dblValue

elif dbtype[0] in (SQLVARCHAR, SQLCHAR, SQLTEXT):
if type(value) not in (str, unicode):
raise TypeError()

if self._charset and type(value) is unicode:
value = value.encode(self._charset)

return <BYTE *><char *>value

elif dbtype[0] in (SQLBINARY, SQLIMAGE):
if type(value) is not str:
raise TypeError()
return <BYTE *><char *>value

# No conversion was possible so just return NULL
return NULL

def execute_non_query(self, query_string, params=None):
"""
execute_non_query(query_string, params=None)
Expand Down Expand Up @@ -687,7 +746,7 @@ cdef class MSSQLConnection:
record += (self.convert_db_value(data, col_type, len),)
return record

def init_procedure(self):
def init_procedure(self, procname):
"""
init_procedure(procname) -- creates and returns a MSSQLStoredProcedure
object.
Expand All @@ -696,7 +755,8 @@ cdef class MSSQLConnection:
and creates a MSSQLStoredProcedure object that allows parameters to
be bound.
"""

return MSSQLStoredProcedure(procname, self)

def next_result(self):
"""
nextresult() -- move to the next result, skipping all pending rows.
Expand Down Expand Up @@ -761,8 +821,145 @@ cdef class MSSQLConnection:

cdef class MSSQLStoredProcedure:

cdef bind(self, value, data_type, param_name, output, null, max_length):
pass
property connection:
"""The underlying MSSQLConnection object."""
def __get__(self):
return self.conn

property name:
"""The name of the procedure that this object represents."""
def __get__(self):
return self.procname

property parameters:
"""The parameters that have been bound to this procedure."""
def __get__(self):
return self.params


def __init__(self, str name, MSSQLConnection connection):
cdef RETCODE rtc

# We firstly want to check if tdsver is >= 8 as anything less
# doesn't support remote procedure calls.
if dbtds(connection.dbproc) < 8:
raise MSSQLDriverException("Stored Procedures aren't " \
"supported with a TDS version less than 8.")

self.conn = connection
self.dbproc = connection.dbproc
self.procname = name
self.params = dict()
self.params_list = NULL

with nogil:
rtc = dbrpcinit(self.dbproc, self.procname, 0)

check_cancel_and_raise(rtc, self.conn)

def bind(self, value, dbtype, str param_name=None, output=False,
null=False, int max_length=-1):
"""
bind(value, data_type, param_name = None, output = False,
null = False, max_length = -1) -- bind a parameter
This method binds a parameter to the stored procedure.
"""
self._bind(value, dbtype, param_name, output, null, max_length)

cdef void _bind(self, value, int dbtype, char *name, int output,
int null, int max_length):
cdef int length = 0
cdef BYTE status, *data
cdef RETCODE rtc
cdef _mssql_parameter_node *pn

# Set status according to output being True or False
status = DBRPCRETURN if output else <BYTE>0

# Store the value in the parameters dictionary for returning
# later.
self.params[name] = value

# Convert the PyObject to the db type
data = self.conn.convert_python_value(value, &dbtype, &length)

# Store the converted parameter in our parameter list so we can
# free() it later.
if data != NULL:
pn = <_mssql_parameter_node *>PyMem_Malloc(sizeof(_mssql_parameter_node))
if pn == NULL:
raise MSSQLDriverException('Out of memory')
pn.next = self.params_list
pn.value = data
self.params_list = pn

# We may need to set the data length depending on the type being
# passed to the server here.
if dbtype in (SQLVARCHAR, SQLCHAR, SQLTEXT, SQLBINARY,
SQLIMAGE):
if null or data == NULL:
length = 0
if not output:
max_length = -1
else:
length = strlen(<char *>data)
else:
# Fixed length data type
if null or output:
length = 0
max_length = -1

if status != DBRPCRETURN:
max_length = -1

IF PYMSSQL_DEBUG == 1:
pass

with nogil:
rtc = dbrpcparam(self.dbproc, name, status, dbtype,
max_length, length, data)
check_cancel_and_raise(rtc, self.conn)

def execute(self):
cdef RETCODE rtc
cdef int output_count, i, type, length
cdef char *name
cdef BYTE *data

# Cancel any pending results as this throws a server error
# otherwise.
db_cancel(self.conn)

# Send the RPC request
with nogil:
rtc = dbrpcsend(self.dbproc)
check_cancel_and_raise(rtc, self.conn)

# Wait for the server to return
with nogil:
rtc = dbsqlok(self.dbproc)
check_cancel_and_raise(rtc, self.conn)

# Need to call thsi regardless of wether or not there are output
# parameters in roder for the return status to be correct.
output_count = dbnumrets(self.dbproc)

# If there are any output parameters then we are going to want to
# set the values in the parameters dictionary.
if output_count:
for i in xrange(1, output_count + 1):
with nogil:
type = dbrettype(self.dbproc, i)
name = dbretname(self.dbproc, i)
length = dbretlen(self.dbproc, i)
data = dbretdata(self.dbproc, i)

value = self.conn.convert_db_value(data, type, length)
self.params[name] = value

# Get the return value from the procedure ready for return.
return dbretstatus(self.dbproc)

cdef void check_and_raise(RETCODE rtc, MSSQLConnection conn):
if rtc == FAIL:
Expand Down Expand Up @@ -922,6 +1119,9 @@ def quote_or_flatten(data):
def quote_data(data):
return _quote_data(data)

def connect(*args, **kwargs):
return MSSQLConnection(*args, **kwargs)

cdef void init_mssql():
global _decimal_context
cdef RETCODE rtc
Expand Down

0 comments on commit ee81e77

Please sign in to comment.