Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,25 @@ func EncryptAndSave(auth *CachedAuth, username string, secretKey *[32]byte) erro
return saveAuths(auths)
}

func authenticate(c *protonmail.Client, cachedAuth *CachedAuth, username string) (openpgp.EntityList, error) {
func authenticate(c *protonmail.Client, cachedAuth *CachedAuth, username string) (openpgp.EntityList, uint64, error) {
auth, err := c.AuthRefresh(&cachedAuth.Auth)
if apiErr, ok := err.(*protonmail.APIError); ok && apiErr.Code == 10013 {
// Invalid refresh token, re-authenticate
authInfo, err := c.AuthInfo(username)
if err != nil {
return nil, fmt.Errorf("cannot re-authenticate: failed to get auth info: %v", err)
return nil, 0, fmt.Errorf("cannot re-authenticate: failed to get auth info: %v", err)
}

auth, err = c.Auth(username, cachedAuth.LoginPassword, authInfo)
if err != nil {
return nil, fmt.Errorf("cannot re-authenticate: %v", err)
return nil, 0, fmt.Errorf("cannot re-authenticate: %v", err)
}

if auth.TwoFactor.Enabled != 0 {
return nil, fmt.Errorf("cannot re-authenticate: two factor authentication enabled, please login again manually")
return nil, 0, fmt.Errorf("cannot re-authenticate: two factor authentication enabled, please login again manually")
}
} else if err != nil {
return nil, err
return nil, 0, err
}
cachedAuth.Auth = *auth

Expand Down Expand Up @@ -171,6 +171,7 @@ type session struct {
hashedSecretKey []byte
c *protonmail.Client
privateKeys openpgp.EntityList
primaryKeyID uint64
}

var ErrUnauthorized = errors.New("Invalid username or password")
Expand All @@ -180,73 +181,74 @@ type Manager struct {
sessions map[string]*session
}

func (m *Manager) Auth(username, password string) (*protonmail.Client, openpgp.EntityList, error) {
func (m *Manager) Auth(username, password string) (*protonmail.Client, openpgp.EntityList, uint64, error) {
var secretKey [32]byte
passwordBytes, err := base64.StdEncoding.DecodeString(password)
if err != nil || len(passwordBytes) != len(secretKey) {
return nil, nil, ErrUnauthorized
return nil, nil, 0, ErrUnauthorized
}
copy(secretKey[:], passwordBytes)

s, ok := m.sessions[username]
if ok {
err := bcrypt.CompareHashAndPassword(s.hashedSecretKey, secretKey[:])
if err != nil {
return nil, nil, ErrUnauthorized
return nil, nil, 0, ErrUnauthorized
}
} else {
auths, err := readCachedAuths()
if err != nil && !os.IsNotExist(err) {
return nil, nil, err
return nil, nil, 0, err
}

encrypted, ok := auths[username]
if !ok {
return nil, nil, ErrUnauthorized
return nil, nil, 0, ErrUnauthorized
}

decrypted, err := decrypt(encrypted, &secretKey)
if err != nil {
return nil, nil, ErrUnauthorized
return nil, nil, 0, ErrUnauthorized
}

var cachedAuth CachedAuth
if err := json.Unmarshal(decrypted, &cachedAuth); err != nil {
return nil, nil, err
return nil, nil, 0, err
}

c := m.newClient()
c.ReAuth = func() error {
if _, err := authenticate(c, &cachedAuth, username); err != nil {
if _, _, err := authenticate(c, &cachedAuth, username); err != nil {
return err
}
return EncryptAndSave(&cachedAuth, username, &secretKey)
}

// authenticate updates cachedAuth with the new refresh token
privateKeys, err := authenticate(c, &cachedAuth, username)
privateKeys, primaryKeyID, err := authenticate(c, &cachedAuth, username)
if err != nil {
return nil, nil, err
return nil, nil, 0, err
}

if err := EncryptAndSave(&cachedAuth, username, &secretKey); err != nil {
return nil, nil, err
return nil, nil, 0, err
}

hashed, err := bcrypt.GenerateFromPassword(secretKey[:], bcrypt.DefaultCost)
if err != nil {
return nil, nil, err
return nil, nil, 0, err
}

s = &session{
c: c,
privateKeys: privateKeys,
hashedSecretKey: hashed,
primaryKeyID: primaryKeyID,
}
m.sessions[username] = s
}

return s.c, s.privateKeys, nil
return s.c, s.privateKeys, s.primaryKeyID, nil
}

func NewManager(newClient func() *protonmail.Client) *Manager {
Expand Down
47 changes: 36 additions & 11 deletions carddav/carddav.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"path"
"strconv"
Expand Down Expand Up @@ -141,11 +142,12 @@ func (b *backend) toAddressObject(contact *protonmail.Contact, req *carddav.Addr
}

type backend struct {
c *protonmail.Client
cache map[string]*protonmail.Contact
locker sync.Mutex
total int
privateKeys openpgp.EntityList
c *protonmail.Client
cache map[string]*protonmail.Contact
locker sync.Mutex
total int
privateKeys openpgp.EntityList
mainAccountKey *openpgp.Entity
}

func (b *backend) CurrentUserPrincipal(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -310,7 +312,7 @@ func (b *backend) PutAddressObject(ctx context.Context, path string, card vcard.
return nil, err
}

contactImport, err := formatCard(card, b.privateKeys[0])
contactImport, err := formatCard(card, b.mainAccountKey)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -397,16 +399,39 @@ func (b *backend) receiveEvents(events <-chan *protonmail.Event) {
}
}

func NewHandler(c *protonmail.Client, privateKeys openpgp.EntityList, events <-chan *protonmail.Event) http.Handler {
func NewHandler(c *protonmail.Client, privateKeys openpgp.EntityList, primaryKeyID uint64, events <-chan *protonmail.Event) http.Handler {
if len(privateKeys) == 0 {
panic("hydroxide/carddav: no private key available")
}

// Find the primary account key by matching primaryKeyID.
// Proton's web client uses the user key (not address key) to
// encrypt and sign contacts. On modern accounts these differ,
// so we must use the correct key to avoid "decryption failed"
// errors in the Proton web UI.
var mainAccountKey *openpgp.Entity
if primaryKeyID != 0 {
for _, entity := range privateKeys {
if entity.PrimaryKey != nil && entity.PrimaryKey.KeyId == primaryKeyID {
mainAccountKey = entity
break
}
}
}
// Fallback to first key if primary not found
if mainAccountKey == nil {
mainAccountKey = privateKeys[0]
if primaryKeyID != 0 {
log.Printf("warning: primary key ID %x not found in key ring, falling back to first key", primaryKeyID)
}
}

b := &backend{
c: c,
cache: make(map[string]*protonmail.Contact),
total: -1,
privateKeys: privateKeys,
c: c,
cache: make(map[string]*protonmail.Contact),
total: -1,
privateKeys: privateKeys,
mainAccountKey: mainAccountKey,
}

if events != nil {
Expand Down
14 changes: 7 additions & 7 deletions cmd/hydroxide/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func listenAndServeCardDAV(addr string, authManager *auth.Manager, eventsManager
return
}

c, privateKeys, err := authManager.Auth(username, password)
c, privateKeys, primaryKeyID, err := authManager.Auth(username, password)
if err != nil {
if err == auth.ErrUnauthorized {
resp.WriteHeader(http.StatusUnauthorized)
Expand All @@ -145,7 +145,7 @@ func listenAndServeCardDAV(addr string, authManager *auth.Manager, eventsManager
if !ok {
ch := make(chan *protonmail.Event)
eventsManager.Register(c, username, ch, nil)
h = carddav.NewHandler(c, privateKeys, ch)
h = carddav.NewHandler(c, privateKeys, primaryKeyID, ch)

handlers[username] = h
}
Expand Down Expand Up @@ -337,7 +337,7 @@ func main() {
log.Fatal(err)
}

_, err = c.Unlock(a, keySalts, mailboxPassword)
_, _, err = c.Unlock(a, keySalts, mailboxPassword)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -384,7 +384,7 @@ func main() {
log.Fatal(err)
}

_, privateKeys, err := auth.NewManager(newClient).Auth(username, bridgePassword)
_, privateKeys, _, err := auth.NewManager(newClient).Auth(username, bridgePassword)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -425,7 +425,7 @@ func main() {
log.Fatal(err)
}

c, _, err := auth.NewManager(newClient).Auth(username, bridgePassword)
c, _, _, err := auth.NewManager(newClient).Auth(username, bridgePassword)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -467,7 +467,7 @@ func main() {
log.Fatal(err)
}

c, privateKeys, err := auth.NewManager(newClient).Auth(username, bridgePassword)
c, privateKeys, _, err := auth.NewManager(newClient).Auth(username, bridgePassword)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -544,7 +544,7 @@ func main() {
log.Fatal(err)
}

c, privateKeys, err := auth.NewManager(newClient).Auth(username, bridgePassword)
c, privateKeys, _, err := auth.NewManager(newClient).Auth(username, bridgePassword)
if err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion imap/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type backend struct {
}

func (be *backend) Login(info *imap.ConnInfo, username, password string) (imapbackend.User, error) {
c, privateKeys, err := be.sessions.Auth(username, password)
c, privateKeys, _, err := be.sessions.Auth(username, password)
if err != nil {
return nil, err
}
Expand Down
33 changes: 21 additions & 12 deletions protonmail/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,9 @@ func unlockPrivateKey(key *PrivateKey, userKeyRing openpgp.EntityList, keySalt [
return entity, nil
}

func unlockKeyRing(keys []*PrivateKey, userKeyRing openpgp.EntityList, keySalts map[string][]byte, passphraseBytes []byte) (openpgp.EntityList, error) {
func unlockKeyRing(keys []*PrivateKey, userKeyRing openpgp.EntityList, keySalts map[string][]byte, passphraseBytes []byte) (openpgp.EntityList, uint64, error) {
var keyRing openpgp.EntityList
var primaryKeyID uint64
for _, key := range keys {
if key.Active != 1 {
continue
Expand All @@ -319,37 +320,43 @@ func unlockKeyRing(keys []*PrivateKey, userKeyRing openpgp.EntityList, keySalts
continue
}

if key.Primary == 1 && entity.PrimaryKey != nil {
primaryKeyID = entity.PrimaryKey.KeyId
}

keyRing = append(keyRing, entity)
}

if len(keyRing) == 0 {
return nil, fmt.Errorf("failed to unlock any key")
return nil, 0, fmt.Errorf("failed to unlock any key")
}
return keyRing, nil
return keyRing, primaryKeyID, nil
}

func (c *Client) Unlock(auth *Auth, keySalts map[string][]byte, passphrase string) (openpgp.EntityList, error) {
func (c *Client) Unlock(auth *Auth, keySalts map[string][]byte, passphrase string) (openpgp.EntityList, uint64, error) {
c.uid = auth.UID
c.accessToken = auth.AccessToken

u, err := c.GetCurrentUser()
if err != nil {
return nil, err
return nil, 0, err
}

userKeyRing, err := unlockKeyRing(u.Keys, nil, keySalts, []byte(passphrase))
userKeyRing, userPrimaryKeyID, err := unlockKeyRing(u.Keys, nil, keySalts, []byte(passphrase))
if err != nil {
return nil, err
return nil, 0, err
}

addrs, err := c.ListAddresses()
if err != nil {
return nil, err
return nil, 0, err
}

var keyRing openpgp.EntityList
// Start with user keys (needed for contact encryption/signing)
// then append address keys (needed for email)
keyRing := append(openpgp.EntityList{}, userKeyRing...)
for _, addr := range addrs {
addrKeyRing, err := unlockKeyRing(addr.Keys, userKeyRing, keySalts, []byte(passphrase))
addrKeyRing, _, err := unlockKeyRing(addr.Keys, userKeyRing, keySalts, []byte(passphrase))
if err != nil {
log.Printf("warning: failed to unlock address <%v>: %v", addr.Email, err)
continue
Expand All @@ -359,12 +366,14 @@ func (c *Client) Unlock(auth *Auth, keySalts map[string][]byte, passphrase strin
}

if len(keyRing) == 0 {
return nil, fmt.Errorf("failed to unlock any key")
return nil, 0, fmt.Errorf("failed to unlock any key")
}

c.keyRing = keyRing

return keyRing, nil
// Return the primary USER key ID — this is the key Proton uses
// for contact encryption/signing (not address keys)
return keyRing, userPrimaryKeyID, nil
}

func (c *Client) Logout() error {
Expand Down
2 changes: 1 addition & 1 deletion smtp/smtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ func (s *session) AuthMechanisms() []string {
}

func (s *session) authPlain(username, password string) error {
c, privateKeys, err := s.be.sessions.Auth(username, password)
c, privateKeys, _, err := s.be.sessions.Auth(username, password)
if err != nil {
return err
}
Expand Down