login_sources.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE and LICENSE.gogs file.
  4. package database
  5. import (
  6. "context"
  7. "fmt"
  8. "strconv"
  9. "time"
  10. jsoniter "github.com/json-iterator/go"
  11. "github.com/pkg/errors"
  12. "gorm.io/gorm"
  13. "github.com/SongZihuan/huan-gogs/internal/auth"
  14. "github.com/SongZihuan/huan-gogs/internal/auth/github"
  15. "github.com/SongZihuan/huan-gogs/internal/auth/ldap"
  16. "github.com/SongZihuan/huan-gogs/internal/auth/pam"
  17. "github.com/SongZihuan/huan-gogs/internal/auth/smtp"
  18. "github.com/SongZihuan/huan-gogs/internal/errutil"
  19. )
  20. // LoginSource represents an external way for authorizing users.
  21. type LoginSource struct {
  22. ID int64 `gorm:"primaryKey"`
  23. Type auth.Type
  24. Name string `xorm:"UNIQUE" gorm:"unique"`
  25. IsActived bool `xorm:"NOT NULL DEFAULT false" gorm:"not null"`
  26. IsDefault bool `xorm:"DEFAULT false"`
  27. Provider auth.Provider `xorm:"-" gorm:"-"`
  28. Config string `xorm:"TEXT cfg" gorm:"column:cfg;type:TEXT" json:"RawConfig"`
  29. Created time.Time `xorm:"-" gorm:"-" json:"-"`
  30. CreatedUnix int64
  31. Updated time.Time `xorm:"-" gorm:"-" json:"-"`
  32. UpdatedUnix int64
  33. File loginSourceFileStore `xorm:"-" gorm:"-" json:"-"`
  34. }
  35. // BeforeSave implements the GORM save hook.
  36. func (s *LoginSource) BeforeSave(_ *gorm.DB) (err error) {
  37. if s.Provider == nil {
  38. return nil
  39. }
  40. s.Config, err = jsoniter.MarshalToString(s.Provider.Config())
  41. return err
  42. }
  43. // BeforeCreate implements the GORM create hook.
  44. func (s *LoginSource) BeforeCreate(tx *gorm.DB) error {
  45. if s.CreatedUnix == 0 {
  46. s.CreatedUnix = tx.NowFunc().Unix()
  47. s.UpdatedUnix = s.CreatedUnix
  48. }
  49. return nil
  50. }
  51. // BeforeUpdate implements the GORM update hook.
  52. func (s *LoginSource) BeforeUpdate(tx *gorm.DB) error {
  53. s.UpdatedUnix = tx.NowFunc().Unix()
  54. return nil
  55. }
  56. type mockProviderConfig struct {
  57. ExternalAccount *auth.ExternalAccount
  58. }
  59. // AfterFind implements the GORM query hook.
  60. func (s *LoginSource) AfterFind(_ *gorm.DB) error {
  61. s.Created = time.Unix(s.CreatedUnix, 0).Local()
  62. s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
  63. switch s.Type {
  64. case auth.LDAP:
  65. var cfg ldap.Config
  66. err := jsoniter.UnmarshalFromString(s.Config, &cfg)
  67. if err != nil {
  68. return err
  69. }
  70. s.Provider = ldap.NewProvider(false, &cfg)
  71. case auth.DLDAP:
  72. var cfg ldap.Config
  73. err := jsoniter.UnmarshalFromString(s.Config, &cfg)
  74. if err != nil {
  75. return err
  76. }
  77. s.Provider = ldap.NewProvider(true, &cfg)
  78. case auth.SMTP:
  79. var cfg smtp.Config
  80. err := jsoniter.UnmarshalFromString(s.Config, &cfg)
  81. if err != nil {
  82. return err
  83. }
  84. s.Provider = smtp.NewProvider(&cfg)
  85. case auth.PAM:
  86. var cfg pam.Config
  87. err := jsoniter.UnmarshalFromString(s.Config, &cfg)
  88. if err != nil {
  89. return err
  90. }
  91. s.Provider = pam.NewProvider(&cfg)
  92. case auth.GitHub:
  93. var cfg github.Config
  94. err := jsoniter.UnmarshalFromString(s.Config, &cfg)
  95. if err != nil {
  96. return err
  97. }
  98. s.Provider = github.NewProvider(&cfg)
  99. case auth.Mock:
  100. var cfg mockProviderConfig
  101. err := jsoniter.UnmarshalFromString(s.Config, &cfg)
  102. if err != nil {
  103. return err
  104. }
  105. mockProvider := NewMockProvider()
  106. mockProvider.AuthenticateFunc.SetDefaultReturn(cfg.ExternalAccount, nil)
  107. s.Provider = mockProvider
  108. default:
  109. return fmt.Errorf("unrecognized login source type: %v", s.Type)
  110. }
  111. return nil
  112. }
  113. func (s *LoginSource) TypeName() string {
  114. return auth.Name(s.Type)
  115. }
  116. func (s *LoginSource) IsLDAP() bool {
  117. return s.Type == auth.LDAP
  118. }
  119. func (s *LoginSource) IsDLDAP() bool {
  120. return s.Type == auth.DLDAP
  121. }
  122. func (s *LoginSource) IsSMTP() bool {
  123. return s.Type == auth.SMTP
  124. }
  125. func (s *LoginSource) IsPAM() bool {
  126. return s.Type == auth.PAM
  127. }
  128. func (s *LoginSource) IsGitHub() bool {
  129. return s.Type == auth.GitHub
  130. }
  131. func (s *LoginSource) LDAP() *ldap.Config {
  132. return s.Provider.Config().(*ldap.Config)
  133. }
  134. func (s *LoginSource) SMTP() *smtp.Config {
  135. return s.Provider.Config().(*smtp.Config)
  136. }
  137. func (s *LoginSource) PAM() *pam.Config {
  138. return s.Provider.Config().(*pam.Config)
  139. }
  140. func (s *LoginSource) GitHub() *github.Config {
  141. return s.Provider.Config().(*github.Config)
  142. }
  143. // LoginSourcesStore is the storage layer for login sources.
  144. type LoginSourcesStore struct {
  145. db *gorm.DB
  146. files loginSourceFilesStore
  147. }
  148. func newLoginSourcesStore(db *gorm.DB, files loginSourceFilesStore) *LoginSourcesStore {
  149. return &LoginSourcesStore{
  150. db: db,
  151. files: files,
  152. }
  153. }
  154. type CreateLoginSourceOptions struct {
  155. Type auth.Type
  156. Name string
  157. Activated bool
  158. Default bool
  159. Config any
  160. }
  161. type ErrLoginSourceAlreadyExist struct {
  162. args errutil.Args
  163. }
  164. func IsErrLoginSourceAlreadyExist(err error) bool {
  165. return errors.As(err, &ErrLoginSourceAlreadyExist{})
  166. }
  167. func (err ErrLoginSourceAlreadyExist) Error() string {
  168. return fmt.Sprintf("login source already exists: %v", err.args)
  169. }
  170. // Create creates a new login source and persists it to the database. It returns
  171. // ErrLoginSourceAlreadyExist when a login source with same name already exists.
  172. func (s *LoginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
  173. err := s.db.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
  174. if err == nil {
  175. return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
  176. } else if !errors.Is(err, gorm.ErrRecordNotFound) {
  177. return nil, err
  178. }
  179. source := &LoginSource{
  180. Type: opts.Type,
  181. Name: opts.Name,
  182. IsActived: opts.Activated,
  183. IsDefault: opts.Default,
  184. }
  185. source.Config, err = jsoniter.MarshalToString(opts.Config)
  186. if err != nil {
  187. return nil, err
  188. }
  189. return source, s.db.WithContext(ctx).Create(source).Error
  190. }
  191. // Count returns the total number of login sources.
  192. func (s *LoginSourcesStore) Count(ctx context.Context) int64 {
  193. var count int64
  194. s.db.WithContext(ctx).Model(new(LoginSource)).Count(&count)
  195. return count + int64(s.files.Len())
  196. }
  197. type ErrLoginSourceInUse struct {
  198. args errutil.Args
  199. }
  200. func IsErrLoginSourceInUse(err error) bool {
  201. return errors.As(err, &ErrLoginSourceInUse{})
  202. }
  203. func (err ErrLoginSourceInUse) Error() string {
  204. return fmt.Sprintf("login source is still used by some users: %v", err.args)
  205. }
  206. // DeleteByID deletes a login source by given ID. It returns ErrLoginSourceInUse
  207. // if at least one user is associated with the login source.
  208. func (s *LoginSourcesStore) DeleteByID(ctx context.Context, id int64) error {
  209. var count int64
  210. err := s.db.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
  211. if err != nil {
  212. return err
  213. } else if count > 0 {
  214. return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
  215. }
  216. return s.db.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
  217. }
  218. // GetByID returns the login source with given ID. It returns
  219. // ErrLoginSourceNotExist when not found.
  220. func (s *LoginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
  221. source := new(LoginSource)
  222. err := s.db.WithContext(ctx).Where("id = ?", id).First(source).Error
  223. if err != nil {
  224. if errors.Is(err, gorm.ErrRecordNotFound) {
  225. return s.files.GetByID(id)
  226. }
  227. return nil, err
  228. }
  229. return source, nil
  230. }
  231. type ListLoginSourceOptions struct {
  232. // Whether to only include activated login sources.
  233. OnlyActivated bool
  234. }
  235. // List returns a list of login sources filtered by options.
  236. func (s *LoginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
  237. var sources []*LoginSource
  238. query := s.db.WithContext(ctx).Order("id ASC")
  239. if opts.OnlyActivated {
  240. query = query.Where("is_actived = ?", true)
  241. }
  242. err := query.Find(&sources).Error
  243. if err != nil {
  244. return nil, err
  245. }
  246. return append(sources, s.files.List(opts)...), nil
  247. }
  248. // ResetNonDefault clears default flag for all the other login sources.
  249. func (s *LoginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
  250. err := s.db.WithContext(ctx).
  251. Model(new(LoginSource)).
  252. Where("id != ?", dflt.ID).
  253. Updates(map[string]any{"is_default": false}).
  254. Error
  255. if err != nil {
  256. return err
  257. }
  258. for _, source := range s.files.List(ListLoginSourceOptions{}) {
  259. if source.File != nil && source.ID != dflt.ID {
  260. source.File.SetGeneral("is_default", "false")
  261. if err = source.File.Save(); err != nil {
  262. return errors.Wrap(err, "save file")
  263. }
  264. }
  265. }
  266. s.files.Update(dflt)
  267. return nil
  268. }
  269. // Save persists all values of given login source to database or local file. The
  270. // Updated field is set to current time automatically.
  271. func (s *LoginSourcesStore) Save(ctx context.Context, source *LoginSource) error {
  272. if source.File == nil {
  273. return s.db.WithContext(ctx).Save(source).Error
  274. }
  275. source.File.SetGeneral("name", source.Name)
  276. source.File.SetGeneral("is_activated", strconv.FormatBool(source.IsActived))
  277. source.File.SetGeneral("is_default", strconv.FormatBool(source.IsDefault))
  278. if err := source.File.SetConfig(source.Provider.Config()); err != nil {
  279. return errors.Wrap(err, "set config")
  280. } else if err = source.File.Save(); err != nil {
  281. return errors.Wrap(err, "save file")
  282. }
  283. return nil
  284. }