-
Notifications
You must be signed in to change notification settings - Fork 209
[UX] Extend dstack login with interactive selection of url and default project
#3492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,38 @@ | ||
| import argparse | ||
| import queue | ||
| import sys | ||
| import threading | ||
| import urllib.parse | ||
| import webbrowser | ||
| from http.server import BaseHTTPRequestHandler, HTTPServer | ||
| from typing import Optional | ||
| from typing import Any, Optional | ||
|
|
||
| from rich.prompt import Prompt as RichPrompt | ||
| from rich.text import Text | ||
|
|
||
| try: | ||
| import questionary | ||
|
|
||
| is_project_menu_supported = sys.stdin.isatty() | ||
| except (ImportError, NotImplementedError, AttributeError): | ||
| is_project_menu_supported = False | ||
|
|
||
| from dstack._internal.cli.commands import BaseCommand | ||
| from dstack._internal.cli.commands.project import select_default_project | ||
| from dstack._internal.cli.utils.common import console, resolve_url | ||
| from dstack._internal.core.errors import ClientError, CLIError | ||
| from dstack._internal.core.models.users import UserWithCreds | ||
| from dstack._internal.utils.logging import get_logger | ||
| from dstack.api._public.runs import ConfigManager | ||
| from dstack.api.server import APIClient | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| class UrlPrompt(RichPrompt): | ||
| def render_default(self, default: Any) -> Text: | ||
| return Text(f"({default})", style="bold orange1") | ||
|
|
||
|
|
||
| class LoginCommand(BaseCommand): | ||
| NAME = "login" | ||
|
|
@@ -23,7 +43,7 @@ def _register(self): | |
| self._parser.add_argument( | ||
| "--url", | ||
| help="The server URL, e.g. https://sky.dstack.ai", | ||
| required=True, | ||
| required=not is_project_menu_supported, | ||
| ) | ||
| self._parser.add_argument( | ||
| "-p", | ||
|
|
@@ -33,10 +53,25 @@ def _register(self): | |
| " Selected automatically if the server supports only one provider." | ||
| ), | ||
| ) | ||
| self._parser.add_argument( | ||
| "-y", | ||
| "--yes", | ||
| help="Don't ask for confirmation (e.g. set first project as default)", | ||
| action="store_true", | ||
| ) | ||
| self._parser.add_argument( | ||
| "-n", | ||
| "--no", | ||
| help="Don't ask for confirmation (e.g. do not change default project)", | ||
| action="store_true", | ||
| ) | ||
|
|
||
| def _command(self, args: argparse.Namespace): | ||
| super()._command(args) | ||
| base_url = _normalize_url_or_error(args.url) | ||
| url = args.url | ||
| if url is None: | ||
| url = self._prompt_url() | ||
| base_url = _normalize_url_or_error(url) | ||
| api_client = APIClient(base_url=base_url) | ||
| provider = self._select_provider_or_error(api_client=api_client, provider=args.provider) | ||
| server = _LoginServer(api_client=api_client, provider=provider) | ||
|
|
@@ -56,9 +91,9 @@ def _command(self, args: argparse.Namespace): | |
| server.shutdown() | ||
| if user is None: | ||
| raise CLIError("CLI authentication failed") | ||
| console.print(f"Logged in as [code]{user.username}[/].") | ||
| console.print(f"Logged in as [code]{user.username}[/]") | ||
| api_client = APIClient(base_url=base_url, token=user.creds.token) | ||
| self._configure_projects(api_client=api_client, user=user) | ||
| self._configure_projects(api_client=api_client, user=user, args=args) | ||
|
|
||
| def _select_provider_or_error(self, api_client: APIClient, provider: Optional[str]) -> str: | ||
| providers = api_client.auth.list_providers() | ||
|
|
@@ -67,6 +102,8 @@ def _select_provider_or_error(self, api_client: APIClient, provider: Optional[st | |
| raise CLIError("No SSO providers configured on the server.") | ||
| if provider is None: | ||
| if len(available_providers) > 1: | ||
| if is_project_menu_supported: | ||
| return self._prompt_provider(available_providers) | ||
| raise CLIError( | ||
| "Specify -p/--provider to choose SSO provider" | ||
| f" Available providers: {', '.join(available_providers)}" | ||
|
|
@@ -79,7 +116,34 @@ def _select_provider_or_error(self, api_client: APIClient, provider: Optional[st | |
| ) | ||
| return provider | ||
|
|
||
| def _configure_projects(self, api_client: APIClient, user: UserWithCreds): | ||
| def _prompt_url(self) -> str: | ||
| url = UrlPrompt.ask( | ||
| "Enter the server URL", | ||
| default="https://sky.dstack.ai", | ||
| console=console, | ||
| ) | ||
| if url is None: | ||
| raise CLIError("URL is required") | ||
| return url | ||
|
|
||
| def _prompt_provider(self, available_providers: list[str]) -> str: | ||
| choices = [ | ||
| questionary.Choice(title=provider, value=provider) # pyright: ignore[reportPossiblyUnboundVariable] | ||
| for provider in available_providers | ||
| ] | ||
| selected_provider = questionary.select( # pyright: ignore[reportPossiblyUnboundVariable] | ||
| message="Select SSO provider:", | ||
| choices=choices, | ||
| qmark="", | ||
| instruction="(↑↓ Enter)", | ||
| ).ask() | ||
| if selected_provider is None: | ||
| raise CLIError("Provider selection is required") | ||
| return selected_provider | ||
|
|
||
| def _configure_projects( | ||
| self, api_client: APIClient, user: UserWithCreds, args: argparse.Namespace | ||
| ): | ||
| projects = api_client.projects.list(include_not_joined=False) | ||
| if len(projects) == 0: | ||
| console.print( | ||
|
|
@@ -89,30 +153,88 @@ def _configure_projects(self, api_client: APIClient, user: UserWithCreds): | |
| return | ||
| config_manager = ConfigManager() | ||
| default_project = config_manager.get_project_config() | ||
| new_default_project = None | ||
| for i, project in enumerate(projects): | ||
| set_as_default = ( | ||
| default_project is None | ||
| and i == 0 | ||
| or default_project is not None | ||
| and default_project.name == project.project_name | ||
| ) | ||
| if set_as_default: | ||
| new_default_project = project | ||
| for project in projects: | ||
| config_manager.configure_project( | ||
| name=project.project_name, | ||
| url=api_client.base_url, | ||
| token=user.creds.token, | ||
| default=set_as_default, | ||
| default=False, | ||
| ) | ||
| config_manager.save() | ||
| project_names = ", ".join(f"[code]{p.project_name}[/]" for p in projects) | ||
| console.print( | ||
| f"Configured projects: {', '.join(f'[code]{p.project_name}[/]' for p in projects)}." | ||
| f"Added {project_names} project{'' if len(projects) == 1 else 's'} at {config_manager.config_filepath}" | ||
| ) | ||
| if new_default_project: | ||
| console.print( | ||
| f"Set project [code]{new_default_project.project_name}[/] as default project." | ||
| ) | ||
|
|
||
| project_configs = config_manager.list_project_configs() | ||
|
|
||
| if args.no: | ||
| return | ||
|
|
||
| if args.yes: | ||
| if len(projects) > 0: | ||
| first_project_from_server = projects[0] | ||
| first_project_config = next( | ||
| ( | ||
| pc | ||
| for pc in project_configs | ||
| if pc.name == first_project_from_server.project_name | ||
| ), | ||
| None, | ||
| ) | ||
| if first_project_config is not None: | ||
| config_manager.configure_project( | ||
| name=first_project_config.name, | ||
| url=first_project_config.url, | ||
| token=first_project_config.token, | ||
| default=True, | ||
| ) | ||
| config_manager.save() | ||
| console.print( | ||
| f"Set [code]{first_project_config.name}[/] project as default at {config_manager.config_filepath}" | ||
| ) | ||
| return | ||
|
|
||
| if len(project_configs) == 1 or not is_project_menu_supported: | ||
| selected_project = None | ||
| if len(project_configs) == 1: | ||
| selected_project = project_configs[0] | ||
| else: | ||
| for i, project in enumerate(projects): | ||
| set_as_default = ( | ||
| default_project is None | ||
| and i == 0 | ||
| or default_project is not None | ||
| and default_project.name == project.project_name | ||
| ) | ||
| if set_as_default: | ||
| selected_project = next( | ||
| (pc for pc in project_configs if pc.name == project.project_name), | ||
| None, | ||
| ) | ||
| break | ||
| if selected_project is not None: | ||
| config_manager.configure_project( | ||
| name=selected_project.name, | ||
| url=selected_project.url, | ||
| token=selected_project.token, | ||
| default=True, | ||
| ) | ||
| config_manager.save() | ||
| console.print( | ||
| f"Set [code]{selected_project.name}[/] project as default at {config_manager.config_filepath}" | ||
| ) | ||
| else: | ||
| console.print() | ||
| selected_project = select_default_project(project_configs, default_project) | ||
| if selected_project is not None: | ||
| config_manager.configure_project( | ||
| name=selected_project.name, | ||
| url=selected_project.url, | ||
| token=selected_project.token, | ||
| default=True, | ||
| ) | ||
| config_manager.save() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing confirmation message after interactive project selectionLow Severity The interactive project selection path (the |
||
|
|
||
|
|
||
| class _BadRequestError(Exception): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interactive project selection includes projects from all servers
Medium Severity
When
select_default_projectis called, it receivesproject_configsfromconfig_manager.list_project_configs(), which returns ALL configured projects across ALL servers, not just projects from the server being logged into. If the user previously configured projects from a different server, those unrelated projects appear in the default project selection menu. The fallback path (lines 203-214) correctly limits selection to projects from the current server by iterating overprojectsfrom the API.Additional Locations (1)
src/dstack/_internal/cli/commands/login.py#L168-L169