Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 113 additions & 64 deletions pkg/commands/git_commands/working_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package git_commands

import (
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
Expand Down Expand Up @@ -116,64 +115,124 @@ func (self *WorkingTreeCommands) BeforeAndAfterFileForRename(file *models.File)
return beforeFile, afterFile, nil
}

// DiscardAllFileChanges directly
func (self *WorkingTreeCommands) DiscardAllFileChanges(file *models.File) error {
if file.IsRename() {
beforeFile, afterFile, err := self.BeforeAndAfterFileForRename(file)
if err != nil {
return err
// DiscardAllFilesChanges discards changes for multiple files in batch
func (self *WorkingTreeCommands) DiscardAllFilesChanges(files []*models.File) error {
// Group files by their discard strategy
var (
aaStatusFiles []*models.File
duStatusFiles []*models.File
filesToReset []*models.File
addedFilesToRemove []*models.File
filesToCheckout []*models.File
)

// Helper function to categorize a file into the appropriate group
categorizeFile := func(file *models.File) {
if file.ShortStatus == "AA" {
aaStatusFiles = append(aaStatusFiles, file)
return
}

if err := self.DiscardAllFileChanges(beforeFile); err != nil {
return err
if file.ShortStatus == "DU" {
duStatusFiles = append(duStatusFiles, file)
return
}

if err := self.DiscardAllFileChanges(afterFile); err != nil {
return err
// Track which files need to be reset first
needsReset := file.HasStagedChanges || file.HasMergeConflicts
if needsReset {
filesToReset = append(filesToReset, file)
}

if file.ShortStatus == "DD" || file.ShortStatus == "AU" {
} else if file.Added {
addedFilesToRemove = append(addedFilesToRemove, file)
} else {
filesToCheckout = append(filesToCheckout, file)
}
}

for _, file := range files {
if file.IsRename() {
// Get the before and after files for the rename and add them to the appropriate groups
beforeFile, afterFile, err := self.BeforeAndAfterFileForRename(file)
if err != nil {
return err
}
categorizeFile(beforeFile)
categorizeFile(afterFile)
continue
}

return nil
categorizeFile(file)
}

if file.ShortStatus == "AA" {
// Batch reset files that need resetting
if len(filesToReset) > 0 {
paths := make([]string, len(filesToReset))
for i, file := range filesToReset {
paths[i] = file.Path
}
if err := self.cmd.New(
NewGitCmd("checkout").Arg("--ours", "--", file.Path).ToArgv(),
NewGitCmd("reset").Arg("--").Arg(paths...).ToArgv(),
).Run(); err != nil {
return err
}
}

// Batch remove DU status files
if len(duStatusFiles) > 0 {
paths := make([]string, len(duStatusFiles))
for i, file := range duStatusFiles {
paths[i] = file.Path
}
if err := self.cmd.New(
NewGitCmd("add").Arg("--", file.Path).ToArgv(),
NewGitCmd("rm").Arg("--").Arg(paths...).ToArgv(),
).Run(); err != nil {
return err
}
return nil
}

if file.ShortStatus == "DU" {
return self.cmd.New(
NewGitCmd("rm").Arg("--", file.Path).ToArgv(),
).Run()
}

// if the file isn't tracked, we assume you want to delete it
if file.HasStagedChanges || file.HasMergeConflicts {
// Batch checkout --ours for AA status files
if len(aaStatusFiles) > 0 {
paths := make([]string, len(aaStatusFiles))
for i, file := range aaStatusFiles {
paths[i] = file.Path
}
if err := self.cmd.New(
NewGitCmd("checkout").Arg("--ours", "--").Arg(paths...).ToArgv(),
).Run(); err != nil {
return err
}
// Stage them after checkout
if err := self.cmd.New(
NewGitCmd("reset").Arg("--", file.Path).ToArgv(),
NewGitCmd("add").Arg("--").Arg(paths...).ToArgv(),
).Run(); err != nil {
return err
}
}

if file.ShortStatus == "DD" || file.ShortStatus == "AU" {
return nil
// Remove added files from filesystem
for _, file := range addedFilesToRemove {
if err := self.os.RemoveFile(file.Path); err != nil {
return err
}
}

if file.Added {
return self.os.RemoveFile(file.Path)
// Batch checkout other files
if len(filesToCheckout) > 0 {
paths := make([]string, len(filesToCheckout))
for i, file := range filesToCheckout {
paths[i] = file.Path
}
if err := self.cmd.New(
NewGitCmd("checkout").Arg("--").Arg(paths...).ToArgv(),
).Run(); err != nil {
return err
}
}

return self.DiscardUnstagedFileChanges(file)
return nil
}

type IFileNode interface {
Expand All @@ -184,55 +243,45 @@ type IFileNode interface {
GetFile() *models.File
}

func (self *WorkingTreeCommands) DiscardAllDirChanges(node IFileNode) error {
// this could be more efficient but we would need to handle all the edge cases
return node.ForEachFile(self.DiscardAllFileChanges)
}

func (self *WorkingTreeCommands) DiscardUnstagedDirChanges(node IFileNode) error {
file := node.GetFile()
if file == nil {
if err := self.RemoveUntrackedDirFiles(node); err != nil {
return err
}
// DiscardUnstagedFilesChanges discards unstaged changes for multiple files in batch
func (self *WorkingTreeCommands) DiscardUnstagedFilesChanges(files []*models.File) error {
var (
addedFilesToRemove []*models.File
trackedFilesToCheckout []*models.File
)

cmdArgs := NewGitCmd("checkout").Arg("--", node.GetPath()).ToArgv()
if err := self.cmd.New(cmdArgs).Run(); err != nil {
return err
}
} else {
for _, file := range files {
// Only remove files that are added but not staged
if file.Added && !file.HasStagedChanges {
return self.os.RemoveFile(file.Path)
addedFilesToRemove = append(addedFilesToRemove, file)
} else {
// Checkout tracked files to discard unstaged changes
trackedFilesToCheckout = append(trackedFilesToCheckout, file)
}
}

if err := self.DiscardUnstagedFileChanges(file); err != nil {
// Remove added files from filesystem
for _, file := range addedFilesToRemove {
if err := self.os.RemoveFile(file.Path); err != nil {
return err
}
}

return nil
}

func (self *WorkingTreeCommands) RemoveUntrackedDirFiles(node IFileNode) error {
untrackedFilePaths := node.GetFilePathsMatching(
func(file *models.File) bool { return !file.GetIsTracked() },
)

for _, path := range untrackedFilePaths {
err := os.Remove(path)
if err != nil {
// Batch checkout tracked files
if len(trackedFilesToCheckout) > 0 {
paths := make([]string, len(trackedFilesToCheckout))
for i, file := range trackedFilesToCheckout {
paths[i] = file.Path
}
cmdArgs := NewGitCmd("checkout").Arg("--").Arg(paths...).ToArgv()
if err := self.cmd.New(cmdArgs).Run(); err != nil {
return err
}
}

return nil
}

func (self *WorkingTreeCommands) DiscardUnstagedFileChanges(file *models.File) error {
cmdArgs := NewGitCmd("checkout").Arg("--", file.Path).ToArgv()
return self.cmd.New(cmdArgs).Run()
}

// Escapes special characters in a filename for gitignore and exclude files
func escapeFilename(filename string) string {
re := regexp.MustCompile(`^[!#]|[\[\]*]`)
Expand Down
6 changes: 3 additions & 3 deletions pkg/commands/git_commands/working_tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func TestWorkingTreeUnstageFile(t *testing.T) {
// these tests don't cover everything, in part because we already have an integration
// test which does cover everything. I don't want to unnecessarily assert on the 'how'
// when the 'what' is what matters
func TestWorkingTreeDiscardAllFileChanges(t *testing.T) {
func TestWorkingTreeDiscardAllFilesChanges(t *testing.T) {
type scenario struct {
testName string
file *models.File
Expand Down Expand Up @@ -190,7 +190,7 @@ func TestWorkingTreeDiscardAllFileChanges(t *testing.T) {
for _, s := range scenarios {
t.Run(s.testName, func(t *testing.T) {
instance := buildWorkingTreeCommands(commonDeps{runner: s.runner, removeFile: s.removeFile})
err := instance.DiscardAllFileChanges(s.file)
err := instance.DiscardAllFilesChanges([]*models.File{s.file})

if s.expectedError == "" {
assert.Nil(t, err)
Expand Down Expand Up @@ -476,7 +476,7 @@ func TestWorkingTreeDiscardUnstagedFileChanges(t *testing.T) {
for _, s := range scenarios {
t.Run(s.testName, func(t *testing.T) {
instance := buildWorkingTreeCommands(commonDeps{runner: s.runner})
s.test(instance.DiscardUnstagedFileChanges(s.file))
s.test(instance.DiscardUnstagedFilesChanges([]*models.File{s.file}))
s.runner.CheckForMissingCalls()
})
}
Expand Down
24 changes: 22 additions & 2 deletions pkg/gui/controllers/files_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -1382,12 +1382,22 @@ func (self *FilesController) remove(selectedNodes []*filetree.FileNode) error {
defer self.context().CancelRangeSelect()
}

// Collect all files from the selected nodes
var files []*models.File
for _, node := range selectedNodes {
if err := self.c.Git().WorkingTree.DiscardAllDirChanges(node); err != nil {
if err := node.ForEachFile(func(file *models.File) error {
files = append(files, file)
return nil
}); err != nil {
return err
}
}

// TODO: Send all nodes to delete untracked directories
if err := self.c.Git().WorkingTree.DiscardAllFilesChanges(files); err != nil {
return err
}

self.c.Refresh(types.RefreshOptions{Mode: types.ASYNC, Scope: []types.RefreshableView{types.FILES, types.WORKTREES}})
return nil
},
Expand All @@ -1409,12 +1419,22 @@ func (self *FilesController) remove(selectedNodes []*filetree.FileNode) error {
defer self.context().CancelRangeSelect()
}

// Collect all files from the selected nodes
var files []*models.File
for _, node := range selectedNodes {
if err := self.c.Git().WorkingTree.DiscardUnstagedDirChanges(node); err != nil {
if err := node.ForEachFile(func(file *models.File) error {
files = append(files, file)
return nil
}); err != nil {
return err
}
}

// TODO: Send all nodes to delete untracked directories
if err := self.c.Git().WorkingTree.DiscardUnstagedFilesChanges(files); err != nil {
return err
}

self.c.Refresh(types.RefreshOptions{Mode: types.ASYNC, Scope: []types.RefreshableView{types.FILES, types.WORKTREES}})
return nil
},
Expand Down
Loading