Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions _examples/ssh-proxy-protocol/proxy_protocol.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package main

import (
"fmt"
"io"
"log"

"github.com/gliderlabs/ssh"
)

const (
ADDR = "0.0.0.0:4444"
)

func main() {
ssh.Handle(func(s ssh.Session) {
io.WriteString(s, fmt.Sprintf("Your address is %s\n", s.RemoteAddr()))
})

log.Println("starting ssh server on " + ADDR)
log.Fatal(ssh.ListenAndServe(ADDR, nil, ssh.EnableProxyProtocol()))
}
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module github.com/gliderlabs/ssh

go 1.20
go 1.25

require (
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
github.com/pires/go-proxyproto v0.12.0
golang.org/x/crypto v0.31.0
)

Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/pires/go-proxyproto v0.12.0 h1:TTCxD66dU898tahivkqc3hoceZp7P44FnorWyo9d5vM=
github.com/pires/go-proxyproto v0.12.0/go.mod h1:qUvfqUMEoX7T8g0q7TQLDnhMjdTrxnG0hvpMn+7ePNI=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
8 changes: 8 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,11 @@ func WrapConn(fn ConnCallback) Option {
return nil
}
}

// EnableProxyProtocol returns a functional option that sets EnableProxyProtocol on the server
func EnableProxyProtocol() Option {
return func(srv *Server) error {
srv.EnableProxyProtocol = true
return nil
}
}
10 changes: 10 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"sync"
"time"

"github.com/pires/go-proxyproto"
gossh "golang.org/x/crypto/ssh"
)

Expand Down Expand Up @@ -56,6 +57,8 @@ type Server struct {
IdleTimeout time.Duration // connection timeout when no activity, none if empty
MaxTimeout time.Duration // absolute connection timeout, none if empty

EnableProxyProtocol bool // Enable support for HA Proxy's and NGinx's PROXY protocol

// ChannelHandlers allow overriding the built-in session handlers or provide
// extensions to the protocol, such as tcpip forwarding. By default only the
// "session" handler is enabled.
Expand Down Expand Up @@ -235,6 +238,13 @@ func (srv *Server) Shutdown(ctx context.Context) error {
//
// Serve always returns a non-nil error.
func (srv *Server) Serve(l net.Listener) error {
if srv.EnableProxyProtocol {
_, ok := l.(*proxyproto.Listener)
if !ok {
l = &proxyproto.Listener{Listener: l}
}
}

srv.ensureHandlers()
defer l.Close()
if err := srv.ensureHostSigner(); err != nil {
Expand Down
107 changes: 107 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@ package ssh
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"testing"
"time"

"github.com/pires/go-proxyproto"
gossh "golang.org/x/crypto/ssh"
)

func TestAddHostKey(t *testing.T) {
Expand Down Expand Up @@ -158,3 +165,103 @@ func TestServerHandshakeTimeout(t *testing.T) {
return
}
}

func TestProxyProtocol(t *testing.T) {
const (
CORRECT_IP = "1.1.1.1"
CORRECT_PORT = 55555
)
handlerDone := make(chan struct{})
var testResult error

handler := func(sess Session) {
defer close(handlerDone)
sourceAddress := sess.RemoteAddr()

index := strings.Index(sourceAddress.String(), ":")
ip := sourceAddress.String()[:index]
portStr := sourceAddress.String()[index+1:]

if ip != CORRECT_IP {
errorMsg := fmt.Sprintf("Expected source address '%s' but got '%s'", CORRECT_IP, ip)
testResult = errors.Join(testResult, fmt.Errorf("%s", errorMsg))
}
port, err := strconv.Atoi(portStr)
if err != nil {
testResult = errors.Join(testResult, fmt.Errorf("%s", err))
} else if port != CORRECT_PORT {
errorMsg := fmt.Sprintf("Expected source port '%d' but got '%d'", CORRECT_PORT, port)
testResult = errors.Join(testResult, fmt.Errorf("%s", errorMsg))
}
}

// Bind the port before starting the goroutine so net.Dial never races
// with the server not yet listening.
l := newLocalListener()
srv := &Server{Handler: handler}
srv.SetOption(EnableProxyProtocol())

serverDone := make(chan error, 1)

go func() {
serverDone <- srv.Serve(l)
}()

defer func() {
srv.Close()
if err := <-serverDone; err != nil && err != ErrServerClosed {
t.Error(err)
}
}()

serverIP, serverPortStr, _ := net.SplitHostPort(l.Addr().String())
serverPort, _ := strconv.Atoi(serverPortStr)
conn, err := net.Dial("tcp", l.Addr().String())

if err != nil {
t.Fatal(err)
}

header := &proxyproto.Header{
Version: 1,
Command: proxyproto.PROXY,
TransportProtocol: proxyproto.TCPv4,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP(CORRECT_IP),
Port: CORRECT_PORT,
},
DestinationAddr: &net.TCPAddr{
IP: net.ParseIP(serverIP),
Port: serverPort,
},
}

// Writes the PROXY header to the TCP stream before SSH begins
_, err = header.WriteTo(conn)
if err != nil {
t.Fatal(err)
}

// Hand the same conn to the SSH stack — handshake starts from here.
clientConn, chans, reqs, err := gossh.NewClientConn(conn, l.Addr().String(), &gossh.ClientConfig{
User: "testuser",
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
})
if err != nil {
t.Fatal(err)
}
client := gossh.NewClient(clientConn, chans, reqs)
defer client.Close()

session, err := client.NewSession()
if err != nil {
t.Fatal(err)
}
session.Run("") // triggers the handler; ignore exec error

<-handlerDone

if testResult != nil {
t.Fatal(testResult)
}
}