login_sources_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package database
  5. import (
  6. "context"
  7. "testing"
  8. "time"
  9. mockrequire "github.com/derision-test/go-mockgen/testutil/require"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/stretchr/testify/require"
  12. "gorm.io/gorm"
  13. "gogs.io/gogs/internal/auth"
  14. "gogs.io/gogs/internal/auth/github"
  15. "gogs.io/gogs/internal/auth/ldap"
  16. "gogs.io/gogs/internal/auth/pam"
  17. "gogs.io/gogs/internal/auth/smtp"
  18. "gogs.io/gogs/internal/dbtest"
  19. "gogs.io/gogs/internal/errutil"
  20. )
  21. func TestLoginSource_BeforeSave(t *testing.T) {
  22. now := time.Now()
  23. db := &gorm.DB{
  24. Config: &gorm.Config{
  25. SkipDefaultTransaction: true,
  26. NowFunc: func() time.Time {
  27. return now
  28. },
  29. },
  30. }
  31. t.Run("Config has not been set", func(t *testing.T) {
  32. s := &LoginSource{}
  33. err := s.BeforeSave(db)
  34. require.NoError(t, err)
  35. assert.Empty(t, s.Config)
  36. })
  37. t.Run("Config has been set", func(t *testing.T) {
  38. s := &LoginSource{
  39. Provider: pam.NewProvider(&pam.Config{
  40. ServiceName: "pam_service",
  41. }),
  42. }
  43. err := s.BeforeSave(db)
  44. require.NoError(t, err)
  45. assert.Equal(t, `{"ServiceName":"pam_service"}`, s.Config)
  46. })
  47. }
  48. func TestLoginSource_BeforeCreate(t *testing.T) {
  49. now := time.Now()
  50. db := &gorm.DB{
  51. Config: &gorm.Config{
  52. SkipDefaultTransaction: true,
  53. NowFunc: func() time.Time {
  54. return now
  55. },
  56. },
  57. }
  58. t.Run("CreatedUnix has been set", func(t *testing.T) {
  59. s := &LoginSource{
  60. CreatedUnix: 1,
  61. }
  62. _ = s.BeforeCreate(db)
  63. assert.Equal(t, int64(1), s.CreatedUnix)
  64. assert.Equal(t, int64(0), s.UpdatedUnix)
  65. })
  66. t.Run("CreatedUnix has not been set", func(t *testing.T) {
  67. s := &LoginSource{}
  68. _ = s.BeforeCreate(db)
  69. assert.Equal(t, db.NowFunc().Unix(), s.CreatedUnix)
  70. assert.Equal(t, db.NowFunc().Unix(), s.UpdatedUnix)
  71. })
  72. }
  73. func TestLoginSource_BeforeUpdate(t *testing.T) {
  74. now := time.Now()
  75. db := &gorm.DB{
  76. Config: &gorm.Config{
  77. SkipDefaultTransaction: true,
  78. NowFunc: func() time.Time {
  79. return now
  80. },
  81. },
  82. }
  83. s := &LoginSource{}
  84. _ = s.BeforeUpdate(db)
  85. assert.Equal(t, db.NowFunc().Unix(), s.UpdatedUnix)
  86. }
  87. func TestLoginSource_AfterFind(t *testing.T) {
  88. now := time.Now()
  89. db := &gorm.DB{
  90. Config: &gorm.Config{
  91. SkipDefaultTransaction: true,
  92. NowFunc: func() time.Time {
  93. return now
  94. },
  95. },
  96. }
  97. tests := []struct {
  98. name string
  99. authType auth.Type
  100. wantType any
  101. }{
  102. {
  103. name: "LDAP",
  104. authType: auth.LDAP,
  105. wantType: &ldap.Provider{},
  106. },
  107. {
  108. name: "DLDAP",
  109. authType: auth.DLDAP,
  110. wantType: &ldap.Provider{},
  111. },
  112. {
  113. name: "SMTP",
  114. authType: auth.SMTP,
  115. wantType: &smtp.Provider{},
  116. },
  117. {
  118. name: "PAM",
  119. authType: auth.PAM,
  120. wantType: &pam.Provider{},
  121. },
  122. {
  123. name: "GitHub",
  124. authType: auth.GitHub,
  125. wantType: &github.Provider{},
  126. },
  127. }
  128. for _, test := range tests {
  129. t.Run(test.name, func(t *testing.T) {
  130. s := LoginSource{
  131. Type: test.authType,
  132. Config: `{}`,
  133. CreatedUnix: now.Unix(),
  134. UpdatedUnix: now.Unix(),
  135. }
  136. err := s.AfterFind(db)
  137. require.NoError(t, err)
  138. assert.Equal(t, s.CreatedUnix, s.Created.Unix())
  139. assert.Equal(t, s.UpdatedUnix, s.Updated.Unix())
  140. assert.IsType(t, test.wantType, s.Provider)
  141. })
  142. }
  143. }
  144. func TestLoginSources(t *testing.T) {
  145. if testing.Short() {
  146. t.Skip()
  147. }
  148. t.Parallel()
  149. ctx := context.Background()
  150. tables := []any{new(LoginSource), new(User)}
  151. db := &loginSources{
  152. DB: dbtest.NewDB(t, "loginSources", tables...),
  153. }
  154. for _, tc := range []struct {
  155. name string
  156. test func(t *testing.T, ctx context.Context, db *loginSources)
  157. }{
  158. {"Create", loginSourcesCreate},
  159. {"Count", loginSourcesCount},
  160. {"DeleteByID", loginSourcesDeleteByID},
  161. {"GetByID", loginSourcesGetByID},
  162. {"List", loginSourcesList},
  163. {"ResetNonDefault", loginSourcesResetNonDefault},
  164. {"Save", loginSourcesSave},
  165. } {
  166. t.Run(tc.name, func(t *testing.T) {
  167. t.Cleanup(func() {
  168. err := clearTables(t, db.DB, tables...)
  169. require.NoError(t, err)
  170. })
  171. tc.test(t, ctx, db)
  172. })
  173. if t.Failed() {
  174. break
  175. }
  176. }
  177. }
  178. func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSources) {
  179. // Create first login source with name "GitHub"
  180. source, err := db.Create(ctx,
  181. CreateLoginSourceOptions{
  182. Type: auth.GitHub,
  183. Name: "GitHub",
  184. Activated: true,
  185. Default: false,
  186. Config: &github.Config{
  187. APIEndpoint: "https://api.github.com",
  188. },
  189. },
  190. )
  191. require.NoError(t, err)
  192. // Get it back and check the Created field
  193. source, err = db.GetByID(ctx, source.ID)
  194. require.NoError(t, err)
  195. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
  196. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
  197. // Try create second login source with same name should fail
  198. _, err = db.Create(ctx, CreateLoginSourceOptions{Name: source.Name})
  199. wantErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
  200. assert.Equal(t, wantErr, err)
  201. }
  202. func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSources) {
  203. // Create two login sources, one in database and one as source file.
  204. _, err := db.Create(ctx,
  205. CreateLoginSourceOptions{
  206. Type: auth.GitHub,
  207. Name: "GitHub",
  208. Activated: true,
  209. Default: false,
  210. Config: &github.Config{
  211. APIEndpoint: "https://api.github.com",
  212. },
  213. },
  214. )
  215. require.NoError(t, err)
  216. mock := NewMockLoginSourceFilesStore()
  217. mock.LenFunc.SetDefaultReturn(2)
  218. setMockLoginSourceFilesStore(t, db, mock)
  219. assert.Equal(t, int64(3), db.Count(ctx))
  220. }
  221. func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources) {
  222. t.Run("delete but in used", func(t *testing.T) {
  223. source, err := db.Create(ctx,
  224. CreateLoginSourceOptions{
  225. Type: auth.GitHub,
  226. Name: "GitHub",
  227. Activated: true,
  228. Default: false,
  229. Config: &github.Config{
  230. APIEndpoint: "https://api.github.com",
  231. },
  232. },
  233. )
  234. require.NoError(t, err)
  235. // Create a user that uses this login source
  236. _, err = (&users{DB: db.DB}).Create(ctx, "alice", "",
  237. CreateUserOptions{
  238. LoginSource: source.ID,
  239. },
  240. )
  241. require.NoError(t, err)
  242. // Delete the login source will result in error
  243. err = db.DeleteByID(ctx, source.ID)
  244. wantErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
  245. assert.Equal(t, wantErr, err)
  246. })
  247. mock := NewMockLoginSourceFilesStore()
  248. mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
  249. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  250. })
  251. setMockLoginSourceFilesStore(t, db, mock)
  252. // Create a login source with name "GitHub2"
  253. source, err := db.Create(ctx,
  254. CreateLoginSourceOptions{
  255. Type: auth.GitHub,
  256. Name: "GitHub2",
  257. Activated: true,
  258. Default: false,
  259. Config: &github.Config{
  260. APIEndpoint: "https://api.github.com",
  261. },
  262. },
  263. )
  264. require.NoError(t, err)
  265. // Delete a non-existent ID is noop
  266. err = db.DeleteByID(ctx, 9999)
  267. require.NoError(t, err)
  268. // We should be able to get it back
  269. _, err = db.GetByID(ctx, source.ID)
  270. require.NoError(t, err)
  271. // Now delete this login source with ID
  272. err = db.DeleteByID(ctx, source.ID)
  273. require.NoError(t, err)
  274. // We should get token not found error
  275. _, err = db.GetByID(ctx, source.ID)
  276. wantErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
  277. assert.Equal(t, wantErr, err)
  278. }
  279. func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSources) {
  280. mock := NewMockLoginSourceFilesStore()
  281. mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
  282. if id != 101 {
  283. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  284. }
  285. return &LoginSource{ID: id}, nil
  286. })
  287. setMockLoginSourceFilesStore(t, db, mock)
  288. expConfig := &github.Config{
  289. APIEndpoint: "https://api.github.com",
  290. }
  291. // Create a login source with name "GitHub"
  292. source, err := db.Create(ctx,
  293. CreateLoginSourceOptions{
  294. Type: auth.GitHub,
  295. Name: "GitHub",
  296. Activated: true,
  297. Default: false,
  298. Config: expConfig,
  299. },
  300. )
  301. require.NoError(t, err)
  302. // Get the one in the database and test the read/write hooks
  303. source, err = db.GetByID(ctx, source.ID)
  304. require.NoError(t, err)
  305. assert.Equal(t, expConfig, source.Provider.Config())
  306. // Get the one in source file store
  307. _, err = db.GetByID(ctx, 101)
  308. require.NoError(t, err)
  309. }
  310. func loginSourcesList(t *testing.T, ctx context.Context, db *loginSources) {
  311. mock := NewMockLoginSourceFilesStore()
  312. mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
  313. if opts.OnlyActivated {
  314. return []*LoginSource{
  315. {ID: 1},
  316. }
  317. }
  318. return []*LoginSource{
  319. {ID: 1},
  320. {ID: 2},
  321. }
  322. })
  323. setMockLoginSourceFilesStore(t, db, mock)
  324. // Create two login sources in database, one activated and the other one not
  325. _, err := db.Create(ctx,
  326. CreateLoginSourceOptions{
  327. Type: auth.PAM,
  328. Name: "PAM",
  329. Config: &pam.Config{
  330. ServiceName: "PAM",
  331. },
  332. },
  333. )
  334. require.NoError(t, err)
  335. _, err = db.Create(ctx,
  336. CreateLoginSourceOptions{
  337. Type: auth.GitHub,
  338. Name: "GitHub",
  339. Activated: true,
  340. Config: &github.Config{
  341. APIEndpoint: "https://api.github.com",
  342. },
  343. },
  344. )
  345. require.NoError(t, err)
  346. // List all login sources
  347. sources, err := db.List(ctx, ListLoginSourceOptions{})
  348. require.NoError(t, err)
  349. assert.Equal(t, 4, len(sources), "number of sources")
  350. // Only list activated login sources
  351. sources, err = db.List(ctx, ListLoginSourceOptions{OnlyActivated: true})
  352. require.NoError(t, err)
  353. assert.Equal(t, 2, len(sources), "number of sources")
  354. }
  355. func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSources) {
  356. mock := NewMockLoginSourceFilesStore()
  357. mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
  358. mockFile := NewMockLoginSourceFileStore()
  359. mockFile.SetGeneralFunc.SetDefaultHook(func(name, value string) {
  360. assert.Equal(t, "is_default", name)
  361. assert.Equal(t, "false", value)
  362. })
  363. return []*LoginSource{
  364. {
  365. File: mockFile,
  366. },
  367. }
  368. })
  369. setMockLoginSourceFilesStore(t, db, mock)
  370. // Create two login sources both have default on
  371. source1, err := db.Create(ctx,
  372. CreateLoginSourceOptions{
  373. Type: auth.PAM,
  374. Name: "PAM",
  375. Default: true,
  376. Config: &pam.Config{
  377. ServiceName: "PAM",
  378. },
  379. },
  380. )
  381. require.NoError(t, err)
  382. source2, err := db.Create(ctx,
  383. CreateLoginSourceOptions{
  384. Type: auth.GitHub,
  385. Name: "GitHub",
  386. Activated: true,
  387. Default: true,
  388. Config: &github.Config{
  389. APIEndpoint: "https://api.github.com",
  390. },
  391. },
  392. )
  393. require.NoError(t, err)
  394. // Set source 1 as default
  395. err = db.ResetNonDefault(ctx, source1)
  396. require.NoError(t, err)
  397. // Verify the default state
  398. source1, err = db.GetByID(ctx, source1.ID)
  399. require.NoError(t, err)
  400. assert.True(t, source1.IsDefault)
  401. source2, err = db.GetByID(ctx, source2.ID)
  402. require.NoError(t, err)
  403. assert.False(t, source2.IsDefault)
  404. }
  405. func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSources) {
  406. t.Run("save to database", func(t *testing.T) {
  407. // Create a login source with name "GitHub"
  408. source, err := db.Create(ctx,
  409. CreateLoginSourceOptions{
  410. Type: auth.GitHub,
  411. Name: "GitHub",
  412. Activated: true,
  413. Default: false,
  414. Config: &github.Config{
  415. APIEndpoint: "https://api.github.com",
  416. },
  417. },
  418. )
  419. require.NoError(t, err)
  420. source.IsActived = false
  421. source.Provider = github.NewProvider(&github.Config{
  422. APIEndpoint: "https://api2.github.com",
  423. })
  424. err = db.Save(ctx, source)
  425. require.NoError(t, err)
  426. source, err = db.GetByID(ctx, source.ID)
  427. require.NoError(t, err)
  428. assert.False(t, source.IsActived)
  429. assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
  430. })
  431. t.Run("save to file", func(t *testing.T) {
  432. mockFile := NewMockLoginSourceFileStore()
  433. source := &LoginSource{
  434. Provider: github.NewProvider(&github.Config{
  435. APIEndpoint: "https://api.github.com",
  436. }),
  437. File: mockFile,
  438. }
  439. err := db.Save(ctx, source)
  440. require.NoError(t, err)
  441. mockrequire.Called(t, mockFile.SaveFunc)
  442. })
  443. }