Skip to content
This repository was archived by the owner on Dec 17, 2021. It is now read-only.

Commit 1b8b9ff

Browse files
committed
commands/scp: Add
1 parent 1925f9a commit 1b8b9ff

File tree

2 files changed

+371
-0
lines changed

2 files changed

+371
-0
lines changed

commands/scp.go

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
package commands
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"os"
7+
"os/exec"
8+
"strings"
9+
10+
"github.com/docker/machine/libmachine"
11+
"github.com/docker/machine/libmachine/log"
12+
"github.com/docker/machine/libmachine/persist"
13+
)
14+
15+
var (
16+
errWrongNumberArguments = errors.New("Improper number of arguments")
17+
18+
// TODO: possibly move this to ssh package
19+
baseSSHArgs = []string{
20+
"-o", "StrictHostKeyChecking=no",
21+
"-o", "UserKnownHostsFile=/dev/null",
22+
"-o", "LogLevel=quiet", // suppress "Warning: Permanently added '[localhost]:2022' (ECDSA) to the list of known hosts."
23+
}
24+
)
25+
26+
// HostInfo gives the mandatory information to connect to a host.
27+
type HostInfo interface {
28+
GetMachineName() string
29+
30+
GetSSHHostname() (string, error)
31+
32+
GetSSHPort() (int, error)
33+
34+
GetSSHUsername() string
35+
36+
GetSSHKeyPath() string
37+
}
38+
39+
// HostInfoLoader loads host information.
40+
type HostInfoLoader interface {
41+
load(name string) (HostInfo, error)
42+
}
43+
44+
type storeHostInfoLoader struct {
45+
store persist.Store
46+
}
47+
48+
func (s *storeHostInfoLoader) load(name string) (HostInfo, error) {
49+
host, err := s.store.Load(name)
50+
if err != nil {
51+
return nil, fmt.Errorf("Error loading host: %s", err)
52+
}
53+
54+
return host.Driver, nil
55+
}
56+
57+
func cmdScp(c CommandLine, api libmachine.API) error {
58+
args := c.Args()
59+
if len(args) != 2 {
60+
c.ShowHelp()
61+
return errWrongNumberArguments
62+
}
63+
64+
src := args[0]
65+
dest := args[1]
66+
67+
hostInfoLoader := &storeHostInfoLoader{api}
68+
69+
cmd, err := getScpCmd(src, dest, c.Bool("recursive"), c.Bool("delta"), hostInfoLoader)
70+
if err != nil {
71+
return err
72+
}
73+
74+
return runCmdWithStdIo(*cmd)
75+
}
76+
77+
func getScpCmd(src, dest string, recursive bool, delta bool, hostInfoLoader HostInfoLoader) (*exec.Cmd, error) {
78+
var cmdPath string
79+
var err error
80+
if !delta {
81+
cmdPath, err = exec.LookPath("scp")
82+
if err != nil {
83+
return nil, errors.New("You must have a copy of the scp binary locally to use the scp feature")
84+
}
85+
} else {
86+
cmdPath, err = exec.LookPath("rsync")
87+
if err != nil {
88+
return nil, errors.New("You must have a copy of the rsync binary locally to use the --delta option")
89+
}
90+
}
91+
92+
srcHost, srcPath, srcOpts, err := getInfoForScpArg(src, hostInfoLoader)
93+
if err != nil {
94+
return nil, err
95+
}
96+
97+
destHost, destPath, destOpts, err := getInfoForScpArg(dest, hostInfoLoader)
98+
if err != nil {
99+
return nil, err
100+
}
101+
102+
// TODO: Check that "-3" flag is available in user's version of scp.
103+
// It is on every system I've checked, but the manual mentioned it's "newer"
104+
sshArgs := baseSSHArgs
105+
if !delta {
106+
sshArgs = append(sshArgs, "-3")
107+
if recursive {
108+
sshArgs = append(sshArgs, "-r")
109+
}
110+
}
111+
112+
// Don't use ssh-agent if both hosts have explicit ssh keys
113+
if !missesExplicitSSHKey(srcHost) && !missesExplicitSSHKey(destHost) {
114+
sshArgs = append(sshArgs, "-o", "IdentitiesOnly=yes")
115+
}
116+
117+
// Append needed -i / private key flags to command.
118+
sshArgs = append(sshArgs, srcOpts...)
119+
sshArgs = append(sshArgs, destOpts...)
120+
121+
// Append actual arguments for the scp command (i.e. docker@<ip>:/path)
122+
locationArg, err := generateLocationArg(srcHost, srcPath)
123+
if err != nil {
124+
return nil, err
125+
}
126+
127+
if delta {
128+
sshArgs = append([]string{"-e"}, "ssh "+strings.Join(sshArgs, " "))
129+
if recursive {
130+
sshArgs = append(sshArgs, "-r")
131+
}
132+
}
133+
134+
sshArgs = append(sshArgs, locationArg)
135+
locationArg, err = generateLocationArg(destHost, destPath)
136+
if err != nil {
137+
return nil, err
138+
}
139+
sshArgs = append(sshArgs, locationArg)
140+
141+
cmd := exec.Command(cmdPath, sshArgs...)
142+
log.Debug(*cmd)
143+
return cmd, nil
144+
}
145+
146+
func missesExplicitSSHKey(hostInfo HostInfo) bool {
147+
return hostInfo != nil && hostInfo.GetSSHKeyPath() == ""
148+
}
149+
150+
func getInfoForScpArg(hostAndPath string, hostInfoLoader HostInfoLoader) (HostInfo, string, []string, error) {
151+
// Local path. e.g. "/tmp/foo"
152+
if !strings.Contains(hostAndPath, ":") {
153+
return nil, hostAndPath, nil, nil
154+
}
155+
156+
// Path with hostname. e.g. "hostname:/usr/bin/cmatrix"
157+
parts := strings.SplitN(hostAndPath, ":", 2)
158+
hostName := parts[0]
159+
path := parts[1]
160+
if hostName == "localhost" {
161+
return nil, path, nil, nil
162+
}
163+
164+
// Remote path
165+
hostInfo, err := hostInfoLoader.load(hostName)
166+
if err != nil {
167+
return nil, "", nil, fmt.Errorf("Error loading host: %s", err)
168+
}
169+
170+
args := []string{}
171+
port, err := hostInfo.GetSSHPort()
172+
if err == nil && port > 0 {
173+
args = append(args, "-o", fmt.Sprintf("Port=%v", port))
174+
}
175+
176+
if hostInfo.GetSSHKeyPath() != "" {
177+
args = append(args, "-o", fmt.Sprintf("IdentityFile=%s", hostInfo.GetSSHKeyPath()))
178+
}
179+
180+
return hostInfo, path, args, nil
181+
}
182+
183+
func generateLocationArg(hostInfo HostInfo, path string) (string, error) {
184+
if hostInfo == nil {
185+
return path, nil
186+
}
187+
188+
hostname, err := hostInfo.GetSSHHostname()
189+
if err != nil {
190+
return "", err
191+
}
192+
193+
location := fmt.Sprintf("%s@%s:%s", hostInfo.GetSSHUsername(), hostname, path)
194+
return location, nil
195+
}
196+
197+
func runCmdWithStdIo(cmd exec.Cmd) error {
198+
cmd.Stdin = os.Stdin
199+
cmd.Stdout = os.Stdout
200+
cmd.Stderr = os.Stderr
201+
202+
return cmd.Run()
203+
}

commands/scp_test.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
package commands
2+
3+
import (
4+
"os/exec"
5+
"strings"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
type MockHostInfo struct {
12+
name string
13+
ip string
14+
sshPort int
15+
sshUsername string
16+
sshKeyPath string
17+
}
18+
19+
func (h *MockHostInfo) GetMachineName() string {
20+
return h.name
21+
}
22+
23+
func (h *MockHostInfo) GetSSHHostname() (string, error) {
24+
return h.ip, nil
25+
}
26+
27+
func (h *MockHostInfo) GetSSHPort() (int, error) {
28+
return h.sshPort, nil
29+
}
30+
31+
func (h *MockHostInfo) GetSSHUsername() string {
32+
return h.sshUsername
33+
}
34+
35+
func (h *MockHostInfo) GetSSHKeyPath() string {
36+
return h.sshKeyPath
37+
}
38+
39+
type MockHostInfoLoader struct {
40+
hostInfo MockHostInfo
41+
}
42+
43+
func (l *MockHostInfoLoader) load(name string) (HostInfo, error) {
44+
info := l.hostInfo
45+
info.name = name
46+
return &info, nil
47+
}
48+
49+
func TestGetInfoForLocalScpArg(t *testing.T) {
50+
host, path, opts, err := getInfoForScpArg("/tmp/foo", nil)
51+
assert.Nil(t, host)
52+
assert.Equal(t, "/tmp/foo", path)
53+
assert.Nil(t, opts)
54+
assert.NoError(t, err)
55+
56+
host, path, opts, err = getInfoForScpArg("localhost:C:\\path", nil)
57+
assert.Nil(t, host)
58+
assert.Equal(t, "C:\\path", path)
59+
assert.Nil(t, opts)
60+
assert.NoError(t, err)
61+
}
62+
63+
func TestGetInfoForRemoteScpArg(t *testing.T) {
64+
hostInfoLoader := MockHostInfoLoader{MockHostInfo{
65+
sshKeyPath: "/fake/keypath/id_rsa",
66+
}}
67+
68+
host, path, opts, err := getInfoForScpArg("myfunhost:/home/docker/foo", &hostInfoLoader)
69+
assert.Equal(t, "myfunhost", host.GetMachineName())
70+
assert.Equal(t, "/home/docker/foo", path)
71+
assert.Equal(t, []string{"-o", "IdentityFile=/fake/keypath/id_rsa"}, opts)
72+
assert.NoError(t, err)
73+
74+
host, path, opts, err = getInfoForScpArg("myfunhost:C:\\path", &hostInfoLoader)
75+
assert.Equal(t, "myfunhost", host.GetMachineName())
76+
assert.Equal(t, "C:\\path", path)
77+
assert.NoError(t, err)
78+
}
79+
80+
func TestHostLocation(t *testing.T) {
81+
arg, err := generateLocationArg(nil, "/home/docker/foo")
82+
83+
assert.Equal(t, "/home/docker/foo", arg)
84+
assert.NoError(t, err)
85+
}
86+
87+
func TestRemoteLocation(t *testing.T) {
88+
hostInfo := MockHostInfo{
89+
ip: "12.34.56.78",
90+
sshUsername: "root",
91+
}
92+
93+
arg, err := generateLocationArg(&hostInfo, "/home/docker/foo")
94+
95+
assert.Equal(t, "root@12.34.56.78:/home/docker/foo", arg)
96+
assert.NoError(t, err)
97+
}
98+
99+
func TestGetScpCmd(t *testing.T) {
100+
hostInfoLoader := MockHostInfoLoader{MockHostInfo{
101+
ip: "12.34.56.78",
102+
sshPort: 234,
103+
sshUsername: "root",
104+
sshKeyPath: "/fake/keypath/id_rsa",
105+
}}
106+
107+
cmd, err := getScpCmd("/tmp/foo", "myfunhost:/home/docker/foo", true, false, &hostInfoLoader)
108+
109+
expectedArgs := append(
110+
baseSSHArgs,
111+
"-3",
112+
"-r",
113+
"-o",
114+
"IdentitiesOnly=yes",
115+
"-o",
116+
"Port=234",
117+
"-o",
118+
"IdentityFile=/fake/keypath/id_rsa",
119+
"/tmp/foo",
120+
"root@12.34.56.78:/home/docker/foo",
121+
)
122+
expectedCmd := exec.Command("/usr/bin/scp", expectedArgs...)
123+
124+
assert.Equal(t, expectedCmd, cmd)
125+
assert.NoError(t, err)
126+
}
127+
128+
func TestGetScpCmdWithoutSshKey(t *testing.T) {
129+
hostInfoLoader := MockHostInfoLoader{MockHostInfo{
130+
ip: "1.2.3.4",
131+
sshUsername: "user",
132+
}}
133+
134+
cmd, err := getScpCmd("/tmp/foo", "myfunhost:/home/docker/foo", true, false, &hostInfoLoader)
135+
136+
expectedArgs := append(
137+
baseSSHArgs,
138+
"-3",
139+
"-r",
140+
"/tmp/foo",
141+
"user@1.2.3.4:/home/docker/foo",
142+
)
143+
expectedCmd := exec.Command("/usr/bin/scp", expectedArgs...)
144+
145+
assert.Equal(t, expectedCmd, cmd)
146+
assert.NoError(t, err)
147+
}
148+
149+
func TestGetScpCmdWithDelta(t *testing.T) {
150+
hostInfoLoader := MockHostInfoLoader{MockHostInfo{
151+
ip: "1.2.3.4",
152+
sshUsername: "user",
153+
}}
154+
155+
cmd, err := getScpCmd("/tmp/foo", "myfunhost:/home/docker/foo", true, true, &hostInfoLoader)
156+
157+
expectedArgs := append(
158+
[]string{"-e"},
159+
"ssh "+strings.Join(baseSSHArgs, " "),
160+
"-r",
161+
"/tmp/foo",
162+
"user@1.2.3.4:/home/docker/foo",
163+
)
164+
expectedCmd := exec.Command("/usr/bin/rsync", expectedArgs...)
165+
166+
assert.Equal(t, expectedCmd, cmd)
167+
assert.NoError(t, err)
168+
}

0 commit comments

Comments
 (0)