postgresqlmodel.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. package model
  2. import (
  3. "database/sql"
  4. "strings"
  5. "github.com/zeromicro/go-zero/core/stores/sqlx"
  6. )
  7. var p2m = map[string]string{
  8. "int8": "bigint",
  9. "numeric": "bigint",
  10. "float8": "double",
  11. "float4": "float",
  12. "int2": "smallint",
  13. "int4": "integer",
  14. "timestamptz": "timestamp",
  15. }
  16. // PostgreSqlModel gets table information from information_schema、pg_catalog
  17. type PostgreSqlModel struct {
  18. conn sqlx.SqlConn
  19. }
  20. // PostgreColumn describes a column in table
  21. type PostgreColumn struct {
  22. Num sql.NullInt32 `db:"num"`
  23. Field sql.NullString `db:"field"`
  24. Type sql.NullString `db:"type"`
  25. NotNull sql.NullBool `db:"not_null"`
  26. Comment sql.NullString `db:"comment"`
  27. ColumnDefault sql.NullString `db:"column_default"`
  28. IdentityIncrement sql.NullInt32 `db:"identity_increment"`
  29. }
  30. // PostgreIndex describes an index for a column
  31. type PostgreIndex struct {
  32. IndexName sql.NullString `db:"index_name"`
  33. IndexId sql.NullInt32 `db:"index_id"`
  34. IsUnique sql.NullBool `db:"is_unique"`
  35. IsPrimary sql.NullBool `db:"is_primary"`
  36. ColumnName sql.NullString `db:"column_name"`
  37. IndexSort sql.NullInt32 `db:"index_sort"`
  38. }
  39. // NewPostgreSqlModel creates an instance and return
  40. func NewPostgreSqlModel(conn sqlx.SqlConn) *PostgreSqlModel {
  41. return &PostgreSqlModel{
  42. conn: conn,
  43. }
  44. }
  45. // GetAllTables selects all tables from TABLE_SCHEMA
  46. func (m *PostgreSqlModel) GetAllTables(schema string) ([]string, error) {
  47. query := `select table_name from information_schema.tables where table_schema = $1`
  48. var tables []string
  49. err := m.conn.QueryRows(&tables, query, schema)
  50. if err != nil {
  51. return nil, err
  52. }
  53. return tables, nil
  54. }
  55. // FindColumns return columns in specified database and table
  56. func (m *PostgreSqlModel) FindColumns(schema, table string) (*ColumnData, error) {
  57. querySql := `select t.num,t.field,t.type,t.not_null,t.comment, c.column_default, identity_increment
  58. from (
  59. SELECT a.attnum AS num,
  60. c.relname,
  61. a.attname AS field,
  62. t.typname AS type,
  63. a.atttypmod AS lengthvar,
  64. a.attnotnull AS not_null,
  65. b.description AS comment
  66. FROM pg_class c,
  67. pg_attribute a
  68. LEFT OUTER JOIN pg_description b ON a.attrelid = b.objoid AND a.attnum = b.objsubid,
  69. pg_type t
  70. WHERE c.relname = $1
  71. and a.attnum > 0
  72. and a.attrelid = c.oid
  73. and a.atttypid = t.oid
  74. GROUP BY a.attnum, c.relname, a.attname, t.typname, a.atttypmod, a.attnotnull, b.description
  75. ORDER BY a.attnum) AS t
  76. left join information_schema.columns AS c on t.relname = c.table_name
  77. and t.field = c.column_name and c.table_schema = $2`
  78. var reply []*PostgreColumn
  79. err := m.conn.QueryRowsPartial(&reply, querySql, table, schema)
  80. if err != nil {
  81. return nil, err
  82. }
  83. list, err := m.getColumns(schema, table, reply)
  84. if err != nil {
  85. return nil, err
  86. }
  87. var columnData ColumnData
  88. columnData.Db = schema
  89. columnData.Table = table
  90. columnData.Columns = list
  91. return &columnData, nil
  92. }
  93. func (m *PostgreSqlModel) getColumns(schema, table string, in []*PostgreColumn) ([]*Column, error) {
  94. index, err := m.getIndex(schema, table)
  95. if err != nil {
  96. return nil, err
  97. }
  98. var list []*Column
  99. for _, e := range in {
  100. var dft interface{}
  101. if len(e.ColumnDefault.String) > 0 {
  102. dft = e.ColumnDefault
  103. }
  104. isNullAble := "YES"
  105. if e.NotNull.Bool {
  106. isNullAble = "NO"
  107. }
  108. var extra string
  109. // when identity is true, the column is auto increment
  110. if e.IdentityIncrement.Int32 == 1 {
  111. extra = "auto_increment"
  112. }
  113. // when type is serial, it's auto_increment. and the default value is tablename_columnname_seq
  114. if strings.Contains(e.ColumnDefault.String, table+"_"+e.Field.String+"_seq") {
  115. extra = "auto_increment"
  116. }
  117. if len(index[e.Field.String]) > 0 {
  118. for _, i := range index[e.Field.String] {
  119. list = append(list, &Column{
  120. DbColumn: &DbColumn{
  121. Name: e.Field.String,
  122. DataType: m.convertPostgreSqlTypeIntoMysqlType(e.Type.String),
  123. Extra: extra,
  124. Comment: e.Comment.String,
  125. ColumnDefault: dft,
  126. IsNullAble: isNullAble,
  127. OrdinalPosition: int(e.Num.Int32),
  128. },
  129. Index: i,
  130. })
  131. }
  132. } else {
  133. list = append(list, &Column{
  134. DbColumn: &DbColumn{
  135. Name: e.Field.String,
  136. DataType: m.convertPostgreSqlTypeIntoMysqlType(e.Type.String),
  137. Extra: extra,
  138. Comment: e.Comment.String,
  139. ColumnDefault: dft,
  140. IsNullAble: isNullAble,
  141. OrdinalPosition: int(e.Num.Int32),
  142. },
  143. })
  144. }
  145. }
  146. return list, nil
  147. }
  148. func (m *PostgreSqlModel) convertPostgreSqlTypeIntoMysqlType(in string) string {
  149. r, ok := p2m[strings.ToLower(in)]
  150. if ok {
  151. return r
  152. }
  153. return in
  154. }
  155. func (m *PostgreSqlModel) getIndex(schema, table string) (map[string][]*DbIndex, error) {
  156. indexes, err := m.FindIndex(schema, table)
  157. if err != nil {
  158. return nil, err
  159. }
  160. index := make(map[string][]*DbIndex)
  161. for _, e := range indexes {
  162. if e.IsPrimary.Bool {
  163. index[e.ColumnName.String] = append(index[e.ColumnName.String], &DbIndex{
  164. IndexName: indexPri,
  165. SeqInIndex: int(e.IndexSort.Int32),
  166. })
  167. continue
  168. }
  169. nonUnique := 0
  170. if !e.IsUnique.Bool {
  171. nonUnique = 1
  172. }
  173. index[e.ColumnName.String] = append(index[e.ColumnName.String], &DbIndex{
  174. IndexName: e.IndexName.String,
  175. NonUnique: nonUnique,
  176. SeqInIndex: int(e.IndexSort.Int32),
  177. })
  178. }
  179. return index, nil
  180. }
  181. // FindIndex finds index with given schema, table and column.
  182. func (m *PostgreSqlModel) FindIndex(schema, table string) ([]*PostgreIndex, error) {
  183. querySql := `select A.INDEXNAME AS index_name,
  184. C.INDEXRELID AS index_id,
  185. C.INDISUNIQUE AS is_unique,
  186. C.INDISPRIMARY AS is_primary,
  187. G.ATTNAME AS column_name,
  188. G.attnum AS index_sort
  189. from PG_AM B
  190. left join PG_CLASS F on
  191. B.OID = F.RELAM
  192. left join PG_STAT_ALL_INDEXES E on
  193. F.OID = E.INDEXRELID
  194. left join PG_INDEX C on
  195. E.INDEXRELID = C.INDEXRELID
  196. left outer join PG_DESCRIPTION D on
  197. C.INDEXRELID = D.OBJOID,
  198. PG_INDEXES A,
  199. pg_attribute G
  200. where A.SCHEMANAME = E.SCHEMANAME
  201. and A.TABLENAME = E.RELNAME
  202. and A.INDEXNAME = E.INDEXRELNAME
  203. and F.oid = G.attrelid
  204. and E.SCHEMANAME = $1
  205. and E.RELNAME = $2
  206. order by C.INDEXRELID,G.attnum`
  207. var reply []*PostgreIndex
  208. err := m.conn.QueryRowsPartial(&reply, querySql, schema, table)
  209. if err != nil {
  210. return nil, err
  211. }
  212. return reply, nil
  213. }