diff --git a/worktree.go b/worktree.go index e1a81a39b..c381c70e1 100644 --- a/worktree.go +++ b/worktree.go @@ -138,7 +138,9 @@ func (w *Worktree) PullContext(ctx context.Context, o *PullOptions) error { Mode: MergeReset, Commit: ref.Hash(), }); err != nil { - return err + // revert to the previous HEAD in case of errors + revErr := w.updateHEAD(head.Hash()) + return errors.Join(err, revErr) } if o.RecurseSubmodules != NoRecurseSubmodules { @@ -285,37 +287,72 @@ func (w *Worktree) ResetSparsely(opts *ResetOptions, dirs []string) error { return err } - if opts.Mode == MergeReset { - unstaged, err := w.containsUnstagedChanges() + t, err := w.r.getTreeFromCommitHash(opts.Commit) + if err != nil { + return err + } + + changes, err := w.diffTreeWithStaging(t, true) + if err != nil { + return err + } + var changedFiles []string + for _, ch := range changes { + a, err := ch.Action() if err != nil { return err } - - if unstaged { - return ErrUnstagedChanges + var name string + switch a { + case merkletrie.Modify, merkletrie.Insert: + name = ch.To.String() + case merkletrie.Delete: + name = ch.From.String() } + changedFiles = append(changedFiles, name) } - if err := w.setHEADCommit(opts.Commit); err != nil { - return err - } + if opts.Mode == MergeReset { + ch, err := w.diffStagingWithWorktree(false, true) + if err != nil { + return err + } - if opts.Mode == SoftReset { - return nil - } + for _, c := range ch { + a, err := c.Action() + if err != nil { + return err + } - t, err := w.r.getTreeFromCommitHash(opts.Commit) - if err != nil { - return err + var name string + switch a { + case merkletrie.Modify, merkletrie.Insert: + name = c.To.String() + case merkletrie.Delete: + name = c.From.String() + } + + if inFiles(changedFiles, name) { + return ErrUnstagedChanges + } + } } var removedFiles []string if opts.Mode == MixedReset || opts.Mode == MergeReset || opts.Mode == HardReset { - if removedFiles, err = w.resetIndex(t, dirs, opts.Files); err != nil { + if removedFiles, err = w.resetIndex(t, dirs, opts.Files, changes); err != nil { return err } } + if err := w.setHEADCommit(opts.Commit); err != nil { + return err + } + + if opts.Mode == SoftReset { + return nil + } + if opts.Mode == MergeReset || opts.Mode == HardReset { if err := w.resetWorktree(t, removedFiles); err != nil { return err @@ -370,7 +407,7 @@ func (w *Worktree) Reset(opts *ResetOptions) error { return w.ResetSparsely(opts, nil) } -func (w *Worktree) resetIndex(t *object.Tree, dirs []string, files []string) ([]string, error) { +func (w *Worktree) resetIndex(t *object.Tree, dirs []string, files []string, changes merkletrie.Changes) ([]string, error) { idx, err := w.r.Storer.Index() if err != nil { return nil, err @@ -378,11 +415,6 @@ func (w *Worktree) resetIndex(t *object.Tree, dirs []string, files []string) ([] b := newIndexBuilder(idx) - changes, err := w.diffTreeWithStaging(t, true) - if err != nil { - return nil, err - } - var removedFiles []string for _, ch := range changes { a, err := ch.Action() @@ -457,11 +489,6 @@ func (w *Worktree) resetWorktree(t *object.Tree, files []string) error { } b := newIndexBuilder(idx) - status, err := w.Status() - if err != nil { - return err - } - for _, ch := range changes { if err := w.validChange(ch); err != nil { return err @@ -485,13 +512,8 @@ func (w *Worktree) resetWorktree(t *object.Tree, files []string) error { } } - // only checkout an untracked file if it is in the list of files - // a reset should leave untracked files alone - file := nameFromAction(&ch) - if status.File(file).Worktree != Untracked || inFiles(files, file) { - if err := w.checkoutChange(ch, t, b); err != nil { - return err - } + if err := w.checkoutChange(ch, t, b); err != nil { + return err } } @@ -631,28 +653,6 @@ func (w *Worktree) checkoutChange(ch merkletrie.Change, t *object.Tree, idx *ind return w.checkoutChangeRegularFile(name, a, t, e, idx) } -func (w *Worktree) containsUnstagedChanges() (bool, error) { - ch, err := w.diffStagingWithWorktree(false, true) - if err != nil { - return false, err - } - - for _, c := range ch { - a, err := c.Action() - if err != nil { - return false, err - } - - if a == merkletrie.Insert { - continue - } - - return true, nil - } - - return false, nil -} - func (w *Worktree) setHEADCommit(commit plumbing.Hash) error { head, err := w.r.Reference(plumbing.HEAD, false) if err != nil { diff --git a/worktree_test.go b/worktree_test.go index 7c7a12faa..2c7dc4fbe 100644 --- a/worktree_test.go +++ b/worktree_test.go @@ -1149,7 +1149,7 @@ func (s *WorktreeSuite) TestResetWithUntracked() { continue } if st.Worktree != Unmodified || st.Staging != Unmodified { - s.Fail("file %s not unmodified", file) + s.Failf("file %s not unmodified", file) } } } @@ -1217,14 +1217,14 @@ func (s *WorktreeSuite) TestResetMerge() { err := w.Checkout(&CheckoutOptions{}) s.NoError(err) - err = w.Reset(&ResetOptions{Mode: MergeReset, Commit: commitA}) + err = w.Reset(&ResetOptions{Mode: SoftReset, Commit: commitA}) s.NoError(err) branch, err := w.r.Reference(plumbing.Master, false) s.NoError(err) s.Equal(commitA, branch.Hash()) - f, err := fs.Create(".gitignore") + f, err := fs.Create("vendor/foo.go") s.NoError(err) _, err = f.Write([]byte("foo")) s.NoError(err)