Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ rullm --system "You are a helpful assistant." "Summarize this text"
# List available models (shows only chat models, with your aliases)
rullm models list

# Update model list for all providers with API keys
# Update model list
rullm models update

# Manage aliases
Expand Down Expand Up @@ -129,4 +129,4 @@ source <(COMPLETE=bash ./target/debug/rullm)

# zsh
source <(COMPLETE=zsh ./target/debug/rullm)
```
```
12 changes: 0 additions & 12 deletions crates/rullm-cli/src/cli_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,18 +486,6 @@ impl CliClient {
}
}

/// Get available models for the provider
pub async fn available_models(&self) -> Result<Vec<String>, LlmError> {
match self {
Self::OpenAI { client, .. } => client.list_models().await,
Self::Anthropic { client, .. } => client.list_models().await,
Self::Google { client, .. } => client.list_models().await,
Self::Groq { client, .. } | Self::OpenRouter { client, .. } => {
client.available_models().await
}
}
}

/// Get provider name
pub fn provider_name(&self) -> &'static str {
match self {
Expand Down
2 changes: 1 addition & 1 deletion crates/rullm-cli/src/commands/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ const CHAT_EXAMPLES: &str = r#"EXAMPLES:

const MODELS_EXAMPLES: &str = r#"EXAMPLES:
rullm models list # List cached models
rullm models update -m openai/gpt-4 # Fetch OpenAI models
rullm models update # Fetch latest models
rullm models default openai/gpt-4o # Set default model
rullm models clear # Clear model cache"#;

Expand Down
172 changes: 62 additions & 110 deletions crates/rullm-cli/src/commands/models.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use crate::cli_client::CliClient;
use anyhow::Result;
use chrono::Utc;
use clap::{Args, Subcommand};
use rullm_core::LlmError;
use serde::Deserialize;
use std::collections::HashMap;
use std::time::Duration;
use strum::IntoEnumIterator;

use crate::{
aliases::UserAliasConfig,
args::{Cli, CliConfig},
client,
commands::{ModelsCache, format_duration},
constants::{ALIASES_CONFIG_FILE, MODEL_FILE_NAME},
output::OutputLevel,
Expand All @@ -23,14 +23,14 @@ pub struct ModelsArgs {

#[derive(Subcommand)]
pub enum ModelsAction {
/// List available models for the current provider (default)
/// List cached models
List,
/// Set a default model that will be used when --model is not supplied
Default {
/// Model identifier in the form provider:model-name (e.g. openai:gpt-4o)
model: Option<String>,
},
/// Fetch fresh models from all providers with available API keys and update local cache
/// Fetch fresh models from models.dev and update local cache
Update,
/// Clear the local models cache
Clear,
Expand All @@ -41,7 +41,7 @@ impl ModelsArgs {
&self,
output_level: OutputLevel,
cli_config: &mut CliConfig,
cli: &Cli,
_cli: &Cli,
) -> Result<()> {
match &self.action {
ModelsAction::List => {
Expand All @@ -67,34 +67,28 @@ impl ModelsArgs {
}
}
ModelsAction::Update => {
// List of supported providers
let providers = Provider::iter();
let mut updated = vec![];
let mut skipped = vec![];

for provider in providers {
let provider = format!("{provider}");
// Try to create a client for this provider
let model_hint = format!("{provider}:dummy"); // dummy model name, just to get the client
let client = match client::from_model(&model_hint, cli, cli_config).await {
Ok(c) => c,
Err(_) => {
skipped.push(provider);
continue;
}
};
match update_models(cli_config, &client, output_level).await {
Ok(_) => updated.push(provider),
Err(_) => skipped.push(provider),
}
let supported: Vec<String> = Provider::iter().map(|p| p.to_string()).collect();

crate::output::progress("Fetching models from models.dev...", output_level);

let models = fetch_models_from_models_dev(&supported).await?;
if models.is_empty() {
anyhow::bail!("No models returned by models.dev");
}

if !skipped.is_empty() {
crate::output::note(
&format!("Skipped (no API key or error): {}", skipped.join(", ")),
output_level,
);
let cache = ModelsCache::new(models);
let path = cli_config.data_base_path.join(MODEL_FILE_NAME);

if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}

std::fs::write(&path, serde_json::to_string_pretty(&cache)?)?;

crate::output::success(
&format!("Updated {} models", cache.models.len()),
output_level,
);
}
ModelsAction::Clear => {
clear_models_cache(cli_config, output_level)?;
Expand Down Expand Up @@ -216,85 +210,6 @@ pub fn clear_models_cache(cli_config: &CliConfig, output_level: OutputLevel) ->
Ok(())
}

pub async fn update_models(
cli_config: &mut CliConfig,
client: &CliClient,
output_level: OutputLevel,
) -> Result<(), LlmError> {
crate::output::progress(
&format!(
"Fetching models from {}...",
crate::output::format_provider(client.provider_name())
),
output_level,
);

let mut models = client.available_models().await.map_err(|e| {
crate::output::error(&format!("Failed to fetch models: {e}"), output_level);
e
})?;

if models.is_empty() {
crate::output::error("No models returned by provider", output_level);
return Err(LlmError::model(
"No models returned by provider".to_string(),
));
}

models.sort();
models.dedup();

_cache_models(cli_config, client.provider_name(), &models).map_err(|e| {
crate::output::error(&format!("Failed to update models cache: {e}"), output_level);
LlmError::unknown(e.to_string())
})?;

crate::output::success(
&format!(
"Updated {} models for {}",
models.len(),
client.provider_name()
),
output_level,
);

Ok(())
}

fn _cache_models(cli_config: &CliConfig, provider_name: &str, models: &[String]) -> Result<()> {
use std::fs;

let path = cli_config.data_base_path.join(MODEL_FILE_NAME);
// TODO: we shouldn't need to do this here, this should be done while cli_config is created
// TODO: Remove if we already do this.
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}

// Load existing cache if present
let mut entries = if let Ok(Some(cache)) = load_models_cache(cli_config) {
cache.models
} else {
Vec::new()
};

// Remove all entries for this provider
let prefix = format!("{}:", provider_name.to_lowercase());
entries.retain(|m| !m.starts_with(&prefix));

// Add new models for this provider
let new_entries: Vec<String> = models
.iter()
.map(|m| format!("{}:{}", provider_name.to_lowercase(), m))
.collect();
entries.extend(new_entries);

let cache = ModelsCache::new(entries);
let json = serde_json::to_string_pretty(&cache)?;
fs::write(path, json)?;
Ok(())
}

pub(crate) fn load_models_cache(cli_config: &CliConfig) -> Result<Option<ModelsCache>> {
use std::fs;

Expand All @@ -314,3 +229,40 @@ pub(crate) fn load_models_cache(cli_config: &CliConfig) -> Result<Option<ModelsC
// Old format doesn't have timestamp info
Ok(None)
}

#[derive(Deserialize)]
struct ModelsDevProvider {
models: HashMap<String, ModelsDevModel>,
}

#[derive(Deserialize)]
struct ModelsDevModel {
id: Option<String>,
}

async fn fetch_models_from_models_dev(supported_providers: &[String]) -> Result<Vec<String>> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()?;

let response = client
.get("https://models.dev/api.json")
.send()
.await?
.error_for_status()?;
let providers: HashMap<String, ModelsDevProvider> = response.json().await?;

let mut all_models = Vec::new();
for provider_id in supported_providers {
if let Some(provider) = providers.get(provider_id) {
for (model_id, model) in &provider.models {
let id = model.id.as_deref().unwrap_or(model_id);
all_models.push(format!("{provider_id}:{id}"));
}
}
}

all_models.sort();
all_models.dedup();
Ok(all_models)
}
28 changes: 11 additions & 17 deletions crates/rullm-core/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ match provider.chat_completion(request).await {

- **`provider.chat_completion(request)`** - Send chat completion
- **`provider.health_check()`** - Test API connectivity
- **`provider.available_models()`** - Get supported models
- **`config.validate()`** - Validate configuration

### Supported Models
Expand Down Expand Up @@ -391,7 +390,7 @@ match provider.chat_completion(request).await {

## Test All Providers (`test_all_providers.rs`)

Comprehensive test that validates all LLM providers and their `available_models` functionality:
Comprehensive test that validates all LLM providers with health checks:

```bash
# Set up your API keys
Expand All @@ -405,31 +404,26 @@ cargo run --example test_all_providers

**Features:**
- Tests OpenAI, Anthropic, and Google providers
- Calls `available_models()` for each provider
- Validates expected model patterns
- Performs health checks
- Provides detailed success/failure reporting
- Gracefully handles missing API keys

**Sample Output:**
```
🚀 Testing All LLM Providers and Their Available Models
🚀 Testing All LLM Providers

🔍 Testing OpenAI Provider...
Provider name: openai
Health check: ✅ Passed
Expected model 'gpt-4': ✅ Found
Expected model 'gpt-3.5-turbo': ✅ Found
✅ OpenAI: Found 5 models
✅ OpenAI: Health check passed

📊 SUMMARY:
┌─────────────┬────────┬─────────────
│ Provider │ Status │ Models │
├─────────────┼────────┼─────────────
│ OpenAI │ ✅ Pass │ 5 models │
│ Anthropic │ ✅ Pass │ 5 models │
│ Google │ ✅ Pass │ 5 models │
└─────────────┴────────┴─────────────
┌─────────────┬────────┐
│ Provider │ Status │
├─────────────┼────────┤
│ OpenAI │ ✅ Pass │
│ Anthropic │ ✅ Pass │
│ Google │ ✅ Pass │
└─────────────┴────────┘

🎉 All providers are working correctly!
```
Expand All @@ -438,4 +432,4 @@ Use this example for:
- Verifying your API keys work
- Testing network connectivity
- Validating provider implementations
- CI/CD pipeline health checks
- CI/CD pipeline health checks
12 changes: 1 addition & 11 deletions crates/rullm-core/examples/google_simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
}

// 7. List models
println!("\n📋 Available models:");
let models = client.list_models().await?;
for (i, model) in models.iter().take(5).enumerate() {
println!(" {}. {}", i + 1, model);
}
if models.len() > 5 {
println!(" ... and {} more", models.len() - 5);
}

// 8. Health check
// 7. Health check
match client.health_check().await {
Ok(_) => println!("\n✅ Google AI is healthy"),
Err(e) => println!("\n❌ Health check failed: {e}"),
Expand Down
14 changes: 0 additions & 14 deletions crates/rullm-core/examples/openai_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Err(e) => println!(" ❌ Health check failed: {e}"),
}

// Get available models
match client.list_models().await {
Ok(models) => {
println!(" Available models (first 5):");
for (i, model) in models.iter().take(5).enumerate() {
println!(" {}. {}", i + 1, model);
}
if models.len() > 5 {
println!(" ... and {} more", models.len() - 5);
}
}
Err(e) => println!(" ❌ Error getting models: {e}"),
}

// Make a simple request
println!("\n Testing chat completion...");
let mut test_request = ChatCompletionRequest::new(
Expand Down
14 changes: 0 additions & 14 deletions crates/rullm-core/examples/openai_conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Configure OpenAI client from environment
let client = OpenAIClient::from_env()?;

// Check available models
println!("Available models:");
match client.list_models().await {
Ok(models) => {
for (i, model) in models.iter().take(10).enumerate() {
println!(" {}. {}", i + 1, model);
}
if models.len() > 10 {
println!(" ... and {} more", models.len() - 10);
}
}
Err(e) => println!("Error getting models: {e}"),
}

// Health check
match client.health_check().await {
Ok(_) => println!("✅ Client is healthy\n"),
Expand Down
Loading