Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 102 additions & 31 deletions staticaddr/withdraw/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,20 +238,11 @@ func (m *Manager) recoverWithdrawals(ctx context.Context) error {
return err
}

// Group the deposits by their finalized withdrawal transaction.
depositsByWithdrawalTx := make(map[chainhash.Hash][]*deposit.Deposit)
hash2tx := make(map[chainhash.Hash]*wire.MsgTx)
for _, d := range withdrawingDeposits {
withdrawalTx := d.FinalizedWithdrawalTx
if withdrawalTx == nil {
continue
}
txid := withdrawalTx.TxHash()
hash2tx[txid] = withdrawalTx

depositsByWithdrawalTx[txid] = append(
depositsByWithdrawalTx[txid], d,
)
depositsByWithdrawalTx, hash2tx, err := m.groupWithdrawingDepositsByTx(
ctx, withdrawingDeposits,
)
if err != nil {
return err
}

// Publishing a transaction can take a while in neutrino mode, so
Expand Down Expand Up @@ -303,6 +294,98 @@ func (m *Manager) recoverWithdrawals(ctx context.Context) error {
return nil
}

// groupWithdrawingDepositsByTx clusters withdrawing deposits by their finalized
// withdrawal transaction hash.
func (m *Manager) groupWithdrawingDepositsByTx(ctx context.Context,
withdrawingDeposits []*deposit.Deposit) (
map[chainhash.Hash][]*deposit.Deposit, map[chainhash.Hash]*wire.MsgTx,
error) {

depositsByWithdrawalTx := make(map[chainhash.Hash][]*deposit.Deposit)
hash2tx := make(map[chainhash.Hash]*wire.MsgTx)

// Build an index of all known finalized withdrawal transactions.
for _, d := range withdrawingDeposits {
if d.FinalizedWithdrawalTx == nil {
continue
}

txid := d.FinalizedWithdrawalTx.TxHash()
hash2tx[txid] = d.FinalizedWithdrawalTx
}

// If exactly one tx hash is present, we can recover missing tx pointers
// from that single cluster.
var fallbackTx *wire.MsgTx
if len(hash2tx) == 1 {
for _, tx := range hash2tx {
fallbackTx = tx
}
}

for _, d := range withdrawingDeposits {
withdrawalTx := d.FinalizedWithdrawalTx
if withdrawalTx == nil {
if fallbackTx == nil {
log.Warnf("Skipping withdrawing deposit %v "+
"during recovery: missing finalized "+
"withdrawal tx", d.OutPoint)

continue
}

// Persist the recovered tx pointer so future restarts
// don't depend on in-memory fallback recovery.
d.Lock()
d.FinalizedWithdrawalTx = fallbackTx
d.Unlock()

err := m.cfg.DepositManager.UpdateDeposit(ctx, d)
if err != nil {
return nil, nil, fmt.Errorf("unable to "+
"persist recovered finalized "+
"withdrawal tx for deposit %v: %w",
d.OutPoint, err)
}

log.Warnf("Recovered missing finalized withdrawal tx "+
"for deposit %v", d.OutPoint)

withdrawalTx = fallbackTx
}

txid := withdrawalTx.TxHash()
hash2tx[txid] = withdrawalTx
depositsByWithdrawalTx[txid] = append(
depositsByWithdrawalTx[txid], d,
)
}
Comment on lines +304 to +362
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation iterates over withdrawingDeposits twice. This can be made more efficient and arguably clearer by partitioning the deposits into those with and without a finalized transaction in a single pass. This avoids redundant work and simplifies the logic.

depositsByWithdrawalTx := make(map[chainhash.Hash][]*deposit.Deposit)
	hash2tx := make(map[chainhash.Hash]*wire.MsgTx)
	var missingTxDeposits []*deposit.Deposit

	// Partition deposits into those with and without a finalized tx, and
	// group the ones that have a tx.
	for _, d := range withdrawingDeposits {
		if d.FinalizedWithdrawalTx == nil {
			missingTxDeposits = append(missingTxDeposits, d)
			continue
		}

		txid := d.FinalizedWithdrawalTx.TxHash()
		hash2tx[txid] = d.FinalizedWithdrawalTx
		depositsByWithdrawalTx[txid] = append(
			depositsByWithdrawalTx[txid], d,
		)
	}

	// If there are deposits with missing tx pointers and there is exactly
	// one cluster of deposits with a known tx, we can recover.
	if len(missingTxDeposits) > 0 && len(hash2tx) == 1 {
		var fallbackTx *wire.MsgTx
		var fallbackTxID chainhash.Hash
		for txid, tx := range hash2tx {
			fallbackTx = tx
			fallbackTxID = txid
			break
		}

		for _, d := range missingTxDeposits {
			// Persist the recovered tx pointer so future restarts
			// don't depend on in-memory fallback recovery.
			d.Lock()
			d.FinalizedWithdrawalTx = fallbackTx
			d.Unlock()

			err := m.cfg.DepositManager.UpdateDeposit(ctx, d)
			if err != nil {
				return nil, nil, fmt.Errorf("unable to "+
					"persist recovered finalized "+
					"withdrawal tx for deposit %v: %w",
					d.OutPoint, err)
			}

			log.Warnf("Recovered missing finalized withdrawal tx "+
				"for deposit %v", d.OutPoint)

			// Add the recovered deposit to the cluster.
			depositsByWithdrawalTx[fallbackTxID] = append(
				depositsByWithdrawalTx[fallbackTxID], d,
			)
		}
	} else if len(missingTxDeposits) > 0 {
		// Log deposits that could not be recovered.
		for _, d := range missingTxDeposits {
			log.Warnf("Skipping withdrawing deposit %v "+
				"during recovery: missing finalized "+
				"withdrawal tx", d.OutPoint)
		}
	}


return depositsByWithdrawalTx, hash2tx, nil
}

// persistFinalizedWithdrawalTx updates the selected deposits with the finalized
// withdrawal tx and persists the change before state transitions.
func (m *Manager) persistFinalizedWithdrawalTx(ctx context.Context,
deposits []*deposit.Deposit, finalizedTx *wire.MsgTx) error {

for _, d := range deposits {
d.Lock()
d.FinalizedWithdrawalTx = finalizedTx
d.Unlock()
}

for _, d := range deposits {
err := m.cfg.DepositManager.UpdateDeposit(ctx, d)
if err != nil {
return fmt.Errorf("failed to update deposit %v: %w",
d.OutPoint, err)
}
}
Comment on lines +372 to +384
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function persistFinalizedWithdrawalTx currently iterates through the deposits slice twice: once to update the FinalizedWithdrawalTx in memory, and a second time to persist each deposit to the database. These two loops can be combined into a single loop for better efficiency and readability. In each iteration, you can update the in-memory object and then immediately call UpdateDeposit. This also improves consistency in case of a failure during persistence.

	for _, d := range deposits {
		d.Lock()
		d.FinalizedWithdrawalTx = finalizedTx
		d.Unlock()

		err := m.cfg.DepositManager.UpdateDeposit(ctx, d)
		if err != nil {
			return fmt.Errorf("failed to update deposit %v: %w",
				d.OutPoint, err)
		}
	}


return nil
}

// WithdrawDeposits starts a deposits withdrawal flow. If the amount is set to 0
// the full amount of the selected deposits will be withdrawn.
func (m *Manager) WithdrawDeposits(ctx context.Context,
Expand Down Expand Up @@ -478,14 +561,11 @@ func (m *Manager) WithdrawDeposits(ctx context.Context,
m.mu.Unlock()
}

// Attach the finalized withdrawal tx to the deposits. After a client
// restart we can use this address as an indicator to republish the
// withdrawal tx and continue the withdrawal.
// Deposits with the same withdrawal tx are part of the same withdrawal.
for _, d := range deposits {
d.Lock()
d.FinalizedWithdrawalTx = finalizedTx
d.Unlock()
// Persist the finalized withdrawal tx before state transitions so that
// a restart can recover the full withdrawal cluster.
err = m.persistFinalizedWithdrawalTx(ctx, deposits, finalizedTx)
if err != nil {
return "", "", err
}

// Add the new withdrawal tx to the finalized withdrawals to republish
Expand All @@ -504,15 +584,6 @@ func (m *Manager) WithdrawDeposits(ctx context.Context,
err)
}

// Update the deposits in the database.
for _, d := range deposits {
err = m.cfg.DepositManager.UpdateDeposit(ctx, d)
if err != nil {
return "", "", fmt.Errorf("failed to update "+
"deposit %w", err)
}
}

return finalizedTx.TxID(), withdrawalAddress.String(), nil
}

Expand Down
195 changes: 195 additions & 0 deletions staticaddr/withdraw/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightninglabs/loop/fsm"
"github.com/lightninglabs/loop/staticaddr/address"
"github.com/lightninglabs/loop/staticaddr/deposit"
"github.com/lightninglabs/loop/staticaddr/script"
"github.com/lightninglabs/loop/swapserverrpc"
"github.com/lightninglabs/loop/test"
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/funding"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnrpc"
Expand All @@ -20,6 +24,10 @@ import (
"github.com/stretchr/testify/require"
)

func init() {
UseLogger(build.NewSubLogger("WDRW", nil))
}

// TestNewManagerHeightValidation ensures the constructor rejects zero heights.
func TestNewManagerHeightValidation(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -606,3 +614,190 @@ func TestCalculateWithdrawalTxValues(t *testing.T) {
})
}
}

// recoveryDepositManager is a test stub that tracks recovery interactions for
// deposits in the WITHDRAWING state.
type recoveryDepositManager struct {
withdrawingDeposits []*deposit.Deposit
transitioned [][]wire.OutPoint
updated []wire.OutPoint
}

// GetActiveDepositsInState returns the preset withdrawing deposits for the
// recovery test.
func (m *recoveryDepositManager) GetActiveDepositsInState(
_ fsm.StateType) (
[]*deposit.Deposit, error) {

return m.withdrawingDeposits, nil
}

// AllOutpointsActiveDeposits reports no active deposit set lookup in this
// test stub.
func (m *recoveryDepositManager) AllOutpointsActiveDeposits(
_ []wire.OutPoint, _ fsm.StateType) ([]*deposit.Deposit, bool) {

return nil, false
}

// TransitionDeposits records the outpoints transitioned by recovery.
func (m *recoveryDepositManager) TransitionDeposits(_ context.Context,
deposits []*deposit.Deposit, _ fsm.EventType, _ fsm.StateType) error {

outpoints := make([]wire.OutPoint, len(deposits))
for i, d := range deposits {
outpoints[i] = d.OutPoint
}

m.transitioned = append(m.transitioned, outpoints)

return nil
}

// UpdateDeposit records which deposits were updated during recovery.
func (m *recoveryDepositManager) UpdateDeposit(_ context.Context,
d *deposit.Deposit) error {

m.updated = append(m.updated, d.OutPoint)

return nil
}

// recoveryAddressManager is a test stub that serves static address parameters
// needed by withdrawal recovery.
type recoveryAddressManager struct {
params *address.Parameters
}

// GetStaticAddressParameters returns the preset static address parameters for
// the recovery test.
func (m *recoveryAddressManager) GetStaticAddressParameters(
_ context.Context) (*address.Parameters, error) {

return m.params, nil
}

// GetStaticAddress returns no static address in this test stub.
func (m *recoveryAddressManager) GetStaticAddress(
_ context.Context) (*script.StaticAddress, error) {

return nil, nil
}

// TestRecoverWithdrawalsIncludesMissingFinalizedTxDeposits verifies regression
// coverage for restart recovery where some deposits are in WITHDRAWING but
// missing FinalizedWithdrawalTx pointers.
//
// Without the fix this test still builds, but fails at runtime because the
// legacy recovery code silently skips those deposits and only reinstates the
// subset with non-nil FinalizedWithdrawalTx.
func TestRecoverWithdrawalsIncludesMissingFinalizedTxDeposits(t *testing.T) {
t.Parallel()

tx := wire.NewMsgTx(2)
tx.AddTxIn(&wire.TxIn{
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{9},
Index: 0,
},
})
tx.AddTxOut(&wire.TxOut{
Value: 1000,
PkScript: []byte{txscript.OP_1},
})

known1 := &deposit.Deposit{
OutPoint: wire.OutPoint{
Hash: chainhash.Hash{1},
Index: 0,
},
ConfirmationHeight: 100,
FinalizedWithdrawalTx: tx,
}
known2 := &deposit.Deposit{
OutPoint: wire.OutPoint{
Hash: chainhash.Hash{2},
Index: 0,
},
ConfirmationHeight: 100,
FinalizedWithdrawalTx: tx,
}
missing1 := &deposit.Deposit{
OutPoint: wire.OutPoint{
Hash: chainhash.Hash{3},
Index: 0,
},
ConfirmationHeight: 100,
}
missing2 := &deposit.Deposit{
OutPoint: wire.OutPoint{
Hash: chainhash.Hash{4},
Index: 0,
},
ConfirmationHeight: 100,
}

depositMgr := &recoveryDepositManager{
withdrawingDeposits: []*deposit.Deposit{
known1, known2, missing1, missing2,
},
}
addrMgr := &recoveryAddressManager{
params: &address.Parameters{
PkScript: []byte{txscript.OP_1},
},
}

lnd := test.NewMockLnd()
go func() {
<-lnd.TxPublishChannel
}()
go func() {
<-lnd.RegisterSpendChannel
}()

mgr, err := NewManager(&ManagerConfig{
DepositManager: depositMgr,
WalletKit: lnd.WalletKit,
ChainNotifier: lnd.ChainNotifier,
AddressManager: addrMgr,
}, 101)
require.NoError(t, err)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

err = mgr.recoverWithdrawals(ctx)
require.NoError(t, err)

// Assert we re-instated one withdrawal cluster containing all four
// deposits. The old buggy behavior re-instated only the two deposits
// that already had finalized tx pointers.
require.Len(t, depositMgr.transitioned, 1)
require.Len(t, depositMgr.transitioned[0], 4)

transitioned := make(map[wire.OutPoint]struct{})
for _, op := range depositMgr.transitioned[0] {
transitioned[op] = struct{}{}
}
_, ok := transitioned[missing1.OutPoint]
require.True(t, ok)
_, ok = transitioned[missing2.OutPoint]
require.True(t, ok)

// Missing pointers should be recovered and persisted.
updated := make(map[wire.OutPoint]struct{})
for _, op := range depositMgr.updated {
updated[op] = struct{}{}
}
_, ok = updated[missing1.OutPoint]
require.True(t, ok)
_, ok = updated[missing2.OutPoint]
require.True(t, ok)
require.NotNil(t, missing1.FinalizedWithdrawalTx)
require.NotNil(t, missing2.FinalizedWithdrawalTx)

// Shut down notifier goroutines started by recovery.
cancel()
lnd.WaitForFinished()
}
Loading