@@ -16,6 +16,7 @@ import (
1616 "github.com/MakeNowJust/heredoc"
1717 "github.com/briandowns/spinner"
1818 "github.com/github/gh-models/internal/azuremodels"
19+ "github.com/github/gh-models/internal/modelkey"
1920 "github.com/github/gh-models/internal/sse"
2021 "github.com/github/gh-models/pkg/command"
2122 "github.com/github/gh-models/pkg/prompt"
@@ -513,9 +514,21 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st
513514 return "" , errors .New (noMatchErrorMessage )
514515 }
515516
517+ parsedModel , err := modelkey .ParseModelKey (modelName )
518+ if err != nil {
519+ return "" , fmt .Errorf ("invalid model format: %w" , err )
520+ }
521+
522+ if parsedModel .Provider == "custom" {
523+ // Skip validation for custom provider
524+ return parsedModel .String (), nil
525+ }
526+
527+ // For non-custom providers, validate the model exists
528+ expectedModelID := azuremodels .FormatIdentifier (parsedModel .Publisher , parsedModel .ModelName )
516529 foundMatch := false
517530 for _ , model := range models {
518- if model .HasName (modelName ) {
531+ if model .HasName (expectedModelID ) {
519532 foundMatch = true
520533 break
521534 }
@@ -525,7 +538,7 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st
525538 return "" , errors .New (noMatchErrorMessage )
526539 }
527540
528- return modelName , nil
541+ return expectedModelID , nil
529542}
530543
531544func (h * runCommandHandler ) getChatCompletionStreamReader (req azuremodels.ChatCompletionOptions , org string ) (sse.Reader [azuremodels.ChatCompletion ], error ) {
0 commit comments