123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243 |
- package model
- import (
- "database/sql"
- "strings"
- "github.com/zeromicro/go-zero/core/stores/sqlx"
- )
- var p2m = map[string]string{
- "int8": "bigint",
- "numeric": "bigint",
- "float8": "double",
- "float4": "float",
- "int2": "smallint",
- "int4": "integer",
- "timestamptz": "timestamp",
- }
- // PostgreSqlModel gets table information from information_schema、pg_catalog
- type PostgreSqlModel struct {
- conn sqlx.SqlConn
- }
- // PostgreColumn describes a column in table
- type PostgreColumn struct {
- Num sql.NullInt32 `db:"num"`
- Field sql.NullString `db:"field"`
- Type sql.NullString `db:"type"`
- NotNull sql.NullBool `db:"not_null"`
- Comment sql.NullString `db:"comment"`
- ColumnDefault sql.NullString `db:"column_default"`
- IdentityIncrement sql.NullInt32 `db:"identity_increment"`
- }
- // PostgreIndex describes an index for a column
- type PostgreIndex struct {
- IndexName sql.NullString `db:"index_name"`
- IndexId sql.NullInt32 `db:"index_id"`
- IsUnique sql.NullBool `db:"is_unique"`
- IsPrimary sql.NullBool `db:"is_primary"`
- ColumnName sql.NullString `db:"column_name"`
- IndexSort sql.NullInt32 `db:"index_sort"`
- }
- // NewPostgreSqlModel creates an instance and return
- func NewPostgreSqlModel(conn sqlx.SqlConn) *PostgreSqlModel {
- return &PostgreSqlModel{
- conn: conn,
- }
- }
- // GetAllTables selects all tables from TABLE_SCHEMA
- func (m *PostgreSqlModel) GetAllTables(schema string) ([]string, error) {
- query := `select table_name from information_schema.tables where table_schema = $1`
- var tables []string
- err := m.conn.QueryRows(&tables, query, schema)
- if err != nil {
- return nil, err
- }
- return tables, nil
- }
- // FindColumns return columns in specified database and table
- func (m *PostgreSqlModel) FindColumns(schema, table string) (*ColumnData, error) {
- querySql := `select t.num,t.field,t.type,t.not_null,t.comment, c.column_default, identity_increment
- from (
- SELECT a.attnum AS num,
- c.relname,
- a.attname AS field,
- t.typname AS type,
- a.atttypmod AS lengthvar,
- a.attnotnull AS not_null,
- b.description AS comment
- FROM pg_class c,
- pg_attribute a
- LEFT OUTER JOIN pg_description b ON a.attrelid = b.objoid AND a.attnum = b.objsubid,
- pg_type t
- WHERE c.relname = $1
- and a.attnum > 0
- and a.attrelid = c.oid
- and a.atttypid = t.oid
- GROUP BY a.attnum, c.relname, a.attname, t.typname, a.atttypmod, a.attnotnull, b.description
- ORDER BY a.attnum) AS t
- left join information_schema.columns AS c on t.relname = c.table_name
- and t.field = c.column_name and c.table_schema = $2`
- var reply []*PostgreColumn
- err := m.conn.QueryRowsPartial(&reply, querySql, table, schema)
- if err != nil {
- return nil, err
- }
- list, err := m.getColumns(schema, table, reply)
- if err != nil {
- return nil, err
- }
- var columnData ColumnData
- columnData.Db = schema
- columnData.Table = table
- columnData.Columns = list
- return &columnData, nil
- }
- func (m *PostgreSqlModel) getColumns(schema, table string, in []*PostgreColumn) ([]*Column, error) {
- index, err := m.getIndex(schema, table)
- if err != nil {
- return nil, err
- }
- var list []*Column
- for _, e := range in {
- var dft interface{}
- if len(e.ColumnDefault.String) > 0 {
- dft = e.ColumnDefault
- }
- isNullAble := "YES"
- if e.NotNull.Bool {
- isNullAble = "NO"
- }
- var extra string
- // when identity is true, the column is auto increment
- if e.IdentityIncrement.Int32 == 1 {
- extra = "auto_increment"
- }
- // when type is serial, it's auto_increment. and the default value is tablename_columnname_seq
- if strings.Contains(e.ColumnDefault.String, table+"_"+e.Field.String+"_seq") {
- extra = "auto_increment"
- }
- if len(index[e.Field.String]) > 0 {
- for _, i := range index[e.Field.String] {
- list = append(list, &Column{
- DbColumn: &DbColumn{
- Name: e.Field.String,
- DataType: m.convertPostgreSqlTypeIntoMysqlType(e.Type.String),
- Extra: extra,
- Comment: e.Comment.String,
- ColumnDefault: dft,
- IsNullAble: isNullAble,
- OrdinalPosition: int(e.Num.Int32),
- },
- Index: i,
- })
- }
- } else {
- list = append(list, &Column{
- DbColumn: &DbColumn{
- Name: e.Field.String,
- DataType: m.convertPostgreSqlTypeIntoMysqlType(e.Type.String),
- Extra: extra,
- Comment: e.Comment.String,
- ColumnDefault: dft,
- IsNullAble: isNullAble,
- OrdinalPosition: int(e.Num.Int32),
- },
- })
- }
- }
- return list, nil
- }
- func (m *PostgreSqlModel) convertPostgreSqlTypeIntoMysqlType(in string) string {
- r, ok := p2m[strings.ToLower(in)]
- if ok {
- return r
- }
- return in
- }
- func (m *PostgreSqlModel) getIndex(schema, table string) (map[string][]*DbIndex, error) {
- indexes, err := m.FindIndex(schema, table)
- if err != nil {
- return nil, err
- }
- index := make(map[string][]*DbIndex)
- for _, e := range indexes {
- if e.IsPrimary.Bool {
- index[e.ColumnName.String] = append(index[e.ColumnName.String], &DbIndex{
- IndexName: indexPri,
- SeqInIndex: int(e.IndexSort.Int32),
- })
- continue
- }
- nonUnique := 0
- if !e.IsUnique.Bool {
- nonUnique = 1
- }
- index[e.ColumnName.String] = append(index[e.ColumnName.String], &DbIndex{
- IndexName: e.IndexName.String,
- NonUnique: nonUnique,
- SeqInIndex: int(e.IndexSort.Int32),
- })
- }
- return index, nil
- }
- // FindIndex finds index with given schema, table and column.
- func (m *PostgreSqlModel) FindIndex(schema, table string) ([]*PostgreIndex, error) {
- querySql := `select A.INDEXNAME AS index_name,
- C.INDEXRELID AS index_id,
- C.INDISUNIQUE AS is_unique,
- C.INDISPRIMARY AS is_primary,
- G.ATTNAME AS column_name,
- G.attnum AS index_sort
- from PG_AM B
- left join PG_CLASS F on
- B.OID = F.RELAM
- left join PG_STAT_ALL_INDEXES E on
- F.OID = E.INDEXRELID
- left join PG_INDEX C on
- E.INDEXRELID = C.INDEXRELID
- left outer join PG_DESCRIPTION D on
- C.INDEXRELID = D.OBJOID,
- PG_INDEXES A,
- pg_attribute G
- where A.SCHEMANAME = E.SCHEMANAME
- and A.TABLENAME = E.RELNAME
- and A.INDEXNAME = E.INDEXRELNAME
- and F.oid = G.attrelid
- and E.SCHEMANAME = $1
- and E.RELNAME = $2
- order by C.INDEXRELID,G.attnum`
- var reply []*PostgreIndex
- err := m.conn.QueryRowsPartial(&reply, querySql, schema, table)
- if err != nil {
- return nil, err
- }
- return reply, nil
- }
|