Skip to content

Commit 5339c9c

Browse files
Fix WSH copy internal (#1906)
I was dumb and used `os.Rename` for copy on the same WSH remote. This change makes it a proper copy, with recursion if needed. --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 71767df commit 5339c9c

3 files changed

Lines changed: 156 additions & 66 deletions

File tree

pkg/remote/connparse/connparse.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,11 @@ func ParseURI(uri string) (*Connection, error) {
128128
}
129129
}
130130

131+
addPrecedingSlash := true
132+
131133
if scheme == "" {
132134
scheme = ConnectionTypeWsh
135+
addPrecedingSlash = false
133136
if len(rest) != len(uri) {
134137
// This accounts for when the uri starts with "//", which would get trimmed in the first split.
135138
parseWshPath()
@@ -152,7 +155,7 @@ func ParseURI(uri string) (*Connection, error) {
152155
}
153156
if strings.HasPrefix(remotePath, "/~") {
154157
remotePath = strings.TrimPrefix(remotePath, "/")
155-
} else if len(remotePath) > 1 && !windowsDriveRegex.MatchString(remotePath) && !strings.HasPrefix(remotePath, "/") && !strings.HasPrefix(remotePath, "~") && !strings.HasPrefix(remotePath, "./") && !strings.HasPrefix(remotePath, "../") && !strings.HasPrefix(remotePath, ".\\") && !strings.HasPrefix(remotePath, "..\\") && remotePath != ".." {
158+
} else if addPrecedingSlash && (len(remotePath) > 1 && !windowsDriveRegex.MatchString(remotePath) && !strings.HasPrefix(remotePath, "/") && !strings.HasPrefix(remotePath, "~") && !strings.HasPrefix(remotePath, "./") && !strings.HasPrefix(remotePath, "../") && !strings.HasPrefix(remotePath, ".\\") && !strings.HasPrefix(remotePath, "..\\") && remotePath != "..") {
156159
remotePath = "/" + remotePath
157160
}
158161
}

pkg/remote/connparse/connparse_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,50 @@ func TestParseURI_WSHCurrentPath(t *testing.T) {
212212
if c.GetFullURI() != expected {
213213
t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI())
214214
}
215+
216+
cstr = "path/to/file"
217+
c, err = connparse.ParseURI(cstr)
218+
if err != nil {
219+
t.Fatalf("failed to parse URI: %v", err)
220+
}
221+
expected = "path/to/file"
222+
if c.Path != expected {
223+
t.Fatalf("expected path to be %q, got %q", expected, c.Path)
224+
}
225+
expected = "current"
226+
if c.Host != expected {
227+
t.Fatalf("expected host to be %q, got %q", expected, c.Host)
228+
}
229+
expected = "wsh"
230+
if c.Scheme != expected {
231+
t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme)
232+
}
233+
expected = "wsh://current/path/to/file"
234+
if c.GetFullURI() != expected {
235+
t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI())
236+
}
237+
238+
cstr = "/etc/path/to/file"
239+
c, err = connparse.ParseURI(cstr)
240+
if err != nil {
241+
t.Fatalf("failed to parse URI: %v", err)
242+
}
243+
expected = "/etc/path/to/file"
244+
if c.Path != expected {
245+
t.Fatalf("expected path to be %q, got %q", expected, c.Path)
246+
}
247+
expected = "current"
248+
if c.Host != expected {
249+
t.Fatalf("expected host to be %q, got %q", expected, c.Host)
250+
}
251+
expected = "wsh"
252+
if c.Scheme != expected {
253+
t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme)
254+
}
255+
expected = "wsh://current/etc/path/to/file"
256+
if c.GetFullURI() != expected {
257+
t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI())
258+
}
215259
}
216260

217261
func TestParseURI_WSHCurrentPathWindows(t *testing.T) {

pkg/wshrpc/wshremote/wshremote.go

Lines changed: 108 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -348,11 +348,113 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C
348348
if err != nil {
349349
return fmt.Errorf("cannot parse source URI %q: %w", srcUri, err)
350350
}
351+
352+
copyFileFunc := func(path string, finfo fs.FileInfo, srcFile io.Reader) (int64, error) {
353+
destinfo, err = os.Stat(path)
354+
if err != nil && !errors.Is(err, fs.ErrNotExist) {
355+
return 0, fmt.Errorf("cannot stat file %q: %w", path, err)
356+
}
357+
358+
if destinfo != nil {
359+
if destinfo.IsDir() {
360+
if !finfo.IsDir() {
361+
// try to create file in directory
362+
path = filepath.Join(path, filepath.Base(finfo.Name()))
363+
newdestinfo, err := os.Stat(path)
364+
if err != nil && !errors.Is(err, fs.ErrNotExist) {
365+
return 0, fmt.Errorf("cannot stat file %q: %w", path, err)
366+
}
367+
if newdestinfo != nil && !overwrite {
368+
return 0, fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", path)
369+
}
370+
} else if !merge && !overwrite {
371+
return 0, fmt.Errorf("cannot create directory %q, directory exists at path, neither overwrite nor merge specified", path)
372+
} else if overwrite {
373+
err := os.RemoveAll(path)
374+
if err != nil {
375+
return 0, fmt.Errorf("cannot remove directory %q: %w", path, err)
376+
}
377+
}
378+
} else {
379+
if finfo.IsDir() {
380+
if !overwrite {
381+
return 0, fmt.Errorf("cannot create file %q, directory exists at path, overwrite not specified", path)
382+
} else {
383+
err := os.RemoveAll(path)
384+
if err != nil {
385+
return 0, fmt.Errorf("cannot remove directory %q: %w", path, err)
386+
}
387+
}
388+
} else if !overwrite {
389+
return 0, fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", path)
390+
}
391+
}
392+
}
393+
394+
if finfo.IsDir() {
395+
err := os.MkdirAll(path, finfo.Mode())
396+
if err != nil {
397+
return 0, fmt.Errorf("cannot create directory %q: %w", path, err)
398+
}
399+
} else {
400+
err := os.MkdirAll(filepath.Dir(path), 0755)
401+
if err != nil {
402+
return 0, fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(path), err)
403+
}
404+
}
405+
406+
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode())
407+
if err != nil {
408+
return 0, fmt.Errorf("cannot create new file %q: %w", path, err)
409+
}
410+
defer file.Close()
411+
_, err = io.Copy(file, srcFile)
412+
if err != nil {
413+
return 0, fmt.Errorf("cannot write file %q: %w", path, err)
414+
}
415+
416+
return finfo.Size(), nil
417+
}
418+
351419
if srcConn.Host == destConn.Host {
352420
srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path))
353-
err := os.Rename(srcPathCleaned, destPathCleaned)
421+
422+
srcFileStat, err := os.Stat(srcPathCleaned)
354423
if err != nil {
355-
return fmt.Errorf("cannot copy file %q to %q: %w", srcPathCleaned, destPathCleaned, err)
424+
return fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err)
425+
}
426+
427+
if srcFileStat.IsDir() {
428+
err = filepath.Walk(srcPathCleaned, func(path string, info fs.FileInfo, err error) error {
429+
if err != nil {
430+
return err
431+
}
432+
srcFilePath := path
433+
destFilePath := filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathCleaned))
434+
var file *os.File
435+
if !info.IsDir() {
436+
file, err = os.Open(srcFilePath)
437+
if err != nil {
438+
return fmt.Errorf("cannot open file %q: %w", srcFilePath, err)
439+
}
440+
defer file.Close()
441+
}
442+
_, err = copyFileFunc(destFilePath, info, file)
443+
return err
444+
})
445+
if err != nil {
446+
return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
447+
}
448+
} else {
449+
file, err := os.Open(srcPathCleaned)
450+
if err != nil {
451+
return fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err)
452+
}
453+
defer file.Close()
454+
_, err = copyFileFunc(destPathCleaned, srcFileStat, file)
455+
if err != nil {
456+
return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err)
457+
}
356458
}
357459
} else {
358460
timeout := DefaultTimeout
@@ -376,70 +478,11 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C
376478
}
377479
numFiles++
378480
finfo := next.FileInfo()
379-
nextPath := filepath.Join(destPathCleaned, next.Name)
380-
destinfo, err = os.Stat(nextPath)
381-
if err != nil && !errors.Is(err, fs.ErrNotExist) {
382-
return fmt.Errorf("cannot stat file %q: %w", nextPath, err)
383-
}
384-
if !finfo.IsDir() {
385-
totalBytes += finfo.Size()
386-
}
387-
388-
if destinfo != nil {
389-
if destinfo.IsDir() {
390-
if !finfo.IsDir() {
391-
if !overwrite {
392-
return fmt.Errorf("cannot create directory %q, file exists at path, overwrite not specified", nextPath)
393-
} else {
394-
err := os.Remove(nextPath)
395-
if err != nil {
396-
return fmt.Errorf("cannot remove file %q: %w", nextPath, err)
397-
}
398-
}
399-
} else if !merge && !overwrite {
400-
return fmt.Errorf("cannot create directory %q, directory exists at path, neither overwrite nor merge specified", nextPath)
401-
} else if overwrite {
402-
err := os.RemoveAll(nextPath)
403-
if err != nil {
404-
return fmt.Errorf("cannot remove directory %q: %w", nextPath, err)
405-
}
406-
}
407-
} else {
408-
if finfo.IsDir() {
409-
if !overwrite {
410-
return fmt.Errorf("cannot create file %q, directory exists at path, overwrite not specified", nextPath)
411-
} else {
412-
err := os.RemoveAll(nextPath)
413-
if err != nil {
414-
return fmt.Errorf("cannot remove directory %q: %w", nextPath, err)
415-
}
416-
}
417-
} else if !overwrite {
418-
return fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", nextPath)
419-
}
420-
}
421-
} else {
422-
if finfo.IsDir() {
423-
err := os.MkdirAll(nextPath, finfo.Mode())
424-
if err != nil {
425-
return fmt.Errorf("cannot create directory %q: %w", nextPath, err)
426-
}
427-
} else {
428-
err := os.MkdirAll(filepath.Dir(nextPath), 0755)
429-
if err != nil {
430-
return fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(nextPath), err)
431-
}
432-
file, err := os.OpenFile(nextPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode())
433-
if err != nil {
434-
return fmt.Errorf("cannot create new file %q: %w", nextPath, err)
435-
}
436-
_, err = io.Copy(file, reader)
437-
if err != nil {
438-
return fmt.Errorf("cannot write file %q: %w", nextPath, err)
439-
}
440-
file.Close()
441-
}
481+
n, err := copyFileFunc(filepath.Join(destPathCleaned, next.Name), finfo, reader)
482+
if err != nil {
483+
return fmt.Errorf("cannot copy file %q: %w", next.Name, err)
442484
}
485+
totalBytes += n
443486
return nil
444487
})
445488
if err != nil {

0 commit comments

Comments
 (0)