login_sources_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/assert"
  9. "gorm.io/gorm"
  10. "gogs.io/gogs/internal/auth"
  11. "gogs.io/gogs/internal/auth/github"
  12. "gogs.io/gogs/internal/auth/pam"
  13. "gogs.io/gogs/internal/errutil"
  14. )
  15. func TestLoginSource_BeforeSave(t *testing.T) {
  16. now := time.Now()
  17. db := &gorm.DB{
  18. Config: &gorm.Config{
  19. SkipDefaultTransaction: true,
  20. NowFunc: func() time.Time {
  21. return now
  22. },
  23. },
  24. }
  25. t.Run("Config has not been set", func(t *testing.T) {
  26. s := &LoginSource{}
  27. err := s.BeforeSave(db)
  28. if err != nil {
  29. t.Fatal(err)
  30. }
  31. assert.Empty(t, s.Config)
  32. })
  33. t.Run("Config has been set", func(t *testing.T) {
  34. s := &LoginSource{
  35. Provider: pam.NewProvider(&pam.Config{
  36. ServiceName: "pam_service",
  37. }),
  38. }
  39. err := s.BeforeSave(db)
  40. if err != nil {
  41. t.Fatal(err)
  42. }
  43. assert.Equal(t, `{"ServiceName":"pam_service"}`, s.Config)
  44. })
  45. }
  46. func TestLoginSource_BeforeCreate(t *testing.T) {
  47. now := time.Now()
  48. db := &gorm.DB{
  49. Config: &gorm.Config{
  50. SkipDefaultTransaction: true,
  51. NowFunc: func() time.Time {
  52. return now
  53. },
  54. },
  55. }
  56. t.Run("CreatedUnix has been set", func(t *testing.T) {
  57. s := &LoginSource{CreatedUnix: 1}
  58. _ = s.BeforeCreate(db)
  59. assert.Equal(t, int64(1), s.CreatedUnix)
  60. assert.Equal(t, int64(0), s.UpdatedUnix)
  61. })
  62. t.Run("CreatedUnix has not been set", func(t *testing.T) {
  63. s := &LoginSource{}
  64. _ = s.BeforeCreate(db)
  65. assert.Equal(t, db.NowFunc().Unix(), s.CreatedUnix)
  66. assert.Equal(t, db.NowFunc().Unix(), s.UpdatedUnix)
  67. })
  68. }
  69. func Test_loginSources(t *testing.T) {
  70. if testing.Short() {
  71. t.Skip()
  72. }
  73. t.Parallel()
  74. tables := []interface{}{new(LoginSource), new(User)}
  75. db := &loginSources{
  76. DB: initTestDB(t, "loginSources", tables...),
  77. }
  78. for _, tc := range []struct {
  79. name string
  80. test func(*testing.T, *loginSources)
  81. }{
  82. {"Create", test_loginSources_Create},
  83. {"Count", test_loginSources_Count},
  84. {"DeleteByID", test_loginSources_DeleteByID},
  85. {"GetByID", test_loginSources_GetByID},
  86. {"List", test_loginSources_List},
  87. {"ResetNonDefault", test_loginSources_ResetNonDefault},
  88. {"Save", test_loginSources_Save},
  89. } {
  90. t.Run(tc.name, func(t *testing.T) {
  91. t.Cleanup(func() {
  92. err := clearTables(t, db.DB, tables...)
  93. if err != nil {
  94. t.Fatal(err)
  95. }
  96. })
  97. tc.test(t, db)
  98. })
  99. if t.Failed() {
  100. break
  101. }
  102. }
  103. }
  104. func test_loginSources_Create(t *testing.T, db *loginSources) {
  105. // Create first login source with name "GitHub"
  106. source, err := db.Create(CreateLoginSourceOpts{
  107. Type: auth.GitHub,
  108. Name: "GitHub",
  109. Activated: true,
  110. Default: false,
  111. Config: &github.Config{
  112. APIEndpoint: "https://api.github.com",
  113. },
  114. })
  115. if err != nil {
  116. t.Fatal(err)
  117. }
  118. // Get it back and check the Created field
  119. source, err = db.GetByID(source.ID)
  120. if err != nil {
  121. t.Fatal(err)
  122. }
  123. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
  124. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
  125. // Try create second login source with same name should fail
  126. _, err = db.Create(CreateLoginSourceOpts{Name: source.Name})
  127. expErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
  128. assert.Equal(t, expErr, err)
  129. }
  130. func test_loginSources_Count(t *testing.T, db *loginSources) {
  131. // Create two login sources, one in database and one as source file.
  132. _, err := db.Create(CreateLoginSourceOpts{
  133. Type: auth.GitHub,
  134. Name: "GitHub",
  135. Activated: true,
  136. Default: false,
  137. Config: &github.Config{
  138. APIEndpoint: "https://api.github.com",
  139. },
  140. })
  141. if err != nil {
  142. t.Fatal(err)
  143. }
  144. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  145. MockLen: func() int {
  146. return 2
  147. },
  148. })
  149. assert.Equal(t, int64(3), db.Count())
  150. }
  151. func test_loginSources_DeleteByID(t *testing.T, db *loginSources) {
  152. t.Run("delete but in used", func(t *testing.T) {
  153. source, err := db.Create(CreateLoginSourceOpts{
  154. Type: auth.GitHub,
  155. Name: "GitHub",
  156. Activated: true,
  157. Default: false,
  158. Config: &github.Config{
  159. APIEndpoint: "https://api.github.com",
  160. },
  161. })
  162. if err != nil {
  163. t.Fatal(err)
  164. }
  165. // Create a user that uses this login source
  166. _, err = (&users{DB: db.DB}).Create("alice", "", CreateUserOpts{
  167. LoginSource: source.ID,
  168. })
  169. if err != nil {
  170. t.Fatal(err)
  171. }
  172. // Delete the login source will result in error
  173. err = db.DeleteByID(source.ID)
  174. expErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
  175. assert.Equal(t, expErr, err)
  176. })
  177. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  178. MockGetByID: func(id int64) (*LoginSource, error) {
  179. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  180. },
  181. })
  182. // Create a login source with name "GitHub2"
  183. source, err := db.Create(CreateLoginSourceOpts{
  184. Type: auth.GitHub,
  185. Name: "GitHub2",
  186. Activated: true,
  187. Default: false,
  188. Config: &github.Config{
  189. APIEndpoint: "https://api.github.com",
  190. },
  191. })
  192. if err != nil {
  193. t.Fatal(err)
  194. }
  195. // Delete a non-existent ID is noop
  196. err = db.DeleteByID(9999)
  197. if err != nil {
  198. t.Fatal(err)
  199. }
  200. // We should be able to get it back
  201. _, err = db.GetByID(source.ID)
  202. if err != nil {
  203. t.Fatal(err)
  204. }
  205. // Now delete this login source with ID
  206. err = db.DeleteByID(source.ID)
  207. if err != nil {
  208. t.Fatal(err)
  209. }
  210. // We should get token not found error
  211. _, err = db.GetByID(source.ID)
  212. expErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
  213. assert.Equal(t, expErr, err)
  214. }
  215. func test_loginSources_GetByID(t *testing.T, db *loginSources) {
  216. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  217. MockGetByID: func(id int64) (*LoginSource, error) {
  218. if id != 101 {
  219. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  220. }
  221. return &LoginSource{ID: id}, nil
  222. },
  223. })
  224. expConfig := &github.Config{
  225. APIEndpoint: "https://api.github.com",
  226. }
  227. // Create a login source with name "GitHub"
  228. source, err := db.Create(CreateLoginSourceOpts{
  229. Type: auth.GitHub,
  230. Name: "GitHub",
  231. Activated: true,
  232. Default: false,
  233. Config: expConfig,
  234. })
  235. if err != nil {
  236. t.Fatal(err)
  237. }
  238. // Get the one in the database and test the read/write hooks
  239. source, err = db.GetByID(source.ID)
  240. if err != nil {
  241. t.Fatal(err)
  242. }
  243. assert.Equal(t, expConfig, source.Provider.Config())
  244. // Get the one in source file store
  245. _, err = db.GetByID(101)
  246. if err != nil {
  247. t.Fatal(err)
  248. }
  249. }
  250. func test_loginSources_List(t *testing.T, db *loginSources) {
  251. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  252. MockList: func(opts ListLoginSourceOpts) []*LoginSource {
  253. if opts.OnlyActivated {
  254. return []*LoginSource{
  255. {ID: 1},
  256. }
  257. }
  258. return []*LoginSource{
  259. {ID: 1},
  260. {ID: 2},
  261. }
  262. },
  263. })
  264. // Create two login sources in database, one activated and the other one not
  265. _, err := db.Create(CreateLoginSourceOpts{
  266. Type: auth.PAM,
  267. Name: "PAM",
  268. Config: &pam.Config{
  269. ServiceName: "PAM",
  270. },
  271. })
  272. if err != nil {
  273. t.Fatal(err)
  274. }
  275. _, err = db.Create(CreateLoginSourceOpts{
  276. Type: auth.GitHub,
  277. Name: "GitHub",
  278. Activated: true,
  279. Config: &github.Config{
  280. APIEndpoint: "https://api.github.com",
  281. },
  282. })
  283. if err != nil {
  284. t.Fatal(err)
  285. }
  286. // List all login sources
  287. sources, err := db.List(ListLoginSourceOpts{})
  288. if err != nil {
  289. t.Fatal(err)
  290. }
  291. assert.Equal(t, 4, len(sources), "number of sources")
  292. // Only list activated login sources
  293. sources, err = db.List(ListLoginSourceOpts{OnlyActivated: true})
  294. if err != nil {
  295. t.Fatal(err)
  296. }
  297. assert.Equal(t, 2, len(sources), "number of sources")
  298. }
  299. func test_loginSources_ResetNonDefault(t *testing.T, db *loginSources) {
  300. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  301. MockList: func(opts ListLoginSourceOpts) []*LoginSource {
  302. return []*LoginSource{
  303. {
  304. File: &mockLoginSourceFileStore{
  305. MockSetGeneral: func(name, value string) {
  306. assert.Equal(t, "is_default", name)
  307. assert.Equal(t, "false", value)
  308. },
  309. MockSave: func() error {
  310. return nil
  311. },
  312. },
  313. },
  314. }
  315. },
  316. MockUpdate: func(source *LoginSource) {},
  317. })
  318. // Create two login sources both have default on
  319. source1, err := db.Create(CreateLoginSourceOpts{
  320. Type: auth.PAM,
  321. Name: "PAM",
  322. Default: true,
  323. Config: &pam.Config{
  324. ServiceName: "PAM",
  325. },
  326. })
  327. if err != nil {
  328. t.Fatal(err)
  329. }
  330. source2, err := db.Create(CreateLoginSourceOpts{
  331. Type: auth.GitHub,
  332. Name: "GitHub",
  333. Activated: true,
  334. Default: true,
  335. Config: &github.Config{
  336. APIEndpoint: "https://api.github.com",
  337. },
  338. })
  339. if err != nil {
  340. t.Fatal(err)
  341. }
  342. // Set source 1 as default
  343. err = db.ResetNonDefault(source1)
  344. if err != nil {
  345. t.Fatal(err)
  346. }
  347. // Verify the default state
  348. source1, err = db.GetByID(source1.ID)
  349. if err != nil {
  350. t.Fatal(err)
  351. }
  352. assert.True(t, source1.IsDefault)
  353. source2, err = db.GetByID(source2.ID)
  354. if err != nil {
  355. t.Fatal(err)
  356. }
  357. assert.False(t, source2.IsDefault)
  358. }
  359. func test_loginSources_Save(t *testing.T, db *loginSources) {
  360. t.Run("save to database", func(t *testing.T) {
  361. // Create a login source with name "GitHub"
  362. source, err := db.Create(CreateLoginSourceOpts{
  363. Type: auth.GitHub,
  364. Name: "GitHub",
  365. Activated: true,
  366. Default: false,
  367. Config: &github.Config{
  368. APIEndpoint: "https://api.github.com",
  369. },
  370. })
  371. if err != nil {
  372. t.Fatal(err)
  373. }
  374. source.IsActived = false
  375. source.Provider = github.NewProvider(&github.Config{
  376. APIEndpoint: "https://api2.github.com",
  377. })
  378. err = db.Save(source)
  379. if err != nil {
  380. t.Fatal(err)
  381. }
  382. source, err = db.GetByID(source.ID)
  383. if err != nil {
  384. t.Fatal(err)
  385. }
  386. assert.False(t, source.IsActived)
  387. assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
  388. })
  389. t.Run("save to file", func(t *testing.T) {
  390. calledSave := false
  391. source := &LoginSource{
  392. Provider: github.NewProvider(&github.Config{
  393. APIEndpoint: "https://api.github.com",
  394. }),
  395. File: &mockLoginSourceFileStore{
  396. MockSetGeneral: func(name, value string) {},
  397. MockSetConfig: func(cfg interface{}) error { return nil },
  398. MockSave: func() error {
  399. calledSave = true
  400. return nil
  401. },
  402. },
  403. }
  404. err := db.Save(source)
  405. if err != nil {
  406. t.Fatal(err)
  407. }
  408. assert.True(t, calledSave)
  409. })
  410. }