threeperson
发布于 2015-08-10 / 0 阅读
0
0

beego orm 扩展

beego orm 扩展,开发效率提升100%

  1. 主要文件功能概要

condition.go //解析各种操作原语

db.go //orm 扩展

generic_sql_builder.go //构建sql

pagination.go //分页struct

query.go //sql操作接口

sql_builder.go //sql构建接口

  1. db

package db
import (
"github.com/astaxie/beego/orm"
"strings"
)

//扩展orm 接口
type DB interface {
orm.Ormer
From(table string) *Querier
Execute(sql string, params ...interface{}) (int64, error)
}
//DB 实现struct
type db struct {
orm.Ormer
}

func NewDB() DB {
o := orm.NewOrm()
d := new(db)
d.Ormer = o
return d
}

func (d db) From(table string) *Querier {
query := NewQuery(d, NewGenericSQLBuilder())
query.From(table)
return query
}
//用于insert or update 操作(暂不支持batch操作,beego底层暂时不支持原生sql的batch操作)
func (d db) Execute(sql string, params ...interface{}) (int64, error) {

result, err := d.Raw(sql, params).Exec()
sql = strings.TrimLeft(sql, " ")
sql = strings.ToLower(sql)
var isInsert = strings.HasPrefix(sql, "insert")
if err != nil {
return -1, err
}
if isInsert {
return result.LastInsertId()
}else {
return result.RowsAffected()
}
}

2.1 query 提供jdbc操作原语

package db
import (
_ "github.com/astaxie/beego/utils/pagination"
)

type Querier struct {
DB
SQLBuilder
table string
}

func NewQuery(db DB, sqlBuilder SQLBuilder) *Querier {
q := new(Querier)
q.DB = db
q.SQLBuilder = sqlBuilder
return q
}


func (q *Querier) From(table string) *Querier {
q.table = table
q.SQLBuilder.SetTable(table)
return q
}

func (q *Querier) Table() string {
return q.table
}

func (q *Querier) GroupBy(groupBy string) *Querier {
q.SQLBuilder.GroupBy(groupBy)
return q
}
func (q *Querier) Limit(offset int, rowCount int) *Querier {
q.SQLBuilder.Limit(offset, rowCount)
return q
}

func (q *Querier) OrderBy(orderBy string) *Querier {
if len(orderBy) == 0 {
return q
}
q.SQLBuilder.OrderBy(orderBy)
return q
}

func (q *Querier) Select(columns ... string) *Querier {
q.SQLBuilder.Select(columns)
return q
}

func (q *Querier) Join(join string) *Querier {
q.SQLBuilder.Join(join)
return q
}

func (q *Querier) where(c Condition) *Querier {
sql, err := c.ToSQL(q.SQLBuilder)
if (nil == err) {
q.SQLBuilder.Where(sql, c.Params)
}
return q
}

func (q *Querier) Segment(sql string, params ...interface{}) *Querier {
return q.where(NewCondition(SEGMENT, sql, params))
}

func (q *Querier) Eq(name string, value interface{}) *Querier {
return q.where(NewCondition(EQ, name, value))
}

func (q *Querier) Where(name string, value interface{}) *Querier {
return q.Eq(name, value)
}

func (q *Querier) Not(name string, value interface{}) *Querier {
return q.where(NewCondition(NOT_EQ, name, value))
}

func (q *Querier) In(name string, values []interface{}) *Querier {
return q.where(NewCondition(IN, name, values))
}

func (q *Querier) NotIn(name string, values []interface{}) *Querier {
return q.where(NewCondition(NOT_IN, name, values))
}
func (q Querier) Between(name string, values []interface{}) *Querier {
return q.where(NewCondition(BETWEEN, name, values))
}

func (q *Querier) NotBetween(name string, values []interface{}) *Querier {
return q.where(NewCondition(NOT_BETWEEN, name, values))
}

func (q *Querier) Less(name string, value interface{}) *Querier {
return q.where(NewCondition(LESS, name, value))
}

func (q *Querier) LessOrEquals(name string, value interface{}) *Querier {
return q.where(NewCondition(LE, name, value))
}


func (q *Querier) Great(name string, value interface{}) *Querier {
return q.where(NewCondition(GREAT, name, value))
}

func (q *Querier) GreatOrEquals(name string, value interface{}) *Querier {
return q.where(NewCondition(GE, name, value))
}

func (q *Querier) IsNull(name string, value interface{}) *Querier {
return q.where(NewCondition(NULL, name, value))
}

func (q *Querier) IsNotNull(name string, value interface{}) *Querier {
return q.where(NewCondition(NOT_NULL, name, value))
}

func (q Querier) Like(name string, value interface{}) *Querier {
return q.where(NewCondition(LIKE, name, value))
}

func (q *Querier) NotLike(name string, value interface{}) *Querier {
return q.where(NewCondition(NOT_LIKE, name, value))
}

func (q *Querier) ToSql() string {
return q.SQLBuilder.ToSql()
}


func (q *Querier) First(container interface{}) error {
var err error
var sql = q.ToSql()
err = q.Raw(sql, q.Parameters()).QueryRow(container)

return err
}
//分页查询function(这才我想要的)
func (q *Querier) Pagination(container interface{}, page int, pageSize int) (*Pagination, error) {
var err error
var totalItem int
var hasNext bool
q.Limit((page - 1) * pageSize, pageSize)
q.Raw(q.ToCountSql(), q.Parameters()).QueryRow(&totalItem)
var sql = q.ToSql()

_, err = q.Raw(sql, q.Parameters()).QueryRows(container)

pagination := NewPagination(page, totalItem, hasNext)
pagination.setPerPage(pageSize)
pagination.hasNext = pagination.TotalPages() > page
pagination.SetData(container)

return pagination, err
}

2.2 sql_builder sql 构建组装接口

package db

type SQLBuilder interface {
SetTable(tableName string)
Parameters() []interface{}
ToCountSql() string
ToSql() string
Where(condition string, params ... interface{})
Select(columns []string)
ClearSelect()
Join(joinSql string)
GroupBy(groupBy string)
OrderBy(orderBy string)
Limit(offset int, rowCount int)
HasLimit() bool
EscapeColumn(column string) string
}

2.3 pagination 分页struct

package db
import (
"beego_study/utils"
"math"
"net/url"
"strconv"
)

const DEFAULT_PER_PAGE = 10
const MAX_SHOW_PAGE = 9

type Mode int

const (
FULL Mode = 1 + iota
NEXT_ONLY
)

type Pagination struct {
Page      int
PerPage   int
Total     int
Data      []interface{}
hasNext   bool
pageRange []int
url       *url.URL
}

func NewPagination(page int, total int, hasNext bool) *Pagination {
pagination := new(Pagination)

if page <= 0 {
page = 1
}
pagination.PerPage = DEFAULT_PER_PAGE
pagination.Page = page
pagination.Total = total
pagination.hasNext = hasNext
return pagination
}

func (p *Pagination) setPerPage(perPage int) {
p.PerPage = perPage
}

func (p *Pagination) TotalPages() int {
return (p.Total + p.PerPage - 1) / p.PerPage;
}

func (p *Pagination) NextPage() int {
if (p.Page < p.TotalPages()) {
return p.Page + 1;
}
return -1;
}

func (p *Pagination) PrevPage() int {
if p.Page <= 1 {
return -1
}else {
return p.Page - 1
}
}

func (p *Pagination) Offset() int {
return (p.Page - 1) * p.PerPage + 1;
}

func ( p *Pagination) HasNext() bool {
return p.hasNext
}

func (p *Pagination) SetData(container interface{}) {
p.Data = utils.ToSlice(container)
}

//由于pc和mobile 屏幕尺寸大小不一
//可以通过maxShowPages控制分页条大小
func (p *Pagination) Pages(maxShowPages int) []int {

if (maxShowPages < 5 || maxShowPages > MAX_SHOW_PAGE) {
maxShowPages = MAX_SHOW_PAGE;
}
middlePageNum := maxShowPages / 2
if p.pageRange == nil && p.Total > 0 {
var pages []int
pageNums := p.TotalPages()
page := p.Page
switch {
case page >= pageNums - middlePageNum && pageNums > maxShowPages:
start := pageNums - maxShowPages + 1
pages = make([]int, maxShowPages)
for i := range pages {
pages[i] = start + i
}
case page >= (middlePageNum + 1) && pageNums > maxShowPages:
start := page - middlePageNum
pages = make([]int, int(math.Min(float64(maxShowPages), float64(page + middlePageNum + 1))))
for i := range pages {
pages[i] = start + i
}
default:
pages = make([]int, int(math.Min(float64(maxShowPages), float64(pageNums))))
for i := range pages {
pages[i] = i + 1
}
}
p.pageRange = pages
}
return p.pageRange
}


func (p *Pagination) PageLink(page int) string {
values := p.url.Query()
values.Set("page", strconv.Itoa(page))
p.url.RawQuery = values.Encode()
return p.url.String()
}

// Returns URL to the previous page.
func (p *Pagination) PageLinkPrev() (link string) {
if p.HasPrev() {
link = p.PageLink(p.Page - 1)
}
return
}

// Returns URL to the next page.
func (p *Pagination) PageLinkNext() (link string) {
if p.HasNext() {
link = p.PageLink(p.Page + 1)
}
return
}

// Returns URL to the first page.
func (p *Pagination) PageLinkFirst() (link string) {
return p.PageLink(1)
}

// Returns URL to the last page.
func (p *Pagination) PageLinkLast() (link string) {
return p.PageLink(p.TotalPages())
}

func ( p *Pagination) HasPrev() bool {
return p.Page > 1
}

func (p *Pagination) IsActive(pagea int) bool {
return p.Page == pagea
}

func (p *Pagination) SetUrl(url *url.URL) {
p.url = url
}

2.4 使用实例

func TestPagination(t *testing.T) {
pagination, _ := db.NewDB().From("user").Select("id", "name").Pagination(&[]entities.User{}, 1, 10)
for _, value := range pagination.Data {
fmt.Println("id", value.(entities.User).Id, "name", value.(entities.User).Name)
}
}


评论