Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.integration-test
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
AMRS_API_KEY=your_amrs_api_key_here
OPENAI_API_KEY=your_openai_api_key_here
FAKE_API_KEY=your_fake_api_key_here
FAKER_API_KEY=your_faker_api_key_here
59 changes: 59 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ lazy_static = "1.5.0"
rand = "0.9.2"
reqwest = "0.12.26"
serde = "1.0.228"
tokio = "1.48.0"
tokio = { version = "1.48.0", features = ["full"] }
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Thanks to [async-openai](https://github.com/64bit/async-openai), AMRS builds on
- Flexible routing strategies, including:
- **Random**: Randomly selects a model from the available models.
- **WRR**: Weighted Round Robin selects models based on predefined weights.
- **UCB**: Upper Confidence Bound based model selection (coming soon).
- **UCB1**: Upper Confidence Bound based model selection (coming soon).
- **Adaptive**: Dynamically selects models based on performance metrics (coming soon).

- Broad provider support:
Expand All @@ -27,30 +27,31 @@ Here's a simple example with the Weighted Round Robin (WRR) routing mode:
// Before running the code, make sure to set your OpenAI API key in the environment variable:
// export OPENAI_API_KEY="your_openai_api_key"

use arms::{Client, Config, ModelConfig, CreateResponseArgs, RoutingMode};
use arms::client;
use arms::types::responses;

let config = Config::builder()
let config = client::Config::builder()
.provider("openai")
.routing_mode(RoutingMode::WRR)
.routing_mode(client::RoutingMode::WRR)
.model(
ModelConfig::builder()
.id("gpt-3.5-turbo")
client::ModelConfig::builder()
.name("gpt-3.5-turbo")
.weight(2)
.build()
.unwrap(),
)
.model(
ModelConfig::builder()
.id("gpt-4")
client::ModelConfig::builder()
.name("gpt-4")
.weight(1)
.build()
.unwrap(),
)
.build()
.unwrap();

let mut client = Client::new(config);
let request = CreateResponseArgs::default()
let mut client = client::Client::new(config);
let request = responses::CreateResponseArgs::default()
.input("give me a poem about nature")
.build()
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/amrs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ class BasicModelConfig(BaseModel):
)


type ModelID = str
type ModelName = str

class ModelConfig(BasicModelConfig):
id: ModelID = Field(
id: ModelName = Field(
description="ID of the model to be used."
)
weight: Optional[int] = Field(
Expand Down
6 changes: 3 additions & 3 deletions bindings/python/amrs/router/random.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import random

from amrs.config import ModelID
from amrs.config import ModelName
from amrs.router.router import Router

class RandomRouter(Router):
def __init__(self, model_list: list[ModelID]):
def __init__(self, model_list: list[ModelName]):
super().__init__(model_list)

def sample(self, _: str) -> ModelID:
def sample(self, _: str) -> ModelName:
return random.choice(self._model_list)
4 changes: 2 additions & 2 deletions bindings/python/amrs/router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ class ModelInfo:
average_latency: float = 0.0

class Router(abc.ABC):
def __init__(self, model_list: list[config.ModelID]):
def __init__(self, model_list: list[config.ModelName]):
self._model_list = model_list

@abc.abstractmethod
def sample(self, content: str) -> config.ModelID:
def sample(self, content: str) -> config.ModelName:
pass

def new_router(model_cfgs: list[config.ModelConfig], mode: config.RoutingMode) -> Router:
Expand Down
32 changes: 17 additions & 15 deletions src/client/client.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::collections::HashMap;

use crate::config::{Config, ModelId};
use crate::client::config::{Config, ModelName};
use crate::provider::provider;
use crate::router::router;
use crate::types::error::OpenAIError;
use crate::types::responses::{CreateResponse, Response};

pub struct Client {
providers: HashMap<ModelId, Box<dyn provider::Provider>>,
providers: HashMap<ModelName, Box<dyn provider::Provider>>,
router: Box<dyn router::Router>,
}

Expand All @@ -17,7 +19,7 @@ impl Client {
let providers = cfg
.models
.iter()
.map(|m| (m.id.clone(), provider::construct_provider(m.clone())))
.map(|m| (m.name.clone(), provider::construct_provider(m.clone())))
.collect();

Self {
Expand All @@ -28,18 +30,18 @@ impl Client {

pub async fn create_response(
&mut self,
request: provider::CreateResponseReq,
) -> Result<provider::CreateResponseRes, provider::APIError> {
let model_id = self.router.sample(&request);
let provider = self.providers.get(&model_id).unwrap();
request: CreateResponse,
) -> Result<Response, OpenAIError> {
let candidate = self.router.sample(&request);
let provider = self.providers.get(&candidate).unwrap();
provider.create_response(request).await
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::config::{Config, ModelConfig, RoutingMode};
use crate::client::config::{Config, ModelConfig, RoutingMode};
use dotenvy::from_filename;

#[test]
Expand All @@ -58,7 +60,7 @@ mod tests {
config: Config::builder()
.models(vec![
ModelConfig::builder()
.id("model_c".to_string())
.name("model_c".to_string())
.build()
.unwrap(),
])
Expand All @@ -71,15 +73,15 @@ mod tests {
config: Config::builder()
.routing_mode(RoutingMode::WRR)
.models(vec![
crate::config::ModelConfig::builder()
.id("model_a".to_string())
crate::client::config::ModelConfig::builder()
.name("model_a".to_string())
.provider(Some("openai".to_string()))
.base_url(Some("https://api.openai.com/v1".to_string()))
.weight(1)
.build()
.unwrap(),
crate::config::ModelConfig::builder()
.id("model_b".to_string())
crate::client::config::ModelConfig::builder()
.name("model_b".to_string())
.provider(Some("openai".to_string()))
.base_url(Some("https://api.openai.com/v1".to_string()))
.weight(3)
Expand All @@ -95,13 +97,13 @@ mod tests {
config: Config::builder()
.models(vec![
ModelConfig::builder()
.id("model_a".to_string())
.name("model_a".to_string())
.provider(Some("openai".to_string()))
.base_url(Some("https://api.openai.com/v1".to_string()))
.build()
.unwrap(),
ModelConfig::builder()
.id("model_b".to_string())
.name("model_b".to_string())
.provider(Some("openai".to_string()))
.base_url(Some("https://api.openai.com/v1".to_string()))
.build()
Expand Down
Loading
Loading