1 Star 0 Fork 13

KelvinQ / shorm

forked from jongsz / shorm 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
sqlgen.go 17.98 KB
一键复制 编辑 原始数据 按行查看 历史
shengzhi 提交于 2017-02-07 16:48 . -- bug fix
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
// Copyright 2016 The shorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package shorm
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"sync"
)
type opType int8
const (
opType_rawQuery opType = iota + 1
opType_limit
opType_top
opType_cols
opType_omit
opType_table
opType_unlockTable
opType_id
opType_where
opType_in
opType_in_or
opType_between
opType_between_or
opType_and
opType_or
opType_orderby
)
type sqlClause struct {
op opType
clause string
params []interface{}
}
type sqlClauseList []sqlClause
func (list sqlClauseList) Len() int {
return len(list)
}
func (list sqlClauseList) Less(i, j int) bool {
return list[i].op < list[j].op
}
func (list sqlClauseList) Swap(i, j int) {
list[i], list[j] = list[j], list[i]
}
//SqlGenerator that generate standard sql statement
type SqlGenerator interface {
GenSelect(table *TableMetadata, sqls sqlClauseList) (string, []interface{})
//Generates insert sql
GenInsert(value reflect.Value, table *TableMetadata, sqls sqlClauseList, hasMultiRows bool) (string, []interface{})
//Generates multiple rows
GenMultiInsert(value reflect.Value, table *TableMetadata, sqls sqlClauseList) (string, []interface{})
//Generats update sql
GenUpdate(value reflect.Value, table *TableMetadata, sqls sqlClauseList) (string, []interface{})
//Generats delete sql
GenDelete(table *TableMetadata, sqls sqlClauseList) (string, []interface{})
//Generates count sql
GenCount(table *TableMetadata, sqls sqlClauseList) (string, []interface{})
}
type BaseGenerator struct {
bufPool *sync.Pool
wrapFunc func(string) string
}
func newBaseGenerator() *BaseGenerator {
g := BaseGenerator{bufPool: &sync.Pool{}}
g.bufPool.New = func() interface{} { return &bytes.Buffer{} }
return &g
}
func (b *BaseGenerator) putBuf(buf *bytes.Buffer) {
buf.Reset()
b.bufPool.Put(buf)
}
func (b *BaseGenerator) getBuf() *bytes.Buffer {
return b.bufPool.Get().(*bytes.Buffer)
}
func (m *BaseGenerator) GenCount(table *TableMetadata, sqls sqlClauseList) (string, []interface{}) {
buf := m.getBuf()
defer m.putBuf(buf)
var args []interface{}
sort.Sort(sqls)
hasWhere := false
buf.WriteString(fmt.Sprintf("select count(1) from %s", m.wrapColumn(table.Name)))
for _, s := range sqls {
switch s.op {
case opType_rawQuery:
return s.clause, s.params
case opType_unlockTable:
buf.WriteString(" with(nolock) ")
case opType_id:
if hasWhere {
buf.WriteString(fmt.Sprintf(" and %s=?", table.IdColumn.name))
} else {
buf.WriteString(fmt.Sprintf(" where %s=?", table.IdColumn.name))
hasWhere = true
}
args = append(args, s.params...)
case opType_where:
if hasWhere {
buf.WriteString(fmt.Sprintf(" and %s", s.clause))
} else {
buf.WriteString(fmt.Sprintf(" where %s", s.clause))
hasWhere = true
}
args = append(args, s.params...)
case opType_and:
buf.WriteString(fmt.Sprintf(" and %s", s.clause))
args = append(args, s.params...)
case opType_or:
buf.WriteString(fmt.Sprintf(" or (%s)", s.clause))
args = append(args, s.params...)
case opType_in:
if len(s.params) > 0 {
if hasWhere {
buf.WriteString(fmt.Sprintf(" and %s in (%s)", s.clause, m.makeInArgs(s.params)))
} else {
buf.WriteString(fmt.Sprintf(" where %s in (%s)", s.clause, m.makeInArgs(s.params)))
hasWhere = true
}
}
case opType_in_or:
if len(s.params) > 0 {
buf.WriteString(fmt.Sprintf(" or(%s in (%s))", s.clause, m.makeInArgs(s.params)))
}
case opType_between:
if hasWhere {
buf.WriteString(fmt.Sprintf(" and %s between ? and ?", s.clause))
} else {
buf.WriteString(fmt.Sprintf(" where %s between ? and ?", s.clause))
hasWhere = true
}
args = append(args, s.params...)
case opType_between_or:
buf.WriteString(fmt.Sprintf(" or (%s between ? and ?)", s.clause))
args = append(args, s.params...)
default:
break
}
}
return fmt.Sprintf(buf.String()), args
}
//Generates select SQL statement
func (m *BaseGenerator) GenSelect(table *TableMetadata, sqls sqlClauseList) (string, []interface{}) {
buf := m.getBuf()
defer m.putBuf(buf)
var args []interface{}
var colNames string
var omitCols []string
sqls = append(sqls, sqlClause{op: opType_table, clause: m.wrapColumn(table.Name)})
sort.Sort(sqls)
isPaging := false
hasWhere := false
var pagingParam []interface{}
buf.WriteString("select ")
for _, s := range sqls {
switch s.op {
case opType_rawQuery:
return s.clause, s.params
case opType_top:
// buf.WriteString(fmt.Sprintf("top %v ", s.params...))
isPaging = true
pagingParam = []interface{}{0, 1}
case opType_cols:
colNames = s.clause
case opType_omit:
omitCols = strings.Split(strings.ToLower(s.clause), ",")
case opType_table:
buf.WriteString("%s")
buf.WriteString(fmt.Sprintf(" from %v", s.clause))
case opType_unlockTable:
buf.WriteString(" with(nolock) ")
case opType_id:
if hasWhere {
buf.WriteString(fmt.Sprintf(" and %s=?", table.IdColumn.name))
} else {
buf.WriteString(fmt.Sprintf(" where %s=?", table.IdColumn.name))
hasWhere = true
}
args = append(args, s.params...)
case opType_where:
if hasWhere {
buf.WriteString(fmt.Sprintf(" and %s", s.clause))
} else {
buf.WriteString(fmt.Sprintf(" where %s", s.clause))
hasWhere = true
}
args = append(args, s.params...)
case opType_and:
buf.WriteString(fmt.Sprintf(" and %s", s.clause))
args = append(args, s.params...)
case opType_or:
buf.WriteString(fmt.Sprintf(" or (%s)", s.clause))
args = append(args, s.params...)
case opType_in:
if len(s.params) > 0 {
if hasWhere {
buf.WriteString(fmt.Sprintf(" and %s in (%s)", s.clause, m.makeInArgs(s.params)))
} else {
buf.WriteString(fmt.Sprintf(" where %s in (%s)", s.clause, m.makeInArgs(s.params)))
hasWhere = true
}
}
case opType_in_or:
if len(s.params) > 0 {
buf.WriteString(fmt.Sprintf(" or(%s in (%s))", s.clause, m.makeInArgs(s.params)))
}
case opType_between:
if hasWhere {
buf.WriteString(fmt.Sprintf(" and %s between ? and ?", s.clause))
} else {
buf.WriteString(fmt.Sprintf(" where %s between ? and ?", s.clause))
hasWhere = true
}
args = append(args, s.params...)
case opType_between_or:
buf.WriteString(fmt.Sprintf(" or (%s between ? and ?)", s.clause))
args = append(args, s.params...)
case opType_limit:
isPaging = true
pagingParam = s.params
case opType_orderby:
buf.WriteString(" order by ")
buf.WriteString(s.clause)
default:
break
}
}
if isPaging {
buf.WriteString(fmt.Sprintf(" limit %v,%v", pagingParam[0], pagingParam[1]))
}
if len(colNames) <= 0 {
cols := make([]string, 0, len(table.Columns))
table.Columns.Foreach(func(colKey string, col *columnMetadata) {
if col.rwType&io_type_ro == io_type_ro {
if len(omitCols) > 0 {
for i := range omitCols {
if colKey == omitCols[i] {
return
}
}
}
cols = append(cols, m.wrapColumn(col.name))
}
})
colNames = strings.Join(cols, ",")
}
return fmt.Sprintf(buf.String(), colNames), args
}
func (m *BaseGenerator) makeInArgs(params []interface{}) string {
element := reflect.Indirect(reflect.ValueOf(params[0]))
isNumber := false
format := "'%v',"
switch element.Type().Kind() {
case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int8,
reflect.Uint, reflect.Uint8, reflect.Uint16,
reflect.Uint32, reflect.Uint64:
isNumber = true
format = "%d,"
case reflect.Float32, reflect.Float64:
isNumber = true
format = "%f,"
default:
isNumber = false
}
var buf bytes.Buffer
for _, arg := range params {
if isNumber {
buf.WriteString(fmt.Sprintf(format, arg))
} else {
buf.WriteString(fmt.Sprintf(format, arg))
}
}
buf.Truncate(buf.Len() - 1)
return buf.String()
}
func (m *BaseGenerator) wrapColumn(colName string) string {
if m.wrapFunc != nil {
return m.wrapFunc(colName)
}
return fmt.Sprintf("`%s`", colName)
}
func (b *BaseGenerator) isCustomType(t reflect.Type) bool {
return len(t.PkgPath()) > 0
}
func (b *BaseGenerator) getValue(colMeta *columnMetadata, value reflect.Value) interface{} {
if len(colMeta.parentFieldIndex) > 0 {
value = value.FieldByIndex(colMeta.parentFieldIndex)
}
field := value.FieldByIndex(colMeta.fieldIndex)
originField := field
if field.Type().Kind() == reflect.Ptr {
field = field.Elem()
}
result := field.Interface()
switch colMeta.goType.Kind() {
case reflect.Ptr:
if colMeta.isDBConverter {
return originField.Interface().(Marshaler).ToDB()
}
data, _ := json.MarshalIndent(result, "", "")
var buf bytes.Buffer
json.Compact(&buf, data)
return buf.String()
case reflect.Slice:
if field.Len() <= 0 {
return ""
}
data, _ := json.MarshalIndent(result, "", "")
var buf bytes.Buffer
json.Compact(&buf, data)
return buf.String()
case reflect.Struct:
if colMeta.specialType == specialType_time {
return result
}
if colMeta.isDBConverter {
return result.(Marshaler).ToDB()
}
data, _ := json.MarshalIndent(result, "", "")
var buf bytes.Buffer
json.Compact(&buf, data)
return buf.String()
case reflect.String:
if b.isCustomType(colMeta.goType) {
return fmt.Sprintf("%v", result)
}
return result
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if b.isCustomType(colMeta.goType) {
val, _ := strconv.ParseInt(fmt.Sprintf("%d", result), 10, 64)
return val
}
return result
case reflect.Float32, reflect.Float64:
if b.isCustomType(colMeta.goType) {
val, _ := strconv.ParseFloat(fmt.Sprintf("%d", result), 64)
return val
}
return result
default:
return result
}
}
func (m *BaseGenerator) GenMultiInsert(value reflect.Value, table *TableMetadata, sqls sqlClauseList) (string, []interface{}) {
buf := m.getBuf()
defer m.putBuf(buf)
args := make([]interface{}, 0, len(table.Columns))
var colNames []string
include := true
Loop:
for _, s := range sqls {
switch s.op {
case opType_rawQuery:
return s.clause, s.params
case opType_cols:
colNames = strings.Split(strings.ToLower(s.clause), ",")
break Loop
case opType_omit:
colNames = strings.Split(strings.ToLower(s.clause), ",")
include = false
}
}
table.Columns.Foreach(func(col string, meta *columnMetadata) {
if meta.isAutoId || meta.rwType&io_type_wo != io_type_wo {
return
}
if len(colNames) <= 0 {
args = append(args, m.getValue(meta, value))
return
}
for _, name := range colNames {
if name == col && include {
args = append(args, m.getValue(meta, value))
return
}
if name != col && !include {
args = append(args, m.getValue(meta, value))
return
}
}
})
buf.WriteString(fmt.Sprintf("(%s),", strings.TrimSuffix(strings.Repeat("?,", len(args)), ",")))
return buf.String(), args
}
//Generates insert SQL statement
func (m *BaseGenerator) GenInsert(value reflect.Value, table *TableMetadata, sqls sqlClauseList, hasMultiRows bool) (string, []interface{}) {
buf := m.getBuf()
defer m.putBuf(buf)
args := make([]interface{}, 0, len(table.Columns))
var colNames []string
include := true
Loop:
for _, s := range sqls {
switch s.op {
case opType_rawQuery:
return s.clause, s.params
case opType_cols:
colNames = strings.Split(strings.ToLower(s.clause), ",")
break Loop
case opType_omit:
colNames = strings.Split(strings.ToLower(s.clause), ",")
include = false
}
}
buf.WriteString("insert into ")
buf.WriteString(m.wrapColumn(table.Name))
buf.WriteString("(")
table.Columns.Foreach(func(col string, meta *columnMetadata) {
if meta.isAutoId || meta.rwType&io_type_wo != io_type_wo {
return
}
if len(colNames) <= 0 {
buf.WriteString(m.wrapColumn(meta.name))
buf.WriteString(",")
args = append(args, m.getValue(meta, value))
return
}
for _, name := range colNames {
if name == col && include {
buf.WriteString(m.wrapColumn(meta.name))
buf.WriteString(",")
args = append(args, m.getValue(meta, value))
return
}
if name != col && !include {
buf.WriteString(m.wrapColumn(meta.name))
buf.WriteString(",")
args = append(args, m.getValue(meta, value))
return
}
}
})
buf.Truncate(buf.Len() - 1)
if hasMultiRows {
buf.WriteString(fmt.Sprintf(") values(%s),", strings.TrimSuffix(strings.Repeat("?,", len(args)), ",")))
} else {
buf.WriteString(fmt.Sprintf(") values(%s);", strings.TrimSuffix(strings.Repeat("?,", len(args)), ",")))
}
return buf.String(), args
}
//Generates insert SQL statement
func (m *BaseGenerator) GenUpdate(value reflect.Value, table *TableMetadata, sqls sqlClauseList) (string, []interface{}) {
buf := m.getBuf()
sqlWhere := m.getBuf()
defer m.putBuf(buf)
defer m.putBuf(sqlWhere)
args := make([]interface{}, 0, len(table.Columns))
whereArgs := make([]interface{}, 0)
var colNames []string
include := true
hasWhere := false
for _, s := range sqls {
switch s.op {
case opType_rawQuery:
return s.clause, s.params
case opType_cols:
colNames = strings.Split(strings.ToLower(s.clause), ",")
case opType_omit:
colNames = strings.Split(strings.ToLower(s.clause), ",")
include = false
case opType_id:
if hasWhere {
sqlWhere.WriteString(fmt.Sprintf(" and %s=?", table.IdColumn.name))
} else {
sqlWhere.WriteString(fmt.Sprintf(" where %s=?", table.IdColumn.name))
hasWhere = true
}
whereArgs = append(whereArgs, s.params...)
case opType_where:
if hasWhere {
sqlWhere.WriteString(fmt.Sprintf(" and %s", s.clause))
} else {
sqlWhere.WriteString(fmt.Sprintf(" where %s", s.clause))
hasWhere = true
}
whereArgs = append(whereArgs, s.params...)
case opType_and:
sqlWhere.WriteString(fmt.Sprintf(" and %s", s.clause))
whereArgs = append(whereArgs, s.params...)
case opType_or:
sqlWhere.WriteString(fmt.Sprintf(" or (%s)", s.clause))
whereArgs = append(whereArgs, s.params...)
case opType_in:
if len(s.params) > 0 {
if hasWhere {
sqlWhere.WriteString(fmt.Sprintf(" and %s in (%s)", s.clause, m.makeInArgs(s.params)))
} else {
sqlWhere.WriteString(fmt.Sprintf(" where %s in (%s)", s.clause, m.makeInArgs(s.params)))
hasWhere = true
}
}
case opType_in_or:
if len(s.params) > 0 {
sqlWhere.WriteString(fmt.Sprintf(" or(%s in (%s))", s.clause, m.makeInArgs(s.params)))
}
case opType_between:
if hasWhere {
sqlWhere.WriteString(fmt.Sprintf(" and %s between ? and ?", s.clause))
} else {
sqlWhere.WriteString(fmt.Sprintf(" where %s between ? and ?", s.clause))
hasWhere = true
}
whereArgs = append(whereArgs, s.params...)
case opType_between_or:
sqlWhere.WriteString(fmt.Sprintf(" or (%s between ? and ?)", s.clause))
whereArgs = append(whereArgs, s.params...)
}
}
buf.WriteString("update ")
buf.WriteString(m.wrapColumn(table.Name))
buf.WriteString(" set ")
table.Columns.Foreach(func(col string, meta *columnMetadata) {
if meta.isAutoId || meta.rwType&io_type_wo != io_type_wo {
return
}
if len(colNames) <= 0 {
buf.WriteString(m.wrapColumn(meta.name))
buf.WriteString("=?,")
args = append(args, m.getValue(meta, value))
return
}
for _, name := range colNames {
if name == col && include {
buf.WriteString(m.wrapColumn(meta.name))
buf.WriteString("=?,")
args = append(args, m.getValue(meta, value))
return
}
if name != col && !include {
buf.WriteString(m.wrapColumn(meta.name))
buf.WriteString("=?,")
args = append(args, m.getValue(meta, value))
return
}
}
})
buf.Truncate(buf.Len() - 1)
if sqlWhere.Len() > 0 {
buf.Write(sqlWhere.Bytes())
args = append(args, whereArgs...)
}
return buf.String(), args
}
func (m *BaseGenerator) GenDelete(table *TableMetadata, sqls sqlClauseList) (string, []interface{}) {
buf := m.getBuf()
defer m.putBuf(buf)
args := make([]interface{}, 0, len(table.Columns))
hasWhere := false
buf.WriteString("delete from ")
buf.WriteString(m.wrapColumn(table.Name))
for _, s := range sqls {
switch s.op {
case opType_rawQuery:
return s.clause, s.params
case opType_id:
if hasWhere {
buf.WriteString(fmt.Sprintf(" and %s=?", table.IdColumn.name))
} else {
buf.WriteString(fmt.Sprintf(" where %s=?", table.IdColumn.name))
hasWhere = true
}
args = append(args, s.params...)
case opType_where:
if hasWhere {
buf.WriteString(fmt.Sprintf(" and %s", s.clause))
} else {
buf.WriteString(fmt.Sprintf(" where %s", s.clause))
hasWhere = true
}
args = append(args, s.params...)
case opType_and:
buf.WriteString(fmt.Sprintf(" and %s", s.clause))
args = append(args, s.params...)
case opType_or:
buf.WriteString(fmt.Sprintf(" or (%s)", s.clause))
args = append(args, s.params...)
case opType_in:
if len(s.params) > 0 {
if hasWhere {
buf.WriteString(fmt.Sprintf(" and %s in (%s)", s.clause, m.makeInArgs(s.params)))
} else {
buf.WriteString(fmt.Sprintf(" where %s in (%s)", s.clause, m.makeInArgs(s.params)))
hasWhere = true
}
}
case opType_in_or:
if len(s.params) > 0 {
buf.WriteString(fmt.Sprintf(" or(%s in (%s))", s.clause, m.makeInArgs(s.params)))
}
case opType_between:
if hasWhere {
buf.WriteString(fmt.Sprintf(" and %s between ? and ?", s.clause))
} else {
buf.WriteString(fmt.Sprintf(" where %s between ? and ?", s.clause))
hasWhere = true
}
args = append(args, s.params...)
case opType_between_or:
buf.WriteString(fmt.Sprintf(" or (%s between ? and ?)", s.clause))
args = append(args, s.params...)
}
}
return buf.String(), args
}
Go
1
https://gitee.com/copicojeje/shorm.git
git@gitee.com:copicojeje/shorm.git
copicojeje
shorm
shorm
master

搜索帮助

53164aa7 5694891 3bd8fe86 5694891