From cbc85adb26127be8c8bdc617c22ce40ead12369d Mon Sep 17 00:00:00 2001 From: Samy Sultan Date: Sat, 17 Jun 2023 12:44:21 +0300 Subject: [PATCH] use UDT with ass-array --- examples/UDT/main.go | 2 +- examples/udt_array/main.go | 318 +++++++++++++++++++++++++++++++++++++ examples/udt_pars/main.go | 35 ++-- v2/command.go | 1 - v2/connection.go | 166 ++++++++++++++++++- v2/lob.go | 2 +- v2/parameter.go | 27 +--- v2/parameter_encode.go | 38 ++--- v2/udt.go | 6 +- v2/utils.go | 113 ++++++++++--- 10 files changed, 615 insertions(+), 93 deletions(-) create mode 100644 examples/udt_array/main.go diff --git a/examples/UDT/main.go b/examples/UDT/main.go index 94248310..2d7cbc8d 100644 --- a/examples/UDT/main.go +++ b/examples/UDT/main.go @@ -129,7 +129,7 @@ func main() { fmt.Println("Can't drop UDT", err) } }() - err = conn.RegisterType("TEST_TYPE1", test1{}) + err = conn.RegisterType("TEST_TYPE1", "", test1{}) if err != nil { fmt.Println("Can't register UDT", err) return diff --git a/examples/udt_array/main.go b/examples/udt_array/main.go new file mode 100644 index 00000000..b4490b2f --- /dev/null +++ b/examples/udt_array/main.go @@ -0,0 +1,318 @@ +package main + +import ( + "database/sql" + "fmt" + go_ora "github.com/sijms/go-ora/v2" + "os" + "time" +) + +type test1 struct { + Id int64 `udt:"test_id"` + Name string `udt:"test_name"` + Data string `udt:"data"` + CreateAt time.Time `udt:"created_at"` +} + +func createPackage(conn *sql.DB) error { + t := time.Now() + sqlText := `create or replace package UDT_ARRAY_PKG AS + -- type t_test1 is table of UDT_ARRAY_TABLE.DATA%type index by binary_integer; + type t_id is table of UDT_ARRAY_TABLE.ID%type index by binary_integer; + procedure test_get1(p_id t_id, p_test1 out test_type2); + procedure test_get2(p_test test_type2); +end UDT_ARRAY_PKG;` + _, err := conn.Exec(sqlText) + if err != nil { + return err + } + sqlText = `create or replace PACKAGE BODY UDT_ARRAY_PKG AS + procedure test_get1(p_id t_id, p_test1 out test_type2) as + temp t_id := p_id; + cursor tempCur is select id, DATA from UDT_ARRAY_TABLE + WHERE id in (select column_value from table(p_id)); + tempRow tempCur%rowtype; + idx number := 1; + BEGIN + p_test1 := test_type2(); + for tempRow in tempCur loop + p_test1.extend; + p_test1(idx) := tempRow.DATA; + idx := idx + 1; + end loop; + END test_get1; + procedure test_get2(p_test test_type2) as + BEGIN + NULL; + END test_get2; +end UDT_ARRAY_PKG;` + _, err = conn.Exec(sqlText) + if err != nil { + return err + } + fmt.Println("finish create package: ", time.Now().Sub(t)) + return nil +} + +func dropPackage(conn *sql.DB) error { + t := time.Now() + _, err := conn.Exec(`drop package UDT_ARRAY_PKG`) + if err != nil { + return err + } + fmt.Println("Drop package: ", time.Now().Sub(t)) + return nil +} + +func queryRow(conn *sql.DB) error { + t := time.Now() + test := test1{} + err := conn.QueryRow(`SELECT DATA FROM UDT_ARRAY_TABLE WHERE ID=1`).Scan(&test) + if err != nil { + return err + } + fmt.Println("row: ", test) + fmt.Println("finish query row: ", time.Now().Sub(t)) + return nil +} +func query(conn *sql.DB) error { + t := time.Now() + var data []test1 + _, err := conn.Exec(`BEGIN UDT_ARRAY_PKG.TEST_GET1(:1, :2); END;`, []int{1, 3, 5, 7}, go_ora.Out{Dest: &data, Size: 5}) + if err != nil { + return err + } + fmt.Println("result: ", data) + fmt.Println("finish query: ", time.Now().Sub(t)) + return nil +} + +func get2(conn *sql.DB) error { + t := time.Now() + var data = []test1{ + { + Id: 1, + Name: "name_1", + Data: "data", + CreateAt: time.Now(), + }, + { + Id: 2, + Name: "name_2", + Data: "data", + CreateAt: time.Now(), + }, + { + Id: 3, + Name: "name_3", + Data: "data", + CreateAt: time.Now(), + }, + { + Id: 3, + Name: "name_4", + Data: "data", + CreateAt: time.Now(), + }, + { + Id: 4, + Name: "name5", + Data: "data5", + CreateAt: time.Now(), + }, + } + _, err := conn.Exec(`BEGIN UDT_ARRAY_PKG.TEST_GET2(:1); END;`, data) + if err != nil { + return err + } + fmt.Println("finish get2: ", time.Now().Sub(t)) + return nil +} +func insertData(conn *sql.DB) error { + t := time.Now() + sqlText := `INSERT INTO UDT_ARRAY_TABLE(ID, DATA) VALUES(:1, :2)` + stmt, err := conn.Prepare(sqlText) + if err != nil { + return err + } + //data := make([]test1, 0, 10) + //ids := make([]int, 0, 10) + for x := 0; x < 10; x++ { + _, err = stmt.Exec(x+1, test1{int64(x + 1), + fmt.Sprintf("name_%d", x+1), + "DATA", + time.Now()}) + if err != nil { + return err + } + } + //_, err := conn.ExecContext(context.Background(), sqlText, + // []driver.NamedValue{ + // driver.NamedValue{"id", 0, ids}, + // driver.NamedValue{"data", 0, data}, + // }) + if err != nil { + return err + } + fmt.Println("finish insert: ", time.Now().Sub(t)) + return nil +} +func createTable(conn *sql.DB) error { + t := time.Now() + sqlText := `CREATE TABLE UDT_ARRAY_TABLE +( + ID number(10, 0), + DATA TEST_TYPE1 +)` + _, err := conn.Exec(sqlText) + if err != nil { + return err + } + fmt.Println("finish create table: ", time.Now().Sub(t)) + return nil +} + +func dropTable(conn *sql.DB) error { + t := time.Now() + _, err := conn.Exec(`DROP TABLE UDT_ARRAY_TABLE PURGE`) + if err != nil { + return err + } + fmt.Println("finish drop table: ", time.Now().Sub(t)) + return nil +} + +func ceateUDTArray(conn *sql.DB) error { + t := time.Now() + _, err := conn.Exec(`CREATE or REPLACE TYPE TEST_TYPE2 AS TABLE of TEST_TYPE1`) + if err != nil { + return err + } + fmt.Println("Finish create UDT Array: ", time.Now().Sub(t)) + return nil +} +func dropUDTArray(conn *sql.DB) error { + t := time.Now() + _, err := conn.Exec("drop type TEST_TYPE2") + if err != nil { + return err + } + fmt.Println("Finish drop UDT Array: ", time.Now().Sub(t)) + return nil +} +func createUDT(conn *sql.DB) error { + t := time.Now() + sqlText := `create or replace TYPE TEST_TYPE1 IS OBJECT +( + TEST_ID NUMBER(10, 0), + TEST_NAME VARCHAR2(10), + DATA CLOB, + CREATED_AT DATE +)` + _, err := conn.Exec(sqlText) + if err != nil { + return err + } + fmt.Println("Finish create UDT: ", time.Now().Sub(t)) + return nil +} + +func dropUDT(conn *sql.DB) error { + t := time.Now() + + _, err := conn.Exec("drop type TEST_TYPE1") + if err != nil { + return err + } + fmt.Println("Finish drop UDT: ", time.Now().Sub(t)) + return nil +} + +func main() { + conn, err := sql.Open("oracle", os.Getenv("DSN")) + if err != nil { + fmt.Println("can't open connection: ", err) + return + } + defer func() { + err = conn.Close() + if err != nil { + fmt.Println("can't close connection: ", err) + } + }() + err = createUDT(conn) + if err != nil { + fmt.Println("can't create UDT: ", err) + return + } + defer func() { + err = dropUDT(conn) + if err != nil { + fmt.Println("can't drop UDT: ", err) + } + }() + err = ceateUDTArray(conn) + if err != nil { + fmt.Println("can't create UDT array: ", err) + return + } + defer func() { + err = dropUDTArray(conn) + if err != nil { + fmt.Println("can't drop UDT array: ", err) + } + }() + err = go_ora.RegisterType(conn, "TEST_TYPE1", "TEST_TYPE2", test1{}) + if err != nil { + fmt.Println("can't register type: ", err) + return + } + + err = createTable(conn) + if err != nil { + fmt.Println("can't create table: ", err) + return + } + + defer func() { + err = dropTable(conn) + if err != nil { + fmt.Println("can't drop table: ", err) + } + }() + + //insert some data + err = insertData(conn) + if err != nil { + fmt.Println("can't insert data: ", err) + return + } + //create package + err = createPackage(conn) + if err != nil { + fmt.Println("can't create package: ", err) + return + } + defer func() { + err = dropPackage(conn) + if err != nil { + fmt.Println("can't drop package: ", err) + } + }() + err = query(conn) + if err != nil { + fmt.Println("can't query: ", err) + return + } + //err = queryRow(conn) + //if err != nil { + // fmt.Println("can't query row: ", err) + // return + //} + err = get2(conn) + if err != nil { + fmt.Println("can't get2: ", err) + return + } +} diff --git a/examples/udt_pars/main.go b/examples/udt_pars/main.go index 55e1c72a..20d76d86 100644 --- a/examples/udt_pars/main.go +++ b/examples/udt_pars/main.go @@ -10,7 +10,7 @@ import ( "time" ) -type test1 struct { +type test2 struct { Id int64 `udt:"test_id"` Name *sql.NullString `udt:"test_name"` Data1 string `udt:"test_data1"` @@ -23,7 +23,7 @@ func createTable(conn *go_ora.Connection) error { t := time.Now() sqlText := `CREATE TABLE GOORA_TEMP_VISIT( VISIT_ID number(10) NOT NULL, - TEST_TYPE TEST_TYPE1, + TEST_TYPE UDTPAR_TYPE, PRIMARY KEY(VISIT_ID) )` _, err := conn.Exec(sqlText) @@ -42,8 +42,8 @@ func insertData(conn *go_ora.Connection) error { _ = stmt.Close() }() nameText := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - for index = 1; index <= 100; index++ { - var test test1 + for index = 1; index <= 1; index++ { + var test test2 test.Id = int64(index) test.Name = &sql.NullString{String: nameText[:index], Valid: true} test.Data1 = nameText @@ -73,7 +73,7 @@ END;` }() var ( visitId int64 - test test1 + test test2 ) _, err := stmt.Exec([]driver.Value{go_ora.Out{Dest: &visitId}, go_ora.Out{Dest: &test}}) if err != nil { @@ -100,7 +100,7 @@ func queryData(conn *go_ora.Connection) error { } var ( visitID int64 - test test1 + test test2 count int ) for rows.Next_() { @@ -135,11 +135,11 @@ func dropTable(conn *go_ora.Connection) error { func createUDT(conn *go_ora.Connection) error { t := time.Now() - sqlText := `create or replace TYPE TEST_TYPE1 IS OBJECT + sqlText := `create or replace TYPE UDTPAR_TYPE IS OBJECT ( TEST_ID NUMBER(10, 0), TEST_NAME VARCHAR2(200), - TEST_DATA1 VARCHAR2(200), + TEST_DATA1 CLOB, TEST_DATA2 VARCHAR2(200), TEST_DATA3 VARCHAR2(200), TEST_DATE DATE @@ -158,7 +158,7 @@ func createUDT(conn *go_ora.Connection) error { func dropUDT(conn *go_ora.Connection) error { t := time.Now() - stmt := go_ora.NewStmt("drop type TEST_TYPE1", conn) + stmt := go_ora.NewStmt("drop type UDTPAR_TYPE", conn) defer func() { _ = stmt.Close() }() @@ -188,11 +188,12 @@ func main() { var ( server string ) - flag.StringVar(&server, "server", "", "Server's URL, oracle://user:pass@server/service_name") flag.Parse() - connStr := os.ExpandEnv(server) + if connStr == "" { + connStr = os.Getenv("DSN") + } if connStr == "" { fmt.Println("Missing -server option") usage() @@ -238,7 +239,7 @@ func main() { fmt.Println("Can't drop table: ", err) } }() - err = conn.RegisterType("TEST_TYPE1", test1{}) + err = conn.RegisterType("UDTPAR_TYPE", "", test2{}) if err != nil { fmt.Println("Can't register UDT: ", err) return @@ -250,11 +251,11 @@ func main() { fmt.Println("Can't insert data: ", err) return } - err = queryData(conn) - if err != nil { - fmt.Println("Can't query data: ", err) - return - } + //err = queryData(conn) + //if err != nil { + // fmt.Println("Can't query data: ", err) + // return + //} err = outputPar(conn) if err != nil { fmt.Println("Can't query output par: ", err) diff --git a/v2/command.go b/v2/command.go index 67b76565..d9f8a118 100644 --- a/v2/command.go +++ b/v2/command.go @@ -377,7 +377,6 @@ func (stmt *Stmt) writePars() error { session.PutUint(size, 4, true, true) session.PutBytes(1, 1) session.PutClr(par.BValue) - par.MaxNoOfArrayElements = 3 //tempBuffer := bytes.Buffer{} //if par.MaxNoOfArrayElements > 0 { // tempBuffer.Write([]byte{0x84, 0x1, 0xfe}) diff --git a/v2/connection.go b/v2/connection.go index f2c89d2b..82f7dbbf 100755 --- a/v2/connection.go +++ b/v2/connection.go @@ -14,6 +14,7 @@ import ( "regexp" "strconv" "strings" + "sync" ) type ConnectionState int @@ -94,20 +95,24 @@ type OracleConnector struct { dialer network.DialerContext } type OracleDriver struct { - //m sync.Mutex + dataCollected bool + cusTyp map[string]customType + mu sync.Mutex + serverCharset int + serverNCharset int //Conn *Connection //Server string //Port int //Instance string //Service string //DBName string - //UserId string + UserId string //SessionId int //SerialNum int } func init() { - sql.Register("oracle", &OracleDriver{}) + sql.Register("oracle", &OracleDriver{cusTyp: map[string]customType{}}) } func (drv *OracleDriver) OpenConnector(name string) (driver.Connector, error) { @@ -118,6 +123,7 @@ func (drv *OracleDriver) OpenConnector(name string) (driver.Connector, error) { func (connector *OracleConnector) Connect(ctx context.Context) (driver.Conn, error) { conn, err := NewConnection(connector.connectString) + conn.cusTyp = connector.drv.cusTyp if err != nil { return nil, err } @@ -126,9 +132,20 @@ func (connector *OracleConnector) Connect(ctx context.Context) (driver.Conn, err if err != nil { return nil, err } + connector.drv.collectData(conn) return conn, nil } +func (driver *OracleDriver) collectData(conn *Connection) { + if !driver.dataCollected { + driver.mu.Lock() + defer driver.mu.Unlock() + driver.UserId = conn.connOption.UserID + driver.serverCharset = conn.tcpNego.ServerCharset + driver.serverNCharset = conn.tcpNego.ServernCharset + driver.dataCollected = true + } +} func (connector *OracleConnector) Driver() driver.Driver { return connector.drv } @@ -140,6 +157,7 @@ func (connector *OracleConnector) Dialer(dialer network.DialerContext) { // Open return a new open connection func (drv *OracleDriver) Open(name string) (driver.Conn, error) { conn, err := NewConnection(name) + conn.cusTyp = drv.cusTyp if err != nil { return nil, err } @@ -147,6 +165,7 @@ func (drv *OracleDriver) Open(name string) (driver.Conn, error) { if err != nil { return nil, err } + drv.collectData(conn) return conn, nil } @@ -1228,3 +1247,144 @@ func (conn *Connection) ResetSession(ctx context.Context) error { } return nil } + +func RegisterType(conn *sql.DB, typeName, arrayTypeName string, typeObj interface{}) error { + // ping first to avoid error when calling register type after open connection + err := conn.Ping() + if err != nil { + return err + } + if driver, ok := conn.Driver().(*OracleDriver); ok { + return RegisterTypeWithOwner(conn, driver.UserId, typeName, arrayTypeName, typeObj) + } + return errors.New("the driver used is not a go-ora driver type") +} + +func RegisterTypeWithOwner(conn *sql.DB, owner, typeName, arrayTypeName string, typeObj interface{}) error { + if len(owner) == 0 { + return errors.New("owner can't be empty") + } + if driver, ok := conn.Driver().(*OracleDriver); ok { + + if typeObj == nil { + return errors.New("type object cannot be nil") + } + typ := reflect.TypeOf(typeObj) + switch typ.Kind() { + case reflect.Ptr: + return errors.New("unsupported type object: Ptr") + case reflect.Array: + return errors.New("unsupported type object: Array") + case reflect.Chan: + return errors.New("unsupported type object: Chan") + case reflect.Map: + return errors.New("unsupported type object: Map") + case reflect.Slice: + return errors.New("unsupported type object: Slice") + } + if typ.Kind() != reflect.Struct { + return errors.New("type object should be of structure type") + } + cust := customType{ + owner: owner, + name: typeName, + arrayTypeName: arrayTypeName, + typ: typ, + fieldMap: map[string]int{}, + } + sqlText := `SELECT type_oid FROM ALL_TYPES WHERE UPPER(OWNER)=:1 AND UPPER(TYPE_NAME)=:2` + err := conn.QueryRow(sqlText, strings.ToUpper(owner), strings.ToUpper(typeName)).Scan(&cust.toid) + if err != nil { + return err + } + if len(cust.arrayTypeName) > 0 { + err = conn.QueryRow(sqlText, strings.ToUpper(owner), strings.ToUpper(arrayTypeName)).Scan(&cust.arrayTOID) + if err != nil { + return err + } + } + sqlText = `SELECT ATTR_NAME, ATTR_TYPE_NAME, LENGTH, ATTR_NO + FROM ALL_TYPE_ATTRS + WHERE UPPER(OWNER)=:1 AND UPPER(TYPE_NAME)=:2` + rows, err := conn.Query(sqlText, strings.ToUpper(owner), strings.ToUpper(typeName)) + if err != nil { + return err + } + var ( + attName sql.NullString + attOrder int64 + attTypeName sql.NullString + length sql.NullInt64 + ) + for rows.Next() { + err = rows.Scan(&attName, &attTypeName, &length, &attOrder) + if err != nil { + return err + } + for int(attOrder) > len(cust.attribs) { + cust.attribs = append(cust.attribs, ParameterInfo{ + Direction: Input, + Flag: 3, + }) + } + param := &cust.attribs[attOrder-1] + param.Name = attName.String + param.TypeName = attTypeName.String + switch strings.ToUpper(attTypeName.String) { + case "NUMBER": + param.DataType = NUMBER + param.MaxLen = converters.MAX_LEN_NUMBER + case "VARCHAR2": + param.DataType = NCHAR + param.CharsetForm = 1 + param.ContFlag = 16 + param.MaxCharLen = int(length.Int64) + param.CharsetID = driver.serverCharset + param.MaxLen = int(length.Int64) * converters.MaxBytePerChar(param.CharsetID) + case "NVARCHAR2": + param.DataType = NCHAR + param.CharsetForm = 2 + param.ContFlag = 16 + param.MaxCharLen = int(length.Int64) + param.CharsetID = driver.serverNCharset + param.MaxLen = int(length.Int64) * converters.MaxBytePerChar(param.CharsetID) + case "TIMESTAMP": + fallthrough + case "DATE": + param.DataType = DATE + param.MaxLen = 11 + case "RAW": + param.DataType = RAW + param.MaxLen = int(length.Int64) + case "BLOB": + param.DataType = OCIBlobLocator + param.MaxLen = int(length.Int64) + case "CLOB": + param.DataType = OCIClobLocator + param.CharsetForm = 1 + param.ContFlag = 16 + param.CharsetID = driver.serverCharset + param.MaxCharLen = int(length.Int64) + param.MaxLen = int(length.Int64) * converters.MaxBytePerChar(param.CharsetID) + case "NCLOB": + param.DataType = OCIClobLocator + param.CharsetForm = 2 + param.ContFlag = 16 + param.CharsetID = driver.serverNCharset + param.MaxCharLen = int(length.Int64) + param.MaxLen = int(length.Int64) * converters.MaxBytePerChar(param.CharsetID) + default: + return fmt.Errorf("unsupported attribute type: %s", attTypeName.String) + } + } + if len(cust.attribs) == 0 { + return fmt.Errorf("unknown or empty type: %s", typeName) + } + cust.loadFieldMap() + driver.mu.Lock() + defer driver.mu.Unlock() + driver.cusTyp[strings.ToUpper(typeName)] = cust + return nil + } + return errors.New("the driver used is not a go-ora driver type") +} diff --git a/v2/lob.go b/v2/lob.go index e3de6133..3c780504 100644 --- a/v2/lob.go +++ b/v2/lob.go @@ -139,7 +139,7 @@ func (lob *Lob) putData(data []byte) error { func (lob *Lob) putString(data string) error { conn := lob.connection conn.connOption.Tracer.Printf("Put Lob String: %d character", int64(len([]rune(data)))) - //lob.initialize() + lob.initialize() var strConv converters.IStringConverter if lob.variableWidthChar() { if conn.dBVersion.Number < 10200 && lob.littleEndianClob() { diff --git a/v2/parameter.go b/v2/parameter.go index f96322c6..8fd5ea77 100644 --- a/v2/parameter.go +++ b/v2/parameter.go @@ -1033,35 +1033,10 @@ func (par *ParameterInfo) decodePrimValue(conn *Connection, udt bool) error { case IntervalDS_DTY: par.oPrimValue = converters.ConvertIntervalDS_DTY(par.BValue) case XMLType: - newState := network.SessionState{InBuffer: par.BValue} - session.SaveState(&newState) - _, err = session.GetByte() - if err != nil { - return err - } - var ctl int - ctl, err = session.GetInt(4, true, true) + err = decodeObject(conn, par) if err != nil { return err } - if ctl == 0xFE { - _, err = session.GetInt(4, false, true) - if err != nil { - return err - } - } - pars := make([]ParameterInfo, 0) - for _, attrib := range par.cusType.attribs { - tempPar := attrib - tempPar.Direction = par.Direction - err = tempPar.decodePrimValue(conn, true) - if err != nil { - return err - } - pars = append(pars, tempPar) - } - par.oPrimValue = pars - _ = session.LoadState() default: return fmt.Errorf("unable to decode oracle type %v to its primitive value", par.DataType) } diff --git a/v2/parameter_encode.go b/v2/parameter_encode.go index a73cf322..ff2ffa81 100644 --- a/v2/parameter_encode.go +++ b/v2/parameter_encode.go @@ -174,23 +174,21 @@ func (par *ParameterInfo) encodePrimValue(conn *Connection) error { session := conn.session //arrayBuffer.Write([]byte{1}) if par.DataType == XMLType { - arrayBuffer.Write([]byte{uint8(par.MaxNoOfArrayElements)}) + // number of fields + arrayBuffer.Write([]byte{1, 3}) + //session.WriteUint(&arrayBuffer, len(par.cusType.attribs), 2, true, true) + // number of elements + session.WriteUint(&arrayBuffer, par.MaxNoOfArrayElements, 2, true, false) + //arrayBuffer.Write([]byte{uint8(par.MaxNoOfArrayElements)}) } else { session.WriteUint(&arrayBuffer, par.MaxNoOfArrayElements, 4, true, true) } - //session.WriteUint(&arrayBuffer, par.MaxNoOfArrayElements, 4, true, true) - //session.WriteUint(&arrayBuffer, par.MaxNoOfArrayElements, 2, true, false) - //arrayBuffer.Write([]byte{3}) - //03 11 00 01 10 - //arrayBuffer.Write([]byte{0x3, 0x11, 0x0, 0x1, 0x10, uint8(par.MaxNoOfArrayElements)}) for _, tempPar := range value { // get the binary representation of the item err = tempPar.encodePrimValue(conn) - if par.DataType == XMLType { - //session.WriteUint(&arrayBuffer, len(tempPar.BValue), 4, true, true) - //session.WriteUint(&arrayBuffer, 0xff, 4, true, true) - arrayBuffer.Write([]byte{0, 0, 0, 0xfe}) - } + //if par.DataType == XMLType { + // arrayBuffer.Write([]byte{0, 0, 0, 0xfe}) + //} if err != nil { return err } @@ -203,6 +201,9 @@ func (par *ParameterInfo) encodePrimValue(conn *Connection) error { // save binary representation to the buffer session.WriteClr(&arrayBuffer, tempPar.BValue) } + //if par.DataType == XMLType { + // arrayBuffer.Write([]byte{0}) + //} par.BValue = arrayBuffer.Bytes() } // for array set maxsize of nchar and raw @@ -215,23 +216,12 @@ func (par *ParameterInfo) encodePrimValue(conn *Connection) error { } if par.DataType == XMLType { par.ToID = par.cusType.arrayTOID - par.BValue = encodeObject(conn, par.BValue, true) + par.BValue = encodeObject(conn.session, par.BValue, true) par.Flag = 3 par.MaxNoOfArrayElements = 0 } } else { - par.BValue = encodeObject(conn, par.BValue, false) - // encode UDT object - //size := len(par.BValue) + 7 - //itemData := bytes.Buffer{} - //conn.session.WriteInt(&itemData, size, 4, true, true) - //itemData.Write([]byte{1, 1}) - //fieldsData := bytes.Buffer{} - //fieldsData.Write([]byte{0x84, 0x1, 0xfe}) - //conn.session.WriteInt(&fieldsData, size, 4, true, false) - //fieldsData.Write(par.BValue) - //conn.session.WriteClr(&itemData, fieldsData.Bytes()) - //par.BValue = itemData.Bytes() + par.BValue = encodeObject(conn.session, par.BValue, false) } default: return fmt.Errorf("unsupported primitive type: %v", reflect.TypeOf(par.iPrimValue).Name()) diff --git a/v2/udt.go b/v2/udt.go index 8959d1a1..c63b42a3 100644 --- a/v2/udt.go +++ b/v2/udt.go @@ -249,8 +249,10 @@ END;` attTypeName string ) for rows.Next_() { - rows.Scan(&attName, &attOrder, &attTypeName) - + err = rows.Scan(&attName, &attOrder, &attTypeName) + if err != nil { + return err + } for int(attOrder) > len(cust.attribs) { cust.attribs = append(cust.attribs, ParameterInfo{ Direction: Input, diff --git a/v2/utils.go b/v2/utils.go index 16409638..ebbec2eb 100644 --- a/v2/utils.go +++ b/v2/utils.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "github.com/sijms/go-ora/v2/converters" + "github.com/sijms/go-ora/v2/network" "io" "reflect" "regexp" @@ -601,18 +602,30 @@ func setUDTObject(value reflect.Value, cust *customType, input []ParameterInfo) } return setUDTObject(value.Elem(), cust, input) } - tempObj := reflect.New(cust.typ) - for _, par := range input { - if fieldIndex, ok := cust.fieldMap[par.Name]; ok { - err := setFieldValue(tempObj.Elem().Field(fieldIndex), par.cusType, par.oPrimValue) - if err != nil { - return err + if value.Kind() == reflect.Slice || value.Kind() == reflect.Array { + arrayObj := reflect.MakeSlice(reflect.SliceOf(cust.typ), 0, len(input)) + for _, par := range input { + if temp, ok := par.oPrimValue.([]ParameterInfo); ok { + tempObj2 := reflect.New(cust.typ) + err := setFieldValue(tempObj2.Elem(), par.cusType, temp) + if err != nil { + return err + } + arrayObj = reflect.Append(arrayObj, tempObj2.Elem()) } } - } - if value.Kind() == reflect.Ptr { - value.Elem().Set(tempObj.Elem()) + value.Set(arrayObj) } else { + tempObj := reflect.New(cust.typ) + for _, par := range input { + + if fieldIndex, ok := cust.fieldMap[par.Name]; ok { + err := setFieldValue(tempObj.Elem().Field(fieldIndex), par.cusType, par.oPrimValue) + if err != nil { + return err + } + } + } value.Set(tempObj.Elem()) } return nil @@ -963,19 +976,83 @@ func getTOID(conn *Connection, owner, typeName string) ([]byte, error) { return ret, rows.Err() } -func encodeObject(conn *Connection, objectData []byte, isArray bool) []byte { - size := len(objectData) + 7 - //itemData := bytes.Buffer{} - //conn.session.WriteInt(&itemData, size, 4, true, true) - //itemData.Write([]byte{1, 1}) +func encodeObject(session *network.Session, objectData []byte, isArray bool) []byte { + size := len(objectData) fieldsData := bytes.Buffer{} if isArray { - fieldsData.Write([]byte{0x88, 0x1, 0xfe}) + fieldsData.Write([]byte{0x88, 0x1}) } else { - fieldsData.Write([]byte{0x84, 0x1, 0xfe}) + fieldsData.Write([]byte{0x84, 0x1}) + } + if (size + 7) < 0xfe { + size += 3 + fieldsData.Write([]byte{uint8(size)}) + } else { + size += 7 + fieldsData.Write([]byte{0xfe}) + session.WriteInt(&fieldsData, size, 4, true, false) } - // from here object should encode - conn.session.WriteInt(&fieldsData, size, 4, true, false) fieldsData.Write(objectData) return fieldsData.Bytes() } + +func decodeObject(conn *Connection, parent *ParameterInfo) error { + session := conn.session + newState := network.SessionState{InBuffer: parent.BValue} + session.SaveState(&newState) + objectType, err := session.GetByte() + if err != nil { + return err + } + ctl, err := session.GetInt(4, true, true) + if err != nil { + return err + } + if ctl == 0xFE { + _, err = session.GetInt(4, false, true) + if err != nil { + return err + } + } + switch objectType { + case 0x88: + _ /*attribsLen*/, err := session.GetInt(2, true, true) + if err != nil { + return err + } + + itemsLen, err := session.GetInt(2, false, true) + if err != nil { + return err + } + pars := make([]ParameterInfo, 0, itemsLen) + for x := 0; x < itemsLen; x++ { + tempPar := parent.clone() + tempPar.Direction = parent.Direction + tempPar.BValue, err = session.GetClr() + if err != nil { + return err + } + err = decodeObject(conn, &tempPar) + if err != nil { + return err + } + pars = append(pars, tempPar) + } + parent.oPrimValue = pars + case 0x84: + pars := make([]ParameterInfo, 0, len(parent.cusType.attribs)) + for _, attrib := range parent.cusType.attribs { + tempPar := attrib + tempPar.Direction = parent.Direction + err = tempPar.decodePrimValue(conn, true) + if err != nil { + return err + } + pars = append(pars, tempPar) + } + parent.oPrimValue = pars + } + _ = session.LoadState() + return nil +}