two_factors_test.go 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "context"
  7. "testing"
  8. "time"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. "gogs.io/gogs/internal/dbtest"
  12. "gogs.io/gogs/internal/errutil"
  13. )
  14. func TestTwoFactors(t *testing.T) {
  15. if testing.Short() {
  16. t.Skip()
  17. }
  18. t.Parallel()
  19. tables := []interface{}{new(TwoFactor), new(TwoFactorRecoveryCode)}
  20. db := &twoFactors{
  21. DB: dbtest.NewDB(t, "twoFactors", tables...),
  22. }
  23. for _, tc := range []struct {
  24. name string
  25. test func(*testing.T, *twoFactors)
  26. }{
  27. {"Create", twoFactorsCreate},
  28. {"GetByUserID", twoFactorsGetByUserID},
  29. {"IsUserEnabled", twoFactorsIsUserEnabled},
  30. } {
  31. t.Run(tc.name, func(t *testing.T) {
  32. t.Cleanup(func() {
  33. err := clearTables(t, db.DB, tables...)
  34. require.NoError(t, err)
  35. })
  36. tc.test(t, db)
  37. })
  38. if t.Failed() {
  39. break
  40. }
  41. }
  42. }
  43. func twoFactorsCreate(t *testing.T, db *twoFactors) {
  44. ctx := context.Background()
  45. // Create a 2FA token
  46. err := db.Create(ctx, 1, "secure-key", "secure-secret")
  47. require.NoError(t, err)
  48. // Get it back and check the Created field
  49. tf, err := db.GetByUserID(ctx, 1)
  50. require.NoError(t, err)
  51. assert.Equal(t, db.NowFunc().Format(time.RFC3339), tf.Created.UTC().Format(time.RFC3339))
  52. // Verify there are 10 recover codes generated
  53. var count int64
  54. err = db.Model(new(TwoFactorRecoveryCode)).Count(&count).Error
  55. require.NoError(t, err)
  56. assert.Equal(t, int64(10), count)
  57. }
  58. func twoFactorsGetByUserID(t *testing.T, db *twoFactors) {
  59. ctx := context.Background()
  60. // Create a 2FA token for user 1
  61. err := db.Create(ctx, 1, "secure-key", "secure-secret")
  62. require.NoError(t, err)
  63. // We should be able to get it back
  64. _, err = db.GetByUserID(ctx, 1)
  65. require.NoError(t, err)
  66. // Try to get a non-existent 2FA token
  67. _, err = db.GetByUserID(ctx, 2)
  68. wantErr := ErrTwoFactorNotFound{args: errutil.Args{"userID": int64(2)}}
  69. assert.Equal(t, wantErr, err)
  70. }
  71. func twoFactorsIsUserEnabled(t *testing.T, db *twoFactors) {
  72. ctx := context.Background()
  73. // Create a 2FA token for user 1
  74. err := db.Create(ctx, 1, "secure-key", "secure-secret")
  75. require.NoError(t, err)
  76. assert.True(t, db.IsUserEnabled(ctx, 1))
  77. assert.False(t, db.IsUserEnabled(ctx, 2))
  78. }