Browse Source

Check access before watch repo

Joe Chen 1 year ago
parent
commit
7d4c5b47a2

+ 14 - 5
internal/db/organizations_test.go

@@ -472,11 +472,6 @@ func orgsRemoveMember(t *testing.T, db *organizations) {
 	reposStore := NewRepositoriesStore(db.DB)
 	repo1, err := reposStore.Create(ctx, org1.ID, CreateRepoOptions{Name: "repo1", Private: true})
 	require.NoError(t, err)
-	err = reposStore.Watch(ctx, bob.ID, repo1.ID)
-	require.NoError(t, err)
-	permsStore := NewPermsStore(db.DB)
-	err = permsStore.SetRepoPerms(ctx, repo1.ID, map[int64]AccessMode{bob.ID: AccessModeRead})
-	require.NoError(t, err)
 	// TODO: Use Repositories.AddCollaborator to replace SQL hack when the method is available.
 	err = db.DB.Create(
 		&Collaboration{
@@ -516,6 +511,20 @@ func orgsRemoveMember(t *testing.T, db *organizations) {
 	).Error
 	require.NoError(t, err)
 
+	permsStore := NewPermsStore(db.DB)
+	err = permsStore.SetRepoPerms(ctx, repo1.ID, map[int64]AccessMode{bob.ID: AccessModeRead})
+	require.NoError(t, err)
+	err = reposStore.Watch(
+		ctx,
+		WatchRepositoryOptions{
+			UserID:        bob.ID,
+			RepoID:        repo1.ID,
+			RepoOwnerID:   repo1.OwnerID,
+			RepoIsPrivate: repo1.IsPrivate,
+		},
+	)
+	require.NoError(t, err)
+
 	// Pull the trigger
 	err = db.RemoveMember(ctx, org1.ID, bob.ID)
 	require.NoError(t, err)

+ 37 - 6
internal/db/repositories.go

@@ -50,7 +50,7 @@ type RepositoriesStore interface {
 	// ListWatches returns all watches of the given repository.
 	ListWatches(ctx context.Context, repoID int64) ([]*Watch, error)
 	// Watch marks the user to watch the repository.
-	Watch(ctx context.Context, userID, repoID int64) error
+	Watch(ctx context.Context, opts WatchRepositoryOptions) error
 
 	// HasForkedBy returns true if the given repository has forked by the given user.
 	HasForkedBy(ctx context.Context, repoID, userID int64) bool
@@ -194,7 +194,15 @@ func (db *repositories) Create(ctx context.Context, ownerID int64, opts CreateRe
 			return errors.Wrap(err, "create")
 		}
 
-		err = NewRepositoriesStore(tx).Watch(ctx, ownerID, repo.ID)
+		err = NewRepositoriesStore(tx).Watch(
+			ctx,
+			WatchRepositoryOptions{
+				UserID:        ownerID,
+				RepoID:        repo.ID,
+				RepoOwnerID:   ownerID,
+				RepoIsPrivate: repo.IsPrivate,
+			},
+		)
 		if err != nil {
 			return errors.Wrap(err, "watch")
 		}
@@ -400,11 +408,34 @@ func (db *repositories) recountWatches(tx *gorm.DB, repoID int64) error {
 		Error
 }
 
-func (db *repositories) Watch(ctx context.Context, userID, repoID int64) error {
+type WatchRepositoryOptions struct {
+	UserID        int64
+	RepoID        int64
+	RepoOwnerID   int64
+	RepoIsPrivate bool
+}
+
+func (db *repositories) Watch(ctx context.Context, opts WatchRepositoryOptions) error {
+	// Make sure the user has access to the private repository
+	if opts.RepoIsPrivate &&
+		opts.UserID != opts.RepoOwnerID &&
+		!NewPermsStore(db.DB).Authorize(
+			ctx,
+			opts.UserID,
+			opts.RepoID,
+			AccessModeRead,
+			AccessModeOptions{
+				OwnerID: opts.RepoOwnerID,
+				Private: true,
+			},
+		) {
+		return errors.New("user does not have access to the repository")
+	}
+
 	return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 		w := &Watch{
-			UserID: userID,
-			RepoID: repoID,
+			UserID: opts.UserID,
+			RepoID: opts.RepoID,
 		}
 		result := tx.FirstOrCreate(w, w)
 		if result.Error != nil {
@@ -413,7 +444,7 @@ func (db *repositories) Watch(ctx context.Context, userID, repoID int64) error {
 			return nil // Relation already exists
 		}
 
-		return db.recountWatches(tx, repoID)
+		return db.recountWatches(tx, opts.RepoID)
 	})
 }
 

+ 57 - 11
internal/db/repositories_test.go

@@ -290,11 +290,20 @@ func reposTouch(t *testing.T, ctx context.Context, db *repositories) {
 }
 
 func reposListWatches(t *testing.T, ctx context.Context, db *repositories) {
-	err := db.Watch(ctx, 1, 1)
+	repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
 	require.NoError(t, err)
-	err = db.Watch(ctx, 2, 1)
+	_, err = db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"})
 	require.NoError(t, err)
-	err = db.Watch(ctx, 2, 2)
+
+	err = db.Watch(
+		ctx,
+		WatchRepositoryOptions{
+			UserID:        2,
+			RepoID:        repo1.ID,
+			RepoOwnerID:   repo1.OwnerID,
+			RepoIsPrivate: repo1.IsPrivate,
+		},
+	)
 	require.NoError(t, err)
 
 	got, err := db.ListWatches(ctx, 1)
@@ -314,16 +323,53 @@ func reposWatch(t *testing.T, ctx context.Context, db *repositories) {
 	repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"})
 	require.NoError(t, err)
 
-	err = db.Watch(ctx, 2, repo1.ID)
-	require.NoError(t, err)
+	t.Run("user does not have access to the repository", func(t *testing.T) {
+		repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1", Private: true})
+		require.NoError(t, err)
 
-	// It is OK to watch multiple times and just be noop.
-	err = db.Watch(ctx, 2, repo1.ID)
-	require.NoError(t, err)
+		err = db.Watch(
+			ctx,
+			WatchRepositoryOptions{
+				UserID:        2,
+				RepoID:        repo1.ID,
+				RepoOwnerID:   repo1.OwnerID,
+				RepoIsPrivate: repo1.IsPrivate,
+			},
+		)
+		require.Error(t, err)
+	})
 
-	repo1, err = db.GetByID(ctx, repo1.ID)
-	require.NoError(t, err)
-	assert.Equal(t, 2, repo1.NumWatches) // The owner is watching the repo by default.
+	t.Run("user has access to the repository", func(t *testing.T) {
+		repo2, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo2"})
+		require.NoError(t, err)
+
+		err = db.Watch(
+			ctx,
+			WatchRepositoryOptions{
+				UserID:        2,
+				RepoID:        repo2.ID,
+				RepoOwnerID:   repo2.OwnerID,
+				RepoIsPrivate: repo2.IsPrivate,
+			},
+		)
+		require.NoError(t, err)
+
+		// It is OK to watch multiple times and just be noop.
+		err = db.Watch(
+			ctx,
+			WatchRepositoryOptions{
+				UserID:        2,
+				RepoID:        repo2.ID,
+				RepoOwnerID:   repo2.OwnerID,
+				RepoIsPrivate: repo2.IsPrivate,
+			},
+		)
+		require.NoError(t, err)
+
+		repo2, err = db.GetByID(ctx, repo2.ID)
+		require.NoError(t, err)
+		assert.Equal(t, 2, repo2.NumWatches) // The owner is watching the repo by default.
+	})
 }
 
 func reposHasForkedBy(t *testing.T, ctx context.Context, db *repositories) {

+ 9 - 1
internal/db/users_test.go

@@ -513,7 +513,15 @@ func usersDeleteByID(t *testing.T, ctx context.Context, db *users) {
 	require.NoError(t, err)
 
 	// Mock watches, stars and follows
-	err = reposStore.Watch(ctx, testUser.ID, repo2.ID)
+	err = reposStore.Watch(
+		ctx,
+		WatchRepositoryOptions{
+			UserID:        testUser.ID,
+			RepoID:        repo2.ID,
+			RepoOwnerID:   repo2.OwnerID,
+			RepoIsPrivate: repo2.IsPrivate,
+		},
+	)
 	require.NoError(t, err)
 	err = reposStore.Star(ctx, testUser.ID, repo2.ID)
 	require.NoError(t, err)