Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 100 additions & 13 deletions bittensor_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4516,6 +4516,11 @@ def stake_add(
amount: float = typer.Option(
0.0, "--amount", help="The amount of TAO to stake"
),
amounts: str = typer.Option(
"",
"--amounts",
help="Comma-separated amounts of TAO to stake for each netuid. Must be used with --netuids and the number of amounts must match the number of netuids. Example: --netuids 1,2,3 --amounts 0.1,0.2,0.3",
),
include_hotkeys: str = typer.Option(
"",
"--include-hotkeys",
Expand Down Expand Up @@ -4586,7 +4591,10 @@ def stake_add(
7. Stake the same amount to multiple subnets:
[green]$[/green] btcli stake add --amount 100 --netuids 4,5,6

8. Stake without MEV protection:
8. Stake different amounts to multiple subnets:
[green]$[/green] btcli stake add --netuids 1,2,3 --amounts 0.1,0.2,0.3

9. Stake without MEV protection:
[green]$[/green] btcli stake add --amount 100 --netuid 1 --no-mev-protection

[bold]Safe Staking Parameters:[/bold]
Expand Down Expand Up @@ -4616,12 +4624,57 @@ def stake_add(
# ensure no negative netuids make it into our list
validate_netuid(netuid_)

# Validate mutually exclusive options
if amount and amounts:
print_error(
"Cannot specify both --amount and --amounts. Use --amount for single amount or --amounts for per-netuid amounts."
)
return

if stake_all and amount:
print_error(
"Cannot specify an amount and 'stake-all'. Choose one or the other."
)
return

if stake_all and amounts:
print_error(
"Cannot specify --amounts and 'stake-all'. Choose one or the other."
)
return

# Parse and validate --amounts if provided
amounts_list = None
if amounts:
if not netuids or len(netuids) == 0:
print_error(
"--amounts can only be used with --netuids. Please specify netuids."
)
return
try:
amounts_list = parse_to_list(
amounts,
float,
"Amounts must be numbers separated by commas, e.g., `--amounts 0.1,0.2,0.3`.",
False,
)
if len(amounts_list) != len(netuids):
print_error(
f"Number of amounts ({len(amounts_list)}) must match number of netuids ({len(netuids)}). "
f"Netuids: {netuids}, Amounts: {amounts_list}"
)
return
# Validate all amounts are positive
for amt in amounts_list:
if amt <= 0:
print_error(
f"All amounts must be positive. Invalid amount: {amt}"
)
return
except Exception as e:
print_error(f"Failed to parse amounts: {e}")
return

if stake_all and not amount:
if not confirm_action(
"Stake all the available TAO tokens?",
Expand Down Expand Up @@ -4747,8 +4800,10 @@ def stake_add(
else:
exclude_hotkeys = []

# TODO: Ask amount for each subnet explicitly if more than one
if not stake_all and not amount:
# Use amounts_list if provided via --amounts flag
if amounts_list:
amount = amounts_list
elif not stake_all and not amount:
free_balance = self._run_command(
wallets.wallet_balance(
wallet, self.initialize_chain(network), False, None
Expand All @@ -4759,23 +4814,55 @@ def stake_add(
if free_balance == Balance.from_tao(0):
print_error("You dont have any balance to stake.")
return
if netuids:

# If netuids is provided and has multiple subnets, ask for amount per netuid
if netuids and len(netuids) > 1:
amounts_prompted = []
remaining_balance = free_balance
for netuid in netuids:
netuid_amount = FloatPrompt.ask(
f"Amount to [{COLORS.G.SUBHEAD_MAIN}]stake to netuid {netuid} (TAO τ)[/] "
f"[dim](remaining balance: {remaining_balance})[/dim]"
)
if netuid_amount <= 0:
print_error(
f"You entered an incorrect stake amount: {netuid_amount}"
)
raise typer.Exit()
if Balance.from_tao(netuid_amount) > remaining_balance:
print_error(
f"You dont have enough balance to stake. Remaining balance: {remaining_balance}."
)
raise typer.Exit()
amounts_prompted.append(netuid_amount)
remaining_balance -= Balance.from_tao(netuid_amount)
amount = amounts_prompted
elif netuids:
# Single netuid
amount = FloatPrompt.ask(
f"Amount to [{COLORS.G.SUBHEAD_MAIN}]stake (TAO τ)"
)
if amount <= 0:
print_error(f"You entered an incorrect stake amount: {amount}")
raise typer.Exit()
if Balance.from_tao(amount) > free_balance:
print_error(
f"You dont have enough balance to stake. Current free Balance: {free_balance}."
)
raise typer.Exit()
else:
# netuids is empty list or None (all subnets) - ask for amount per netuid
amount = FloatPrompt.ask(
f"Amount to [{COLORS.G.SUBHEAD_MAIN}]stake to each netuid (TAO τ)"
)

if amount <= 0:
print_error(f"You entered an incorrect stake amount: {amount}")
raise typer.Exit()
if Balance.from_tao(amount) > free_balance:
print_error(
f"You dont have enough balance to stake. Current free Balance: {free_balance}."
)
raise typer.Exit()
if amount <= 0:
print_error(f"You entered an incorrect stake amount: {amount}")
raise typer.Exit()
if Balance.from_tao(amount) > free_balance:
print_error(
f"You dont have enough balance to stake. Current free Balance: {free_balance}."
)
raise typer.Exit()
logger.debug(
"args:\n"
f"network: {network}\n"
Expand Down
41 changes: 27 additions & 14 deletions bittensor_cli/src/commands/stake/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from functools import partial

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from async_substrate_interface import AsyncExtrinsicReceipt
from rich.table import Table
Expand Down Expand Up @@ -38,7 +38,7 @@ async def stake_add(
subtensor: "SubtensorInterface",
netuids: Optional[list[int]],
stake_all: bool,
amount: float,
amount: Union[float, list[float]],
prompt: bool,
decline: bool,
quiet: bool,
Expand All @@ -59,7 +59,7 @@ async def stake_add(
subtensor: SubtensorInterface object
netuids: the netuids to stake to (None indicates all subnets)
stake_all: whether to stake all available balance
amount: specified amount of balance to stake
amount: specified amount of balance to stake (float for single amount, list[float] for per-netuid amounts)
prompt: whether to prompt the user
all_hotkeys: whether to stake all hotkeys
include_hotkeys: list of hotkeys to include in staking process (if not specifying `--all`)
Expand Down Expand Up @@ -350,8 +350,12 @@ async def stake_extrinsic(
remaining_wallet_balance = current_wallet_balance
max_slippage = 0.0

amount_list = []
if isinstance(amount, list):
amount_list = amount

for hotkey in hotkeys_to_stake_to:
for netuid in netuids:
for netuid_idx, netuid in enumerate(netuids):
# Check that the subnet exists.
subnet_info = all_subnets.get(netuid)
if not subnet_info:
Expand All @@ -361,7 +365,11 @@ async def stake_extrinsic(

# Get the amount.
amount_to_stake = Balance(0)
if amount:
if amount_list:
# Use the amount from the list for this specific netuid
amount_to_stake = Balance.from_tao(amount_list[netuid_idx])
elif amount:
# Single amount for all netuids
amount_to_stake = Balance.from_tao(amount)
elif stake_all:
amount_to_stake = current_wallet_balance / len(netuids)
Expand All @@ -373,15 +381,6 @@ async def stake_extrinsic(
)
amounts_to_stake.append(amount_to_stake)

# Check enough to stake.
if amount_to_stake > remaining_wallet_balance:
print_error(
f"Not enough stake:[bold white]\n wallet balance:{remaining_wallet_balance} < "
f"staking amount: {amount_to_stake}[/bold white]"
)
return
remaining_wallet_balance -= amount_to_stake

# Calculate slippage
# TODO: Update for V3, slippage calculation is significantly different in v3
# try:
Expand Down Expand Up @@ -433,6 +432,20 @@ async def stake_extrinsic(
safe_staking_=safe_staking,
)
row_extension = []

# Check enough balance to cover stake amount and extrinsic fee
total_cost = (
amount_to_stake + extrinsic_fee if not proxy else amount_to_stake
)
if total_cost > remaining_wallet_balance:
print_error(
f"[red]Not enough stake[/red]:[bold white]\n wallet balance: {remaining_wallet_balance} < "
f"staking amount: {amount_to_stake}[/bold white]"
)
return

# Deduct stake amount and extrinsic fee from remaining balance
remaining_wallet_balance -= total_cost
# TODO this should be asyncio gathered before the for loop
amount_minus_fee = (
(amount_to_stake - extrinsic_fee) if not proxy else amount_to_stake
Expand Down
101 changes: 101 additions & 0 deletions tests/e2e_tests/test_staking_sudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,107 @@ def line(key: str) -> Union[str, bool]:
assert line("error_messages") == ""
assert isinstance(line("extrinsic_ids"), str)

# Test staking with prompted amounts for each netuid
add_stake_prompted = exec_command_alice(
command="stake",
sub_command="add",
extra_args=[
"--netuids",
",".join(str(x) for x in multiple_netuids),
"--wallet-path",
wallet_path_alice,
"--wallet-name",
wallet_alice.name,
"--hotkey",
wallet_alice.hotkey_str,
"--chain",
"ws://127.0.0.1:9945",
"--tolerance",
"0.1",
"--partial",
"--era",
"32",
"--json-output",
"--no-prompt",
# Note: No --amount or --amounts flag, will trigger prompts
],
inputs=["50", "30"], # 50 TAO for netuid 2, 30 TAO for netuid 3
)

# Verify prompts appeared in output
assert "stake to netuid 2" in add_stake_prompted.stdout
assert "stake to netuid 3" in add_stake_prompted.stdout
assert "remaining balance" in add_stake_prompted.stdout

# Extract JSON from stdout (prompts are mixed with JSON output)
json_match = re.search(r"\{.*\}", add_stake_prompted.stdout, re.DOTALL)
if json_match:
json_str = json_match.group(0)
add_stake_prompted_output = json.loads(json_str)

for netuid_ in multiple_netuids:

def line_prompted(key: str) -> Union[str, bool]:
return add_stake_prompted_output[key][str(netuid_)][
wallet_alice.hotkey.ss58_address
]

assert line_prompted("staking_success") is True, (
f"Staking to netuid {netuid_} should succeed"
)
assert line_prompted("error_messages") == "", (
f"No error messages expected for netuid {netuid_}"
)
assert isinstance(line_prompted("extrinsic_ids"), str), (
f"Extrinsic ID should be a string for netuid {netuid_}"
)

# Test staking with --amounts option for different amounts per netuid
add_stake_amounts = exec_command_alice(
command="stake",
sub_command="add",
extra_args=[
"--netuids",
",".join(str(x) for x in multiple_netuids),
"--amounts",
"25,15", # 25 TAO for netuid 2, 15 TAO for netuid 3
"--wallet-path",
wallet_path_alice,
"--wallet-name",
wallet_alice.name,
"--hotkey",
wallet_alice.hotkey_str,
"--chain",
"ws://127.0.0.1:9945",
"--tolerance",
"0.1",
"--partial",
"--era",
"32",
"--json-output",
"--no-prompt",
],
)

# Parse and verify the staking results for --amounts
add_stake_amounts_output = json.loads(add_stake_amounts.stdout)
for netuid_ in multiple_netuids:

def line_amounts(key: str) -> Union[str, bool]:
return add_stake_amounts_output[key][str(netuid_)][
wallet_alice.hotkey.ss58_address
]

assert line_amounts("staking_success") is True, (
f"Staking with --amounts to netuid {netuid_} should succeed"
)
assert line_amounts("error_messages") == "", (
f"No error messages expected for netuid {netuid_} with --amounts"
)
assert isinstance(line_amounts("extrinsic_ids"), str), (
f"Extrinsic ID should be a string for netuid {netuid_} with --amounts"
)

# Fetch the hyperparameters of the subnet
hyperparams = exec_command_alice(
command="sudo",
Expand Down