Skip to content

Commit

Permalink
suppot enum , add column 、table desc
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikaelemmmm committed Apr 1, 2022
1 parent 2a8ba49 commit 0060ac0
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 29 deletions.
111 changes: 83 additions & 28 deletions core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ func GenerateSchema(db *sql.DB, table string, ignoreTables []string, serviceName
s.GoPackage = "./" + s.Package
}




cols, err := dbColumns(db, dbs, table)
if nil != err {
return nil, err
Expand Down Expand Up @@ -87,7 +84,7 @@ func typesFromColumns(s *Schema, cols []Column, ignoreTables []string) error {

msg, ok := messageMap[messageName]
if !ok {
messageMap[messageName] = &Message{Name: messageName}
messageMap[messageName] = &Message{Name: messageName,Comment: c.TableComment}
msg = messageMap[messageName]
}

Expand Down Expand Up @@ -116,15 +113,16 @@ func dbColumns(db *sql.DB, schema, table string) ([]Column, error) {

tableArr:= strings.Split(table,",")

q := "SELECT TABLE_NAME, COLUMN_NAME, IS_NULLABLE, DATA_TYPE, " +
"CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION, NUMERIC_SCALE, COLUMN_TYPE " +
"FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = ?"
q := "SELECT c.TABLE_NAME, c.COLUMN_NAME, c.IS_NULLABLE, c.DATA_TYPE, " +
"c.CHARACTER_MAXIMUM_LENGTH, c.NUMERIC_PRECISION, c.NUMERIC_SCALE, c.COLUMN_TYPE ,c.COLUMN_COMMENT,t.TABLE_COMMENT " +
"FROM INFORMATION_SCHEMA.COLUMNS as c LEFT JOIN INFORMATION_SCHEMA.TABLES as t on c.TABLE_NAME = t.TABLE_NAME and c.TABLE_SCHEMA = t.TABLE_SCHEMA" +
" WHERE c.TABLE_SCHEMA = ?"

if table != "" && table != "*" {
q += " AND TABLE_NAME IN('" + strings.TrimRight(strings.Join(tableArr,"' ,'"),",") + "')"
q += " AND c.TABLE_NAME IN('" + strings.TrimRight(strings.Join(tableArr,"' ,'"),",") + "')"
}

q += " ORDER BY TABLE_NAME, ORDINAL_POSITION"
q += " ORDER BY c.TABLE_NAME, c.ORDINAL_POSITION"

rows, err := db.Query(q, schema)
defer rows.Close()
Expand All @@ -137,11 +135,15 @@ func dbColumns(db *sql.DB, schema, table string) ([]Column, error) {
for rows.Next() {
cs := Column{}
err := rows.Scan(&cs.TableName, &cs.ColumnName, &cs.IsNullable, &cs.DataType,
&cs.CharacterMaximumLength, &cs.NumericPrecision, &cs.NumericScale, &cs.ColumnType)
&cs.CharacterMaximumLength, &cs.NumericPrecision, &cs.NumericScale, &cs.ColumnType,&cs.ColumnComment,&cs.TableComment)
if err != nil {
log.Fatal(err)
}

if cs.TableComment == ""{
cs.TableComment = stringx.From(cs.TableName).ToCamelWithStartLower()
}

cols = append(cols, cs)
}
if err := rows.Err(); nil != err {
Expand All @@ -151,6 +153,8 @@ func dbColumns(db *sql.DB, schema, table string) ([]Column, error) {
return cols, nil
}



// Schema is a representation of a protobuf schema.
type Schema struct {
Syntax string
Expand Down Expand Up @@ -223,6 +227,8 @@ func (s *Schema) String() string {
buf.WriteString("// ------------------------------------ \n\n")

for _, m := range s.Messages {
buf.WriteString("//--------------------------------" + m.Comment+"--------------------------------")
buf.WriteString("\n")
m.GenDefaultMessage(buf)
m.GenRpcAddReqRespMessage(buf)
m.GenRpcUpdateReqMessage(buf)
Expand All @@ -233,13 +239,24 @@ func (s *Schema) String() string {

buf.WriteString("\n")

if len(s.Enums) > 0{
buf.WriteString("// ------------------------------------ \n")
buf.WriteString("// Enums\n")
buf.WriteString("// ------------------------------------ \n\n")

for _, e := range s.Enums {
buf.WriteString(fmt.Sprintf("%s\n", e))
}
}

buf.WriteString("\n")
buf.WriteString("// ------------------------------------ \n")
buf.WriteString("// Rpc Func\n")
buf.WriteString("// ------------------------------------ \n\n")

funcTpl := "service " + s.ServiceName + "{ \n"
funcTpl := "service " + s.ServiceName + "{ \n\n"
for _, m := range s.Messages {
funcTpl+= "\t //-----------------------" + m.Comment+"----------------------- \n"
funcTpl += "\t rpc Add" + m.Name + "(Add" + m.Name + "Req) returns (Add" + m.Name + "Resp); \n"
funcTpl += "\t rpc Update" + m.Name + "(Update" + m.Name + "Req) returns (Update" + m.Name + "Resp); \n"
funcTpl += "\t rpc Del" + m.Name + "(Del" + m.Name + "Req) returns (Del" + m.Name + "Resp); \n"
Expand All @@ -255,13 +272,15 @@ func (s *Schema) String() string {
// Enum represents a protocol buffer enumerated type.
type Enum struct {
Name string
Comment string
Fields []EnumField
}

// String returns a string representation of an Enum.
func (e *Enum) String() string {
buf := new(bytes.Buffer)

buf.WriteString(fmt.Sprintf("// %s \n", e.Comment))
buf.WriteString(fmt.Sprintf("enum %s {\n", e.Name))

for _, f := range e.Fields {
Expand Down Expand Up @@ -318,9 +337,10 @@ func (ef EnumField) Tag() int {
}

// newEnumFromStrings creates an enum from a name and a slice of strings that represent the names of each field.
func newEnumFromStrings(name string, ss []string) (*Enum, error) {
func newEnumFromStrings(name ,comment string, ss []string) (*Enum, error) {
enum := &Enum{}
enum.Name = name
enum.Comment = comment

for i, s := range ss {
err := enum.AppendField(NewEnumField(s, i))
Expand All @@ -338,8 +358,9 @@ type Service struct{}

// Message represents a protocol buffer message.
type Message struct {
Name string
Fields []MessageField
Name string
Comment string
Fields []MessageField
}

//gen default message
Expand All @@ -348,11 +369,17 @@ func (m Message) GenDefaultMessage(buf *bytes.Buffer) {
mOrginFields := m.Fields

curFields := []MessageField{}
var filedTag int
for _, field := range m.Fields {
if isInSlice([]string{"version", "del_state", "delete_time"}, field.Name) {
continue
}
filedTag++
field.tag = filedTag
field.Name = stringx.From(field.Name).ToCamelWithStartLower()
if field.Comment == ""{
field.Comment = field.Name
}
curFields = append(curFields, field)
}
m.Fields = curFields
Expand All @@ -371,11 +398,17 @@ func (m Message) GenRpcAddReqRespMessage(buf *bytes.Buffer) {
//req
m.Name = "Add" + mOrginName + "Req"
curFields := []MessageField{}
var filedTag int
for _, field := range m.Fields {
if isInSlice([]string{"id", "create_time", "update_time", "version", "del_state", "delete_time"}, field.Name) {
continue
}
filedTag++
field.tag = filedTag
field.Name = stringx.From(field.Name).ToCamelWithStartLower()
if field.Comment == ""{
field.Comment = field.Name
}
curFields = append(curFields, field)
}
m.Fields = curFields
Expand Down Expand Up @@ -403,11 +436,17 @@ func (m Message) GenRpcUpdateReqMessage(buf *bytes.Buffer) {

m.Name = "Update" + mOrginName + "Req"
curFields := []MessageField{}
var filedTag int
for _, field := range m.Fields {
if isInSlice([]string{"create_time", "update_time", "version", "del_state", "delete_time"}, field.Name) {
continue
}
filedTag++
field.tag = filedTag
field.Name = stringx.From(field.Name).ToCamelWithStartLower()
if field.Comment == ""{
field.Comment = field.Name
}
curFields = append(curFields, field)
}
m.Fields = curFields
Expand All @@ -434,7 +473,7 @@ func (m Message) GenRpcDelReqMessage(buf *bytes.Buffer) {

m.Name = "Del" + mOrginName + "Req"
m.Fields = []MessageField{
{Name: "id", Typ: "int64", tag: 1},
{Name: "id", Typ: "int64", tag: 1,Comment: "id"},
}
buf.WriteString(fmt.Sprintf("%s\n", m))

Expand All @@ -459,7 +498,7 @@ func (m Message) GenRpcGetByIdReqMessage(buf *bytes.Buffer) {

m.Name = "Get" + mOrginName + "ByIdReq"
m.Fields = []MessageField{
{Name: "id", Typ: "int64", tag: 1},
{Name: "id", Typ: "int64", tag: 1,Comment: "id"},
}
buf.WriteString(fmt.Sprintf("%s\n", m))

Expand All @@ -471,7 +510,7 @@ func (m Message) GenRpcGetByIdReqMessage(buf *bytes.Buffer) {
firstWord := strings.ToLower(string(m.Name[0]))
m.Name = "Get" + mOrginName + "ByIdResp"
m.Fields = []MessageField{
{Typ: mOrginName, Name: stringx.From(firstWord + mOrginName[1:]).ToCamelWithStartLower(), tag: 1},
{Typ: mOrginName, Name: stringx.From(firstWord + mOrginName[1:]).ToCamelWithStartLower(), tag: 1,Comment: stringx.From(firstWord + mOrginName[1:]).ToCamelWithStartLower()},
}
buf.WriteString(fmt.Sprintf("%s\n", m))

Expand All @@ -486,12 +525,21 @@ func (m Message) GenRpcSearchReqMessage(buf *bytes.Buffer) {
mOrginFields := m.Fields

m.Name = "Search" + mOrginName + "Req"
curFields := []MessageField{}
curFields := []MessageField{
{Typ: "int64",Name: "page",tag: 1,Comment: "page"},
{Typ: "int64",Name: "pageSize",tag: 2,Comment: "pageSize"},
}
var filedTag = len(curFields)
for _, field := range m.Fields {
if isInSlice([]string{"version", "del_state", "delete_time"}, field.Name) {
continue
}
filedTag++
field.tag = filedTag
field.Name = stringx.From(field.Name).ToCamelWithStartLower()
if field.Comment == ""{
field.Comment = field.Name
}
curFields = append(curFields, field)
}
m.Fields = curFields
Expand All @@ -505,7 +553,7 @@ func (m Message) GenRpcSearchReqMessage(buf *bytes.Buffer) {
firstWord := strings.ToLower(string(m.Name[0]))
m.Name = "Search" + mOrginName + "Resp"
m.Fields = []MessageField{
{Typ: "repeated " + mOrginName, Name: stringx.From(firstWord + mOrginName[1:]).ToCamelWithStartLower(), tag: 1},
{Typ: "repeated " + mOrginName, Name: stringx.From(firstWord + mOrginName[1:]).ToCamelWithStartLower(), tag: 1,Comment: stringx.From(firstWord + mOrginName[1:]).ToCamelWithStartLower()},
}
buf.WriteString(fmt.Sprintf("%s\n", m))

Expand All @@ -520,7 +568,7 @@ func (m Message) String() string {

buf.WriteString(fmt.Sprintf("message %s {\n", m.Name))
for _, f := range m.Fields {
buf.WriteString(fmt.Sprintf("%s%s;\n", indent, f))
buf.WriteString(fmt.Sprintf("%s%s; //%s\n", indent, f,f.Comment))
}
buf.WriteString("}\n")

Expand All @@ -545,11 +593,12 @@ type MessageField struct {
Typ string
Name string
tag int
Comment string
}

// NewMessageField creates a new message field.
func NewMessageField(typ, name string, tag int) MessageField {
return MessageField{typ, name, tag}
func NewMessageField(typ, name string, tag int,comment string) MessageField {
return MessageField{typ, name, tag,comment}
}

// Tag returns the unique numbered tag of the message field.
Expand All @@ -565,13 +614,20 @@ func (f MessageField) String() string {
// Column represents a database column.
type Column struct {
TableName string
TableComment string
ColumnName string
IsNullable string
DataType string
CharacterMaximumLength sql.NullInt64
NumericPrecision sql.NullInt64
NumericScale sql.NullInt64
ColumnType string
ColumnComment string
}
// Table represents a database table.
type Table struct {
TableName string
ColumnName string
}

// parseColumn parses a column and inserts the relevant fields in the Message. If an enumerated type is encountered, an Enum will
Expand All @@ -592,7 +648,7 @@ func parseColumn(s *Schema, msg *Message, col Column) error {
})

enumName := inflect.Singularize(snaker.SnakeToCamel(col.TableName)) + snaker.SnakeToCamel(col.ColumnName)
enum, err := newEnumFromStrings(enumName, enums)
enum, err := newEnumFromStrings(enumName,col.ColumnComment, enums)
if nil != err {
return err
}
Expand All @@ -604,21 +660,20 @@ func parseColumn(s *Schema, msg *Message, col Column) error {
fieldType = "bytes"
case "date", "time", "datetime", "timestamp":
//s.AppendImport("google/protobuf/timestamp.proto")

fieldType = "int64"
case "tinyint", "bool":
case "bool":
fieldType = "bool"
case "smallint", "int", "mediumint", "bigint":
case "tinyint","smallint", "int", "mediumint", "bigint":
fieldType = "int64"
case "float", "decimal", "double":
fieldType = "float"
fieldType = "double"
}

if "" == fieldType {
return fmt.Errorf("no compatible protobuf type found for `%s`. column: `%s`.`%s`", col.DataType, col.TableName, col.ColumnName)
}

field := NewMessageField(fieldType, col.ColumnName, len(msg.Fields)+1)
field := NewMessageField(fieldType, col.ColumnName, len(msg.Fields)+1,col.ColumnComment)

err := msg.AppendField(field)
if nil != err {
Expand Down
2 changes: 1 addition & 1 deletion sql2pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func main() {
port := flag.Int("port", 3306, "the database port")
user := flag.String("user", "root", "the database user")
password := flag.String("password", "root", "the database password")
schema := flag.String("schema", "order", "the database schema")
schema := flag.String("schema", "", "the database schema")
table := flag.String("table", "*", "the table schema,multiple tables ',' split. ")
serviceName := flag.String("service_name", *schema, "the protobuf service name , defaults to the database schema.")
packageName := flag.String("package", *schema, "the protocol buffer package. defaults to the database schema.")
Expand Down

0 comments on commit 0060ac0

Please sign in to comment.