Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion certc/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ func NewRoot(sk ed25519.PrivateKey) (*Cert, error) {
sk = priv
}

pk := sk.Public().(ed25519.PublicKey)

template := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().UnixMicro()),

Expand All @@ -44,14 +46,17 @@ func NewRoot(sk ed25519.PrivateKey) (*Cert, error) {

Subject: SharedSubject,

SubjectKeyId: pk,
AuthorityKeyId: pk,

BasicConstraintsValid: true,
IsCA: true,

KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
ExtKeyUsage: []x509.ExtKeyUsage{},
}

der, err := x509.CreateCertificate(rand.Reader, template, template, sk.Public(), sk)
der, err := x509.CreateCertificate(rand.Reader, template, template, pk, sk)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -98,6 +103,9 @@ func (c *Cert) NewServer(opts CertOpts) (*Cert, error) {
Issuer: parent.Subject,
Subject: subject,

AuthorityKeyId: parent.SubjectKeyId,
SubjectKeyId: pk,

DNSNames: opts.Domains,
IPAddresses: opts.IPs,

Expand Down Expand Up @@ -136,6 +144,9 @@ func (c *Cert) NewClient() (*Cert, error) {
Issuer: parent.Subject,
Subject: SharedSubject,

AuthorityKeyId: parent.SubjectKeyId,
SubjectKeyId: pk,

BasicConstraintsValid: false,
IsCA: false,

Expand Down
10 changes: 9 additions & 1 deletion certc/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"fmt"
"io"
"net"
Expand All @@ -21,8 +22,15 @@ func TestChain(t *testing.T) {
seed := make([]byte, ed25519.SeedSize)
_, err := io.ReadFull(rand.Reader, seed)
require.NoError(t, err)
fmt.Println("seed", hex.EncodeToString(seed))
priv := ed25519.NewKeyFromSeed(seed)
require.NoError(t, err)
fmt.Println("priv", hex.EncodeToString(priv))

pub := priv.Public().(ed25519.PublicKey)
fmt.Println("pub", hex.EncodeToString(pub))

root, err := NewRoot(ed25519.NewKeyFromSeed(seed))
root, err := NewRoot(priv)
require.NoError(t, err)

caPool, err := root.CertPool()
Expand Down
15 changes: 4 additions & 11 deletions client/destination.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import (
"context"
"crypto/ed25519"
"crypto/tls"
"crypto/x509"
"fmt"
"log/slog"
"net"
"time"

"github.com/connet-dev/connet/cryptoc"
"github.com/connet-dev/connet/model"
"github.com/connet-dev/connet/netc"
"github.com/connet-dev/connet/notify"
"github.com/connet-dev/connet/proto"
"github.com/connet-dev/connet/proto/pbconnect"
Expand Down Expand Up @@ -278,7 +278,7 @@ func (d *destinationConn) runConnect(ctx context.Context, stream *quic.Stream, r

connect.DestinationEncryption = pbconnect.RelayEncryptionScheme_TLS
connect.DestinationTls = &pbconnect.TLSConfiguration{
ClientName: d.dst.peer.serverCert.Leaf.DNSNames[0],
ClientName: netc.GenServerNameTLS(d.dst.peer.rootCert),
}
case encryption == model.DHXCPEncryption:
// get check peer public key
Expand Down Expand Up @@ -355,21 +355,14 @@ func (d *Destination) getSourceTLS(name string) (*tls.Config, error) {
}

for _, remote := range remotes {
switch cfg, err := newServerTLSConfig(remote.Peer.ServerCertificate); {
switch cfg, err := newServerTLSConfigInternal(remote.Peer.Certificate); {
case err != nil:
return nil, fmt.Errorf("source peer server cert: %w", err)
case cfg.name == name:
clientCert, err := x509.ParseCertificate(remote.Peer.ClientCertificate)
if err != nil {
return nil, fmt.Errorf("source peer client cert: %w", err)
}

clientCAs := x509.NewCertPool()
clientCAs.AddCert(clientCert)
return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{d.peer.serverCert},
ClientCAs: clientCAs,
ClientCAs: cfg.cas,
}, nil
}
}
Expand Down
23 changes: 12 additions & 11 deletions client/direct.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
Expand Down Expand Up @@ -50,7 +51,7 @@ func (s *vServer) dequeue(key model.Key, cert *x509.Certificate) *vClient {
s.mu.Lock()
defer s.mu.Unlock()

if exp, ok := s.clients[key]; ok && exp.cert.Equal(cert) {
if exp, ok := s.clients[key]; ok && bytes.Equal(exp.cert.SubjectKeyId, cert.AuthorityKeyId) {
delete(s.clients, key)
return exp
}
Expand All @@ -69,16 +70,16 @@ func (s *vServer) updateClientCA() {
s.clientCA.Store(clientCA)
}

func (s *DirectServer) addServerCert(cert tls.Certificate) {
serverName := cert.Leaf.DNSNames[0]
func (s *DirectServer) addServerCert(localServerCert tls.Certificate) {
localServerName := localServerCert.Leaf.DNSNames[0]

s.serversMu.Lock()
defer s.serversMu.Unlock()

s.logger.Debug("add server cert", "server", serverName, "cert", model.NewKey(cert.Leaf))
s.servers[serverName] = &vServer{
serverName: serverName,
serverCert: cert,
s.logger.Debug("add server cert", "server", localServerName, "cert", model.NewKey(localServerCert.Leaf))
s.servers[localServerName] = &vServer{
serverName: localServerName,
serverCert: localServerCert,
clients: map[model.Key]*vClient{},
}
}
Expand All @@ -90,9 +91,9 @@ func (s *DirectServer) getServer(serverName string) *vServer {
return s.servers[serverName]
}

func (s *DirectServer) expect(serverCert tls.Certificate, cert *x509.Certificate) (chan *quic.Conn, func()) {
key := model.NewKey(cert)
srv := s.getServer(serverCert.Leaf.DNSNames[0])
func (s *DirectServer) expect(localServerCert tls.Certificate, cert *x509.Certificate) (chan *quic.Conn, func()) {
key := model.NewKeyRaw(cert.SubjectKeyId)
srv := s.getServer(localServerCert.Leaf.DNSNames[0])

defer srv.updateClientCA()

Expand Down Expand Up @@ -161,7 +162,7 @@ func (s *DirectServer) runConn(conn *quic.Conn) {
}

cert := conn.ConnectionState().TLS.PeerCertificates[0]
key := model.NewKey(cert)
key := model.NewKeyRaw(cert.AuthorityKeyId)
s.logger.Debug("accepted conn", "server", srv.serverName, "cert", key, "remote", conn.RemoteAddr())

exp := srv.dequeue(key, cert)
Expand Down
57 changes: 37 additions & 20 deletions client/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type peer struct {
peerConns *notify.V[map[peerConnKey]*quic.Conn]

direct *DirectServer
rootCert tls.Certificate
serverCert tls.Certificate
clientCert tls.Certificate
logger *slog.Logger
Expand Down Expand Up @@ -67,13 +68,17 @@ func (s peerStyle) String() string {
}

func newPeer(direct *DirectServer, logger *slog.Logger, privateKey ed25519.PrivateKey) (*peer, error) {
root, err := certc.NewRoot(privateKey)
peerCert, err := certc.NewRoot(privateKey)
if err != nil {
return nil, err
}
peerTLSCert, err := peerCert.TLSCert()
if err != nil {
return nil, err
}

serverCert, err := root.NewServer(certc.CertOpts{
Domains: []string{netc.GenServerName("connet-direct")},
serverCert, err := peerCert.NewServer(certc.CertOpts{
Domains: []string{netc.GenServerNameTLS(peerTLSCert)},
})
if err != nil {
return nil, err
Expand All @@ -82,7 +87,7 @@ func newPeer(direct *DirectServer, logger *slog.Logger, privateKey ed25519.Priva
if err != nil {
return nil, err
}
clientCert, err := root.NewClient()
clientCert, err := peerCert.NewClient()
if err != nil {
return nil, err
}
Expand All @@ -93,15 +98,15 @@ func newPeer(direct *DirectServer, logger *slog.Logger, privateKey ed25519.Priva

return &peer{
self: notify.New(&pbclient.Peer{
ServerCertificate: serverTLSCert.Leaf.Raw,
ClientCertificate: clientTLSCert.Leaf.Raw,
Certificate: peerCert.Raw(),
}),
relays: notify.NewEmpty[[]*pbclient.Relay](),
relayConns: notify.New(map[relayID]*quic.Conn{}),
peers: notify.NewEmpty[[]*pbclient.RemotePeer](),
peerConns: notify.New(map[peerConnKey]*quic.Conn{}),

direct: direct,
rootCert: peerTLSCert,
serverCert: serverTLSCert,
clientCert: clientTLSCert,
logger: logger,
Expand All @@ -119,10 +124,9 @@ func (p *peer) isDirect() bool {
func (p *peer) setDirectAddrs(addrs []netip.AddrPort) {
p.self.Update(func(cp *pbclient.Peer) *pbclient.Peer {
return &pbclient.Peer{
Directs: pbmodel.AsAddrPorts(addrs),
RelayIds: cp.RelayIds,
ServerCertificate: cp.ServerCertificate,
ClientCertificate: cp.ClientCertificate,
Directs: pbmodel.AsAddrPorts(addrs),
RelayIds: cp.RelayIds,
Certificate: cp.Certificate,
}
})
}
Expand Down Expand Up @@ -159,7 +163,7 @@ func (p *peer) runRelays(ctx context.Context) error {

activeRelays[id] = struct{}{}

cfg, err := newServerTLSConfig(relay.ServerCertificate)
cfg, err := newServerTLSConfigPublic(relay.ServerCertificate)
if err != nil {
return err
}
Expand Down Expand Up @@ -192,10 +196,9 @@ func (p *peer) runShareRelays(ctx context.Context) error {
}
p.self.Update(func(cp *pbclient.Peer) *pbclient.Peer {
return &pbclient.Peer{
Directs: cp.Directs,
RelayIds: ids,
ServerCertificate: cp.ServerCertificate,
ClientCertificate: cp.ClientCertificate,
Directs: cp.Directs,
RelayIds: ids,
Certificate: cp.Certificate,
}
})
return nil
Expand Down Expand Up @@ -272,7 +275,21 @@ type serverTLSConfig struct {
cas *x509.CertPool
}

func newServerTLSConfig(serverCert []byte) (*serverTLSConfig, error) {
func newServerTLSConfigInternal(serverCert []byte) (*serverTLSConfig, error) {
cert, err := x509.ParseCertificate(serverCert)
if err != nil {
return nil, err
}
cas := x509.NewCertPool()
cas.AddCert(cert)
return &serverTLSConfig{
key: model.NewKey(cert),
name: netc.GenServerNameX509(cert),
cas: cas,
}, nil
}

func newServerTLSConfigPublic(serverCert []byte) (*serverTLSConfig, error) {
cert, err := x509.ParseCertificate(serverCert)
if err != nil {
return nil, err
Expand All @@ -296,14 +313,14 @@ func (p *peer) newECDHConfig() (*ecdh.PrivateKey, *pbconnect.ECDHConfiguration,
keyTime = append(keyTime, sk.PublicKey().Bytes()...)
keyTime = binary.BigEndian.AppendUint64(keyTime, uint64(time.Now().Nanosecond()))

certSK := p.serverCert.PrivateKey.(ed25519.PrivateKey)
certSK := p.rootCert.PrivateKey.(ed25519.PrivateKey)
signature, err := certSK.Sign(rand.Reader, keyTime, &ed25519.Options{})
if err != nil {
return nil, nil, fmt.Errorf("peer sign: %w", err)
}

return sk, &pbconnect.ECDHConfiguration{
ClientName: p.serverCert.Leaf.DNSNames[0],
ClientName: netc.GenServerNameTLS(p.rootCert),
KeyTime: keyTime,
Signature: signature,
}, nil
Expand All @@ -316,11 +333,11 @@ func (p *peer) getECDHPublicKey(cfg *pbconnect.ECDHConfiguration) (*ecdh.PublicK
}
var candidates []*x509.Certificate
for _, remote := range remotes {
cert, err := x509.ParseCertificate(remote.Peer.ServerCertificate)
cert, err := x509.ParseCertificate(remote.Peer.Certificate)
if err != nil {
return nil, err
}
if cert.DNSNames[0] == cfg.ClientName {
if netc.GenServerNameX509(cert) == cfg.ClientName {
candidates = append(candidates, cert)
}
}
Expand Down
4 changes: 2 additions & 2 deletions client/peer_direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ func (p *directPeer) runRemote(ctx context.Context) error {
return p.remote.Listen(ctx, func(remote *pbclient.RemotePeer) error {
if p.local.isDirect() && len(remote.Peer.Directs) > 0 {
if p.incoming == nil {
remoteClientCert, err := x509.ParseCertificate(remote.Peer.ClientCertificate)
remoteClientCert, err := x509.ParseCertificate(remote.Peer.Certificate)
if err != nil {
return fmt.Errorf("parse client certificate: %w", err)
}
p.incoming = newDirectPeerIncoming(ctx, p, remoteClientCert)
}

if p.outgoing == nil {
remoteServerConf, err := newServerTLSConfig(remote.Peer.ServerCertificate)
remoteServerConf, err := newServerTLSConfigInternal(remote.Peer.Certificate)
if err != nil {
return fmt.Errorf("parse server certificate: %w", err)
}
Expand Down
5 changes: 3 additions & 2 deletions client/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/connet-dev/connet/cryptoc"
"github.com/connet-dev/connet/model"
"github.com/connet-dev/connet/netc"
"github.com/connet-dev/connet/notify"
"github.com/connet-dev/connet/proto"
"github.com/connet-dev/connet/proto/pbconnect"
Expand Down Expand Up @@ -408,7 +409,7 @@ func (s *Source) dialStream(ctx context.Context, dest sourceConn, stream *quic.S

if slices.Contains(s.cfg.RelayEncryptions, model.TLSEncryption) {
connect.SourceTls = &pbconnect.TLSConfiguration{
ClientName: s.peer.serverCert.Leaf.DNSNames[0],
ClientName: netc.GenServerNameTLS(s.peer.rootCert),
}
}

Expand Down Expand Up @@ -487,7 +488,7 @@ func (s *Source) getDestinationTLS(name string) (*tls.Config, error) {
}

for _, remote := range remotes {
switch cfg, err := newServerTLSConfig(remote.Peer.ServerCertificate); {
switch cfg, err := newServerTLSConfigInternal(remote.Peer.Certificate); {
case err != nil:
return nil, fmt.Errorf("destination peer server cert: %w", err)
case cfg.name == name:
Expand Down
5 changes: 1 addition & 4 deletions control/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,12 +580,9 @@ func (s *clientStream) runErr(ctx context.Context) error {
}

func validatePeerCert(endpoint model.Endpoint, peer *pbclient.Peer) *pberror.Error {
if _, err := x509.ParseCertificate(peer.ClientCertificate); err != nil {
if _, err := x509.ParseCertificate(peer.Certificate); err != nil {
return pberror.NewError(pberror.Code_AnnounceInvalidClientCertificate, "'%s' client cert is invalid", endpoint)
}
if _, err := x509.ParseCertificate(peer.ServerCertificate); err != nil {
return pberror.NewError(pberror.Code_AnnounceInvalidServerCertificate, "'%s' server cert is invalid", endpoint)
}
return nil
}

Expand Down
1 change: 1 addition & 0 deletions examples/client-source.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ direct-addr = ":19193"
[client.sources.sws]
relay-encryptions = ["tls"]
url = "tcp://:9999"
private-key = "nccarlqoint2fcpcsah54ids7v3r5ol63jbe4i4v5d0dt5blf5a0"
Loading