Skip to content
6 changes: 6 additions & 0 deletions cli/azd/pkg/account/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Manager interface {
GetSubscriptions(ctx context.Context) ([]Subscription, error)
GetSubscriptionsWithDefaultSet(ctx context.Context) ([]Subscription, error)
GetLocations(ctx context.Context, subscriptionId string) ([]Location, error)
GetTenantDisplayNames(ctx context.Context) (map[string]string, error)
SetDefaultSubscription(ctx context.Context, subscriptionId string) (*Subscription, error)
SetDefaultLocation(ctx context.Context, subscriptionId string, location string) (*Location, error)
}
Expand Down Expand Up @@ -140,6 +141,11 @@ func (m *manager) GetSubscriptions(ctx context.Context) ([]Subscription, error)
return m.subManager.GetSubscriptions(ctx)
}

// GetTenantDisplayNames returns a map of tenant ID to display name.
func (m *manager) GetTenantDisplayNames(ctx context.Context) (map[string]string, error) {
return m.subManager.GetTenantDisplayNames(ctx)
}

// Gets the available Azure locations for the specified Azure subscription.
func (m *manager) GetLocations(ctx context.Context, subscriptionId string) ([]Location, error) {
locations, err := m.subManager.ListLocations(ctx, subscriptionId)
Expand Down
22 changes: 22 additions & 0 deletions cli/azd/pkg/account/subscriptions_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,28 @@ func (m *SubscriptionsManager) getSubscription(ctx context.Context, subscription
return &sub, nil
}

// GetTenantDisplayNames returns a map of tenant ID to display name for all tenants
// accessible by the current account.
func (m *SubscriptionsManager) GetTenantDisplayNames(ctx context.Context) (map[string]string, error) {
tenants, err := m.service.ListTenants(ctx)
if err != nil {
return nil, fmt.Errorf("listing tenants: %w", err)
}

result := make(map[string]string, len(tenants))
for _, t := range tenants {
if t.TenantID != nil {
name := *t.TenantID
if t.DisplayName != nil && *t.DisplayName != "" {
name = *t.DisplayName
}
result[*t.TenantID] = name
}
}

return result, nil
}

func toSubscriptions(azSubs []*armsubscriptions.Subscription, userAccessTenantId string) []Subscription {
if azSubs == nil {
return nil
Expand Down
79 changes: 59 additions & 20 deletions cli/azd/pkg/prompt/prompt_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ type SelectOptions struct {
HelpMessage string
// LoadingMessage is the loading message to display to the user.
LoadingMessage string
// SkipLoadingSpinner skips the loading spinner in PromptCustomResource.
// Use this when data is pre-loaded and LoadData returns immediately.
SkipLoadingSpinner bool
Comment thread
vhvb1989 marked this conversation as resolved.
// DisplayNumbers specifies whether to display numbers next to the choices.
DisplayNumbers *bool
// DisplayCount is the number of choices to display at a time.
Expand Down Expand Up @@ -157,6 +160,7 @@ type ResourceService interface {
type SubscriptionManager interface {
GetSubscriptions(ctx context.Context) ([]account.Subscription, error)
GetLocations(ctx context.Context, subscriptionId string) ([]account.Location, error)
GetTenantDisplayNames(ctx context.Context) (map[string]string, error)
}

// PromptServiceInterface defines the methods that the PromptService must implement.
Expand Down Expand Up @@ -211,6 +215,8 @@ func NewPromptService(
}

// PromptSubscription prompts the user to select an Azure subscription.
// If the user has access to multiple tenants, a tenant selection prompt is shown first
// to scope down the subscription list.
func (ps *promptService) PromptSubscription(
ctx context.Context,
selectorOptions *SelectOptions,
Expand All @@ -235,6 +241,31 @@ func (ps *promptService) PromptSubscription(
return nil, err
}

// Load subscriptions under a spinner first
var subscriptionList []account.Subscription
loadingSpinner := ux.NewSpinner(&ux.SpinnerOptions{
Text: mergedOptions.LoadingMessage,
})
Comment thread
vhvb1989 marked this conversation as resolved.

err := loadingSpinner.Run(ctx, func(ctx context.Context) error {
var loadErr error
subscriptionList, loadErr = ps.subscriptionManager.GetSubscriptions(ctx)
return loadErr
})
if err != nil {
return nil, fmt.Errorf("listing subscriptions: %w", err)
}
Comment thread
vhvb1989 marked this conversation as resolved.

// Apply tenant filtering (after spinner is done so the prompt doesn't overlap)
subscriptionList = filterByTenantEnvVar(subscriptionList)
if !ps.console.IsNoPromptMode() {
subscriptionList, err = promptAndFilterByTenant(
ctx, ps.console, subscriptionList, ps.subscriptionManager.GetTenantDisplayNames)
if err != nil {
return nil, err
}
Comment thread
vhvb1989 marked this conversation as resolved.
}

// Get default subscription from user config
var defaultSubscriptionId = ""
userConfig, err := ps.userConfigManager.Load()
Expand All @@ -247,19 +278,19 @@ func (ps *promptService) PromptSubscription(

hideId := isDemoModeEnabled()

return PromptCustomResource(ctx, CustomResourceOptions[account.Subscription]{
SelectorOptions: mergedOptions,
LoadData: func(ctx context.Context) ([]*account.Subscription, error) {
subscriptionList, err := ps.subscriptionManager.GetSubscriptions(ctx)
if err != nil {
return nil, err
}
// Use PromptCustomResource with pre-loaded data
subscriptions := make([]*account.Subscription, len(subscriptionList))
for i := range subscriptionList {
subscriptions[i] = &subscriptionList[i]
}

subscriptions := make([]*account.Subscription, len(subscriptionList))
for i, subscription := range subscriptionList {
subscriptions[i] = &subscription
}
// Create selector options with spinner disabled since data is already loaded
resourceSelectorOptions := *mergedOptions
resourceSelectorOptions.SkipLoadingSpinner = true

return PromptCustomResource(ctx, CustomResourceOptions[account.Subscription]{
SelectorOptions: &resourceSelectorOptions,
LoadData: func(ctx context.Context) ([]*account.Subscription, error) {
return subscriptions, nil
},
Comment thread
vhvb1989 marked this conversation as resolved.
DisplayResource: func(subscription *account.Subscription) (string, error) {
Expand Down Expand Up @@ -768,21 +799,29 @@ func PromptCustomResource[T any](ctx context.Context, options CustomResourceOpti
allowNewResource = true
selectedIndex = new(0)
} else {
loadingSpinner := ux.NewSpinner(&ux.SpinnerOptions{
Text: options.SelectorOptions.LoadingMessage,
})

err := loadingSpinner.Run(ctx, func(ctx context.Context) error {
loadData := func(ctx context.Context) error {
resourceList, err := options.LoadData(ctx)
if err != nil {
return err
}

resources = resourceList
return nil
})
if err != nil {
return nil, err
}

// Skip the spinner when data is pre-loaded
if mergedSelectorOptions.SkipLoadingSpinner {
if err := loadData(ctx); err != nil {
return nil, err
}
Comment thread
vhvb1989 marked this conversation as resolved.
} else {
loadingSpinner := ux.NewSpinner(&ux.SpinnerOptions{
Text: mergedSelectorOptions.LoadingMessage,
})
if err := loadingSpinner.Run(ctx, func(ctx context.Context) error {
return loadData(ctx)
}); err != nil {
return nil, err
}
}

if !allowNewResource && len(resources) == 0 {
Expand Down
21 changes: 21 additions & 0 deletions cli/azd/pkg/prompt/prompt_service_extra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,27 @@ func TestPromptCustomResource_NilSelectorOptions_UsesDefaultsAndForce(t *testing
require.Equal(t, 42, *result)
}

func TestPromptCustomResource_SkipLoadingSpinner(t *testing.T) {
t.Parallel()

loaded := false
_, err := PromptCustomResource(t.Context(), CustomResourceOptions[string]{
SelectorOptions: &SelectOptions{
SkipLoadingSpinner: true,
AllowNewResource: new(false),
},
LoadData: func(ctx context.Context) ([]*string, error) {
loaded = true
return nil, nil
},
})

// LoadData should have been called directly (without spinner)
require.True(t, loaded)
// No resources and AllowNewResource=false returns the sentinel error
require.ErrorIs(t, err, ErrNoResourcesFound)
}

// helpers

func emptySubs() []account.Subscription { return []account.Subscription{} }
99 changes: 57 additions & 42 deletions cli/azd/pkg/prompt/prompter.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"log"
"os"
"slices"
"strconv"

"github.com/MakeNowJust/heredoc/v2"
"github.com/azure/azure-dev/cli/azd/pkg/account"
Expand Down Expand Up @@ -71,12 +70,12 @@ func NewDefaultPrompter(
}

func (p *DefaultPrompter) PromptSubscription(ctx context.Context, msg string) (subscriptionId string, err error) {
subscriptionOptions, subscriptions, defaultSubscription, err := p.getSubscriptionOptions(ctx)
subscriptionInfos, err := p.accountManager.GetSubscriptions(ctx)
if err != nil {
return "", err
return "", fmt.Errorf("listing subscriptions: %w", err)
}

if len(subscriptionOptions) == 0 {
if len(subscriptionInfos) == 0 {
// NOTE: Error text must contain "no subscriptions found" to match the
// pattern in error_suggestions.yaml. Update both if rewording.
return "", errors.New(heredoc.Docf(
Expand All @@ -87,6 +86,32 @@ func (p *DefaultPrompter) PromptSubscription(ctx context.Context, msg string) (s
))
}

// Filter by AZURE_TENANT_ID if set (works in both prompt and no-prompt modes)
subscriptionInfos = filterByTenantEnvVar(subscriptionInfos)

// Tenant selection: if multiple tenants, prompt user to pick one
if !p.console.IsNoPromptMode() {
subscriptionInfos, err = promptAndFilterByTenant(
ctx, p.console, subscriptionInfos, p.accountManager.GetTenantDisplayNames)
if err != nil {
return "", err
}
}
Comment thread
vhvb1989 marked this conversation as resolved.

slices.SortFunc(subscriptionInfos, func(a, b account.Subscription) int {
return stringutil.CompareLower(a.Name, b.Name)
})

// The default value is based on AZURE_SUBSCRIPTION_ID, falling back to whatever default subscription in
// set in azd's config.
defaultSubscriptionId := os.Getenv(environment.SubscriptionIdEnvVarName)
if defaultSubscriptionId == "" {
defaultSubscriptionId = p.accountManager.GetDefaultSubscriptionID(ctx)
}
Comment thread
vhvb1989 marked this conversation as resolved.

subscriptionOptions, subscriptions, defaultSubscription :=
formatSubscriptionOptions(subscriptionInfos, defaultSubscriptionId)

for subscriptionId == "" {
subscriptionSelectionIndex, err := p.console.Select(ctx, input.ConsoleOptions{
Message: msg,
Expand All @@ -110,6 +135,34 @@ func (p *DefaultPrompter) PromptSubscription(ctx context.Context, msg string) (s
return subscriptionId, nil
}

// formatSubscriptionOptions formats subscription infos into display options.
func formatSubscriptionOptions(
subscriptionInfos []account.Subscription,
defaultSubscriptionId string,
) (options []string, ids []string, defaultOption any) {
options = make([]string, len(subscriptionInfos))
ids = make([]string, len(subscriptionInfos))

hideId := isDemoModeEnabled()

for index, info := range subscriptionInfos {
if hideId {
options[index] = fmt.Sprintf("%2d. %s", index+1, info.Name)
} else {
options[index] = fmt.Sprintf(
"%2d. %s (%s)", index+1, info.Name, info.Id)
}
Comment thread
vhvb1989 marked this conversation as resolved.

ids[index] = info.Id

if info.Id == defaultSubscriptionId {
defaultOption = options[index]
}
}

return options, ids, defaultOption
}

func (p *DefaultPrompter) PromptLocation(
ctx context.Context,
subId string,
Expand Down Expand Up @@ -246,44 +299,6 @@ func (p *DefaultPrompter) PromptResourceGroupFrom(
return name, nil
}

func (p *DefaultPrompter) getSubscriptionOptions(ctx context.Context) ([]string, []string, any, error) {
subscriptionInfos, err := p.accountManager.GetSubscriptions(ctx)
if err != nil {
return nil, nil, nil, fmt.Errorf("listing accounts: %w", err)
}

slices.SortFunc(subscriptionInfos, func(a, b account.Subscription) int {
return stringutil.CompareLower(a.Name, b.Name)
})

// The default value is based on AZURE_SUBSCRIPTION_ID, falling back to whatever default subscription in
// set in azd's config.
defaultSubscriptionId := os.Getenv(environment.SubscriptionIdEnvVarName)
if defaultSubscriptionId == "" {
defaultSubscriptionId = p.accountManager.GetDefaultSubscriptionID(ctx)
}

var subscriptionOptions = make([]string, len(subscriptionInfos))
var subscriptions = make([]string, len(subscriptionInfos))
var defaultSubscription any

for index, info := range subscriptionInfos {
if v, err := strconv.ParseBool(os.Getenv("AZD_DEMO_MODE")); err == nil && v {
subscriptionOptions[index] = fmt.Sprintf("%2d. %s", index+1, info.Name)
} else {
subscriptionOptions[index] = fmt.Sprintf("%2d. %s (%s)", index+1, info.Name, info.Id)
}

subscriptions[index] = info.Id

if info.Id == defaultSubscriptionId {
defaultSubscription = subscriptionOptions[index]
}
}

return subscriptionOptions, subscriptions, defaultSubscription, nil
}

func (p *DefaultPrompter) IsNoPromptMode() bool {
return p.console.IsNoPromptMode()
}
Loading
Loading