login_sources_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "context"
  7. "testing"
  8. "time"
  9. mockrequire "github.com/derision-test/go-mockgen/testutil/require"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/stretchr/testify/require"
  12. "gorm.io/gorm"
  13. "gogs.io/gogs/internal/auth"
  14. "gogs.io/gogs/internal/auth/github"
  15. "gogs.io/gogs/internal/auth/pam"
  16. "gogs.io/gogs/internal/dbtest"
  17. "gogs.io/gogs/internal/errutil"
  18. )
  19. func TestLoginSource_BeforeSave(t *testing.T) {
  20. now := time.Now()
  21. db := &gorm.DB{
  22. Config: &gorm.Config{
  23. SkipDefaultTransaction: true,
  24. NowFunc: func() time.Time {
  25. return now
  26. },
  27. },
  28. }
  29. t.Run("Config has not been set", func(t *testing.T) {
  30. s := &LoginSource{}
  31. err := s.BeforeSave(db)
  32. require.NoError(t, err)
  33. assert.Empty(t, s.Config)
  34. })
  35. t.Run("Config has been set", func(t *testing.T) {
  36. s := &LoginSource{
  37. Provider: pam.NewProvider(&pam.Config{
  38. ServiceName: "pam_service",
  39. }),
  40. }
  41. err := s.BeforeSave(db)
  42. require.NoError(t, err)
  43. assert.Equal(t, `{"ServiceName":"pam_service"}`, s.Config)
  44. })
  45. }
  46. func TestLoginSource_BeforeCreate(t *testing.T) {
  47. now := time.Now()
  48. db := &gorm.DB{
  49. Config: &gorm.Config{
  50. SkipDefaultTransaction: true,
  51. NowFunc: func() time.Time {
  52. return now
  53. },
  54. },
  55. }
  56. t.Run("CreatedUnix has been set", func(t *testing.T) {
  57. s := &LoginSource{CreatedUnix: 1}
  58. _ = s.BeforeCreate(db)
  59. assert.Equal(t, int64(1), s.CreatedUnix)
  60. assert.Equal(t, int64(0), s.UpdatedUnix)
  61. })
  62. t.Run("CreatedUnix has not been set", func(t *testing.T) {
  63. s := &LoginSource{}
  64. _ = s.BeforeCreate(db)
  65. assert.Equal(t, db.NowFunc().Unix(), s.CreatedUnix)
  66. assert.Equal(t, db.NowFunc().Unix(), s.UpdatedUnix)
  67. })
  68. }
  69. func Test_loginSources(t *testing.T) {
  70. if testing.Short() {
  71. t.Skip()
  72. }
  73. t.Parallel()
  74. tables := []interface{}{new(LoginSource), new(User)}
  75. db := &loginSources{
  76. DB: dbtest.NewDB(t, "loginSources", tables...),
  77. }
  78. for _, tc := range []struct {
  79. name string
  80. test func(*testing.T, *loginSources)
  81. }{
  82. {"Create", loginSourcesCreate},
  83. {"Count", loginSourcesCount},
  84. {"DeleteByID", loginSourcesDeleteByID},
  85. {"GetByID", loginSourcesGetByID},
  86. {"List", loginSourcesList},
  87. {"ResetNonDefault", loginSourcesResetNonDefault},
  88. {"Save", loginSourcesSave},
  89. } {
  90. t.Run(tc.name, func(t *testing.T) {
  91. t.Cleanup(func() {
  92. err := clearTables(t, db.DB, tables...)
  93. require.NoError(t, err)
  94. })
  95. tc.test(t, db)
  96. })
  97. if t.Failed() {
  98. break
  99. }
  100. }
  101. }
  102. func loginSourcesCreate(t *testing.T, db *loginSources) {
  103. ctx := context.Background()
  104. // Create first login source with name "GitHub"
  105. source, err := db.Create(ctx,
  106. CreateLoginSourceOptions{
  107. Type: auth.GitHub,
  108. Name: "GitHub",
  109. Activated: true,
  110. Default: false,
  111. Config: &github.Config{
  112. APIEndpoint: "https://api.github.com",
  113. },
  114. },
  115. )
  116. require.NoError(t, err)
  117. // Get it back and check the Created field
  118. source, err = db.GetByID(ctx, source.ID)
  119. require.NoError(t, err)
  120. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
  121. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
  122. // Try create second login source with same name should fail
  123. _, err = db.Create(ctx, CreateLoginSourceOptions{Name: source.Name})
  124. wantErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
  125. assert.Equal(t, wantErr, err)
  126. }
  127. func loginSourcesCount(t *testing.T, db *loginSources) {
  128. ctx := context.Background()
  129. // Create two login sources, one in database and one as source file.
  130. _, err := db.Create(ctx,
  131. CreateLoginSourceOptions{
  132. Type: auth.GitHub,
  133. Name: "GitHub",
  134. Activated: true,
  135. Default: false,
  136. Config: &github.Config{
  137. APIEndpoint: "https://api.github.com",
  138. },
  139. },
  140. )
  141. require.NoError(t, err)
  142. mock := NewMockLoginSourceFilesStore()
  143. mock.LenFunc.SetDefaultReturn(2)
  144. setMockLoginSourceFilesStore(t, db, mock)
  145. assert.Equal(t, int64(3), db.Count(ctx))
  146. }
  147. func loginSourcesDeleteByID(t *testing.T, db *loginSources) {
  148. ctx := context.Background()
  149. t.Run("delete but in used", func(t *testing.T) {
  150. source, err := db.Create(ctx,
  151. CreateLoginSourceOptions{
  152. Type: auth.GitHub,
  153. Name: "GitHub",
  154. Activated: true,
  155. Default: false,
  156. Config: &github.Config{
  157. APIEndpoint: "https://api.github.com",
  158. },
  159. },
  160. )
  161. require.NoError(t, err)
  162. // Create a user that uses this login source
  163. _, err = (&users{DB: db.DB}).Create(ctx, "alice", "",
  164. CreateUserOptions{
  165. LoginSource: source.ID,
  166. },
  167. )
  168. require.NoError(t, err)
  169. // Delete the login source will result in error
  170. err = db.DeleteByID(ctx, source.ID)
  171. wantErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
  172. assert.Equal(t, wantErr, err)
  173. })
  174. mock := NewMockLoginSourceFilesStore()
  175. mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
  176. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  177. })
  178. setMockLoginSourceFilesStore(t, db, mock)
  179. // Create a login source with name "GitHub2"
  180. source, err := db.Create(ctx,
  181. CreateLoginSourceOptions{
  182. Type: auth.GitHub,
  183. Name: "GitHub2",
  184. Activated: true,
  185. Default: false,
  186. Config: &github.Config{
  187. APIEndpoint: "https://api.github.com",
  188. },
  189. },
  190. )
  191. require.NoError(t, err)
  192. // Delete a non-existent ID is noop
  193. err = db.DeleteByID(ctx, 9999)
  194. require.NoError(t, err)
  195. // We should be able to get it back
  196. _, err = db.GetByID(ctx, source.ID)
  197. require.NoError(t, err)
  198. // Now delete this login source with ID
  199. err = db.DeleteByID(ctx, source.ID)
  200. require.NoError(t, err)
  201. // We should get token not found error
  202. _, err = db.GetByID(ctx, source.ID)
  203. wantErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
  204. assert.Equal(t, wantErr, err)
  205. }
  206. func loginSourcesGetByID(t *testing.T, db *loginSources) {
  207. ctx := context.Background()
  208. mock := NewMockLoginSourceFilesStore()
  209. mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
  210. if id != 101 {
  211. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  212. }
  213. return &LoginSource{ID: id}, nil
  214. })
  215. setMockLoginSourceFilesStore(t, db, mock)
  216. expConfig := &github.Config{
  217. APIEndpoint: "https://api.github.com",
  218. }
  219. // Create a login source with name "GitHub"
  220. source, err := db.Create(ctx,
  221. CreateLoginSourceOptions{
  222. Type: auth.GitHub,
  223. Name: "GitHub",
  224. Activated: true,
  225. Default: false,
  226. Config: expConfig,
  227. },
  228. )
  229. require.NoError(t, err)
  230. // Get the one in the database and test the read/write hooks
  231. source, err = db.GetByID(ctx, source.ID)
  232. require.NoError(t, err)
  233. assert.Equal(t, expConfig, source.Provider.Config())
  234. // Get the one in source file store
  235. _, err = db.GetByID(ctx, 101)
  236. require.NoError(t, err)
  237. }
  238. func loginSourcesList(t *testing.T, db *loginSources) {
  239. ctx := context.Background()
  240. mock := NewMockLoginSourceFilesStore()
  241. mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
  242. if opts.OnlyActivated {
  243. return []*LoginSource{
  244. {ID: 1},
  245. }
  246. }
  247. return []*LoginSource{
  248. {ID: 1},
  249. {ID: 2},
  250. }
  251. })
  252. setMockLoginSourceFilesStore(t, db, mock)
  253. // Create two login sources in database, one activated and the other one not
  254. _, err := db.Create(ctx,
  255. CreateLoginSourceOptions{
  256. Type: auth.PAM,
  257. Name: "PAM",
  258. Config: &pam.Config{
  259. ServiceName: "PAM",
  260. },
  261. },
  262. )
  263. require.NoError(t, err)
  264. _, err = db.Create(ctx,
  265. CreateLoginSourceOptions{
  266. Type: auth.GitHub,
  267. Name: "GitHub",
  268. Activated: true,
  269. Config: &github.Config{
  270. APIEndpoint: "https://api.github.com",
  271. },
  272. },
  273. )
  274. require.NoError(t, err)
  275. // List all login sources
  276. sources, err := db.List(ctx, ListLoginSourceOptions{})
  277. require.NoError(t, err)
  278. assert.Equal(t, 4, len(sources), "number of sources")
  279. // Only list activated login sources
  280. sources, err = db.List(ctx, ListLoginSourceOptions{OnlyActivated: true})
  281. require.NoError(t, err)
  282. assert.Equal(t, 2, len(sources), "number of sources")
  283. }
  284. func loginSourcesResetNonDefault(t *testing.T, db *loginSources) {
  285. ctx := context.Background()
  286. mock := NewMockLoginSourceFilesStore()
  287. mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
  288. mockFile := NewMockLoginSourceFileStore()
  289. mockFile.SetGeneralFunc.SetDefaultHook(func(name, value string) {
  290. assert.Equal(t, "is_default", name)
  291. assert.Equal(t, "false", value)
  292. })
  293. return []*LoginSource{
  294. {
  295. File: mockFile,
  296. },
  297. }
  298. })
  299. setMockLoginSourceFilesStore(t, db, mock)
  300. // Create two login sources both have default on
  301. source1, err := db.Create(ctx,
  302. CreateLoginSourceOptions{
  303. Type: auth.PAM,
  304. Name: "PAM",
  305. Default: true,
  306. Config: &pam.Config{
  307. ServiceName: "PAM",
  308. },
  309. },
  310. )
  311. require.NoError(t, err)
  312. source2, err := db.Create(ctx,
  313. CreateLoginSourceOptions{
  314. Type: auth.GitHub,
  315. Name: "GitHub",
  316. Activated: true,
  317. Default: true,
  318. Config: &github.Config{
  319. APIEndpoint: "https://api.github.com",
  320. },
  321. },
  322. )
  323. require.NoError(t, err)
  324. // Set source 1 as default
  325. err = db.ResetNonDefault(ctx, source1)
  326. require.NoError(t, err)
  327. // Verify the default state
  328. source1, err = db.GetByID(ctx, source1.ID)
  329. require.NoError(t, err)
  330. assert.True(t, source1.IsDefault)
  331. source2, err = db.GetByID(ctx, source2.ID)
  332. require.NoError(t, err)
  333. assert.False(t, source2.IsDefault)
  334. }
  335. func loginSourcesSave(t *testing.T, db *loginSources) {
  336. ctx := context.Background()
  337. t.Run("save to database", func(t *testing.T) {
  338. // Create a login source with name "GitHub"
  339. source, err := db.Create(ctx,
  340. CreateLoginSourceOptions{
  341. Type: auth.GitHub,
  342. Name: "GitHub",
  343. Activated: true,
  344. Default: false,
  345. Config: &github.Config{
  346. APIEndpoint: "https://api.github.com",
  347. },
  348. },
  349. )
  350. require.NoError(t, err)
  351. source.IsActived = false
  352. source.Provider = github.NewProvider(&github.Config{
  353. APIEndpoint: "https://api2.github.com",
  354. })
  355. err = db.Save(ctx, source)
  356. require.NoError(t, err)
  357. source, err = db.GetByID(ctx, source.ID)
  358. require.NoError(t, err)
  359. assert.False(t, source.IsActived)
  360. assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
  361. })
  362. t.Run("save to file", func(t *testing.T) {
  363. mockFile := NewMockLoginSourceFileStore()
  364. source := &LoginSource{
  365. Provider: github.NewProvider(&github.Config{
  366. APIEndpoint: "https://api.github.com",
  367. }),
  368. File: mockFile,
  369. }
  370. err := db.Save(ctx, source)
  371. require.NoError(t, err)
  372. mockrequire.Called(t, mockFile.SaveFunc)
  373. })
  374. }