login_sources_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package database
  5. import (
  6. "context"
  7. "testing"
  8. "time"
  9. mockrequire "github.com/derision-test/go-mockgen/testutil/require"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/stretchr/testify/require"
  12. "gorm.io/gorm"
  13. "gogs.io/gogs/internal/auth"
  14. "gogs.io/gogs/internal/auth/github"
  15. "gogs.io/gogs/internal/auth/ldap"
  16. "gogs.io/gogs/internal/auth/pam"
  17. "gogs.io/gogs/internal/auth/smtp"
  18. "gogs.io/gogs/internal/errutil"
  19. )
  20. func TestLoginSource_BeforeSave(t *testing.T) {
  21. now := time.Now()
  22. db := &gorm.DB{
  23. Config: &gorm.Config{
  24. SkipDefaultTransaction: true,
  25. NowFunc: func() time.Time {
  26. return now
  27. },
  28. },
  29. }
  30. t.Run("Config has not been set", func(t *testing.T) {
  31. s := &LoginSource{}
  32. err := s.BeforeSave(db)
  33. require.NoError(t, err)
  34. assert.Empty(t, s.Config)
  35. })
  36. t.Run("Config has been set", func(t *testing.T) {
  37. s := &LoginSource{
  38. Provider: pam.NewProvider(&pam.Config{
  39. ServiceName: "pam_service",
  40. }),
  41. }
  42. err := s.BeforeSave(db)
  43. require.NoError(t, err)
  44. assert.Equal(t, `{"ServiceName":"pam_service"}`, s.Config)
  45. })
  46. }
  47. func TestLoginSource_BeforeCreate(t *testing.T) {
  48. now := time.Now()
  49. db := &gorm.DB{
  50. Config: &gorm.Config{
  51. SkipDefaultTransaction: true,
  52. NowFunc: func() time.Time {
  53. return now
  54. },
  55. },
  56. }
  57. t.Run("CreatedUnix has been set", func(t *testing.T) {
  58. s := &LoginSource{
  59. CreatedUnix: 1,
  60. }
  61. _ = s.BeforeCreate(db)
  62. assert.Equal(t, int64(1), s.CreatedUnix)
  63. assert.Equal(t, int64(0), s.UpdatedUnix)
  64. })
  65. t.Run("CreatedUnix has not been set", func(t *testing.T) {
  66. s := &LoginSource{}
  67. _ = s.BeforeCreate(db)
  68. assert.Equal(t, db.NowFunc().Unix(), s.CreatedUnix)
  69. assert.Equal(t, db.NowFunc().Unix(), s.UpdatedUnix)
  70. })
  71. }
  72. func TestLoginSource_BeforeUpdate(t *testing.T) {
  73. now := time.Now()
  74. db := &gorm.DB{
  75. Config: &gorm.Config{
  76. SkipDefaultTransaction: true,
  77. NowFunc: func() time.Time {
  78. return now
  79. },
  80. },
  81. }
  82. s := &LoginSource{}
  83. _ = s.BeforeUpdate(db)
  84. assert.Equal(t, db.NowFunc().Unix(), s.UpdatedUnix)
  85. }
  86. func TestLoginSource_AfterFind(t *testing.T) {
  87. now := time.Now()
  88. db := &gorm.DB{
  89. Config: &gorm.Config{
  90. SkipDefaultTransaction: true,
  91. NowFunc: func() time.Time {
  92. return now
  93. },
  94. },
  95. }
  96. tests := []struct {
  97. name string
  98. authType auth.Type
  99. wantType any
  100. }{
  101. {
  102. name: "LDAP",
  103. authType: auth.LDAP,
  104. wantType: &ldap.Provider{},
  105. },
  106. {
  107. name: "DLDAP",
  108. authType: auth.DLDAP,
  109. wantType: &ldap.Provider{},
  110. },
  111. {
  112. name: "SMTP",
  113. authType: auth.SMTP,
  114. wantType: &smtp.Provider{},
  115. },
  116. {
  117. name: "PAM",
  118. authType: auth.PAM,
  119. wantType: &pam.Provider{},
  120. },
  121. {
  122. name: "GitHub",
  123. authType: auth.GitHub,
  124. wantType: &github.Provider{},
  125. },
  126. }
  127. for _, test := range tests {
  128. t.Run(test.name, func(t *testing.T) {
  129. s := LoginSource{
  130. Type: test.authType,
  131. Config: `{}`,
  132. CreatedUnix: now.Unix(),
  133. UpdatedUnix: now.Unix(),
  134. }
  135. err := s.AfterFind(db)
  136. require.NoError(t, err)
  137. assert.Equal(t, s.CreatedUnix, s.Created.Unix())
  138. assert.Equal(t, s.UpdatedUnix, s.Updated.Unix())
  139. assert.IsType(t, test.wantType, s.Provider)
  140. })
  141. }
  142. }
  143. func TestLoginSources(t *testing.T) {
  144. if testing.Short() {
  145. t.Skip()
  146. }
  147. t.Parallel()
  148. ctx := context.Background()
  149. db := &loginSourcesStore{
  150. DB: newTestDB(t, "loginSourcesStore"),
  151. }
  152. for _, tc := range []struct {
  153. name string
  154. test func(t *testing.T, ctx context.Context, db *loginSourcesStore)
  155. }{
  156. {"Create", loginSourcesCreate},
  157. {"Count", loginSourcesCount},
  158. {"DeleteByID", loginSourcesDeleteByID},
  159. {"GetByID", loginSourcesGetByID},
  160. {"List", loginSourcesList},
  161. {"ResetNonDefault", loginSourcesResetNonDefault},
  162. {"Save", loginSourcesSave},
  163. } {
  164. t.Run(tc.name, func(t *testing.T) {
  165. t.Cleanup(func() {
  166. err := clearTables(t, db.DB)
  167. require.NoError(t, err)
  168. })
  169. tc.test(t, ctx, db)
  170. })
  171. if t.Failed() {
  172. break
  173. }
  174. }
  175. }
  176. func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSourcesStore) {
  177. // Create first login source with name "GitHub"
  178. source, err := db.Create(ctx,
  179. CreateLoginSourceOptions{
  180. Type: auth.GitHub,
  181. Name: "GitHub",
  182. Activated: true,
  183. Default: false,
  184. Config: &github.Config{
  185. APIEndpoint: "https://api.github.com",
  186. },
  187. },
  188. )
  189. require.NoError(t, err)
  190. // Get it back and check the Created field
  191. source, err = db.GetByID(ctx, source.ID)
  192. require.NoError(t, err)
  193. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
  194. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
  195. // Try create second login source with same name should fail
  196. _, err = db.Create(ctx, CreateLoginSourceOptions{Name: source.Name})
  197. wantErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
  198. assert.Equal(t, wantErr, err)
  199. }
  200. func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSourcesStore) {
  201. // Create two login sources, one in database and one as source file.
  202. _, err := db.Create(ctx,
  203. CreateLoginSourceOptions{
  204. Type: auth.GitHub,
  205. Name: "GitHub",
  206. Activated: true,
  207. Default: false,
  208. Config: &github.Config{
  209. APIEndpoint: "https://api.github.com",
  210. },
  211. },
  212. )
  213. require.NoError(t, err)
  214. mock := NewMockLoginSourceFilesStore()
  215. mock.LenFunc.SetDefaultReturn(2)
  216. setMockLoginSourceFilesStore(t, db, mock)
  217. assert.Equal(t, int64(3), db.Count(ctx))
  218. }
  219. func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
  220. t.Run("delete but in used", func(t *testing.T) {
  221. source, err := db.Create(ctx,
  222. CreateLoginSourceOptions{
  223. Type: auth.GitHub,
  224. Name: "GitHub",
  225. Activated: true,
  226. Default: false,
  227. Config: &github.Config{
  228. APIEndpoint: "https://api.github.com",
  229. },
  230. },
  231. )
  232. require.NoError(t, err)
  233. // Create a user that uses this login source
  234. _, err = (&usersStore{DB: db.DB}).Create(ctx, "alice", "",
  235. CreateUserOptions{
  236. LoginSource: source.ID,
  237. },
  238. )
  239. require.NoError(t, err)
  240. // Delete the login source will result in error
  241. err = db.DeleteByID(ctx, source.ID)
  242. wantErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
  243. assert.Equal(t, wantErr, err)
  244. })
  245. mock := NewMockLoginSourceFilesStore()
  246. mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
  247. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  248. })
  249. setMockLoginSourceFilesStore(t, db, mock)
  250. // Create a login source with name "GitHub2"
  251. source, err := db.Create(ctx,
  252. CreateLoginSourceOptions{
  253. Type: auth.GitHub,
  254. Name: "GitHub2",
  255. Activated: true,
  256. Default: false,
  257. Config: &github.Config{
  258. APIEndpoint: "https://api.github.com",
  259. },
  260. },
  261. )
  262. require.NoError(t, err)
  263. // Delete a non-existent ID is noop
  264. err = db.DeleteByID(ctx, 9999)
  265. require.NoError(t, err)
  266. // We should be able to get it back
  267. _, err = db.GetByID(ctx, source.ID)
  268. require.NoError(t, err)
  269. // Now delete this login source with ID
  270. err = db.DeleteByID(ctx, source.ID)
  271. require.NoError(t, err)
  272. // We should get token not found error
  273. _, err = db.GetByID(ctx, source.ID)
  274. wantErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
  275. assert.Equal(t, wantErr, err)
  276. }
  277. func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
  278. mock := NewMockLoginSourceFilesStore()
  279. mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
  280. if id != 101 {
  281. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  282. }
  283. return &LoginSource{ID: id}, nil
  284. })
  285. setMockLoginSourceFilesStore(t, db, mock)
  286. expConfig := &github.Config{
  287. APIEndpoint: "https://api.github.com",
  288. }
  289. // Create a login source with name "GitHub"
  290. source, err := db.Create(ctx,
  291. CreateLoginSourceOptions{
  292. Type: auth.GitHub,
  293. Name: "GitHub",
  294. Activated: true,
  295. Default: false,
  296. Config: expConfig,
  297. },
  298. )
  299. require.NoError(t, err)
  300. // Get the one in the database and test the read/write hooks
  301. source, err = db.GetByID(ctx, source.ID)
  302. require.NoError(t, err)
  303. assert.Equal(t, expConfig, source.Provider.Config())
  304. // Get the one in source file store
  305. _, err = db.GetByID(ctx, 101)
  306. require.NoError(t, err)
  307. }
  308. func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore) {
  309. mock := NewMockLoginSourceFilesStore()
  310. mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
  311. if opts.OnlyActivated {
  312. return []*LoginSource{
  313. {ID: 1},
  314. }
  315. }
  316. return []*LoginSource{
  317. {ID: 1},
  318. {ID: 2},
  319. }
  320. })
  321. setMockLoginSourceFilesStore(t, db, mock)
  322. // Create two login sources in database, one activated and the other one not
  323. _, err := db.Create(ctx,
  324. CreateLoginSourceOptions{
  325. Type: auth.PAM,
  326. Name: "PAM",
  327. Config: &pam.Config{
  328. ServiceName: "PAM",
  329. },
  330. },
  331. )
  332. require.NoError(t, err)
  333. _, err = db.Create(ctx,
  334. CreateLoginSourceOptions{
  335. Type: auth.GitHub,
  336. Name: "GitHub",
  337. Activated: true,
  338. Config: &github.Config{
  339. APIEndpoint: "https://api.github.com",
  340. },
  341. },
  342. )
  343. require.NoError(t, err)
  344. // List all login sources
  345. sources, err := db.List(ctx, ListLoginSourceOptions{})
  346. require.NoError(t, err)
  347. assert.Equal(t, 4, len(sources), "number of sources")
  348. // Only list activated login sources
  349. sources, err = db.List(ctx, ListLoginSourceOptions{OnlyActivated: true})
  350. require.NoError(t, err)
  351. assert.Equal(t, 2, len(sources), "number of sources")
  352. }
  353. func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSourcesStore) {
  354. mock := NewMockLoginSourceFilesStore()
  355. mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
  356. mockFile := NewMockLoginSourceFileStore()
  357. mockFile.SetGeneralFunc.SetDefaultHook(func(name, value string) {
  358. assert.Equal(t, "is_default", name)
  359. assert.Equal(t, "false", value)
  360. })
  361. return []*LoginSource{
  362. {
  363. File: mockFile,
  364. },
  365. }
  366. })
  367. setMockLoginSourceFilesStore(t, db, mock)
  368. // Create two login sources both have default on
  369. source1, err := db.Create(ctx,
  370. CreateLoginSourceOptions{
  371. Type: auth.PAM,
  372. Name: "PAM",
  373. Default: true,
  374. Config: &pam.Config{
  375. ServiceName: "PAM",
  376. },
  377. },
  378. )
  379. require.NoError(t, err)
  380. source2, err := db.Create(ctx,
  381. CreateLoginSourceOptions{
  382. Type: auth.GitHub,
  383. Name: "GitHub",
  384. Activated: true,
  385. Default: true,
  386. Config: &github.Config{
  387. APIEndpoint: "https://api.github.com",
  388. },
  389. },
  390. )
  391. require.NoError(t, err)
  392. // Set source 1 as default
  393. err = db.ResetNonDefault(ctx, source1)
  394. require.NoError(t, err)
  395. // Verify the default state
  396. source1, err = db.GetByID(ctx, source1.ID)
  397. require.NoError(t, err)
  398. assert.True(t, source1.IsDefault)
  399. source2, err = db.GetByID(ctx, source2.ID)
  400. require.NoError(t, err)
  401. assert.False(t, source2.IsDefault)
  402. }
  403. func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore) {
  404. t.Run("save to database", func(t *testing.T) {
  405. // Create a login source with name "GitHub"
  406. source, err := db.Create(ctx,
  407. CreateLoginSourceOptions{
  408. Type: auth.GitHub,
  409. Name: "GitHub",
  410. Activated: true,
  411. Default: false,
  412. Config: &github.Config{
  413. APIEndpoint: "https://api.github.com",
  414. },
  415. },
  416. )
  417. require.NoError(t, err)
  418. source.IsActived = false
  419. source.Provider = github.NewProvider(&github.Config{
  420. APIEndpoint: "https://api2.github.com",
  421. })
  422. err = db.Save(ctx, source)
  423. require.NoError(t, err)
  424. source, err = db.GetByID(ctx, source.ID)
  425. require.NoError(t, err)
  426. assert.False(t, source.IsActived)
  427. assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
  428. })
  429. t.Run("save to file", func(t *testing.T) {
  430. mockFile := NewMockLoginSourceFileStore()
  431. source := &LoginSource{
  432. Provider: github.NewProvider(&github.Config{
  433. APIEndpoint: "https://api.github.com",
  434. }),
  435. File: mockFile,
  436. }
  437. err := db.Save(ctx, source)
  438. require.NoError(t, err)
  439. mockrequire.Called(t, mockFile.SaveFunc)
  440. })
  441. }