Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ fn main() {
.build()
.unwrap();

let mut client = client::Client::new(config);
let client = client::Client::new(config);
let request = chat::CreateChatCompletionRequestArgs::default()
.messages([
chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(),
Expand Down
2 changes: 1 addition & 1 deletion examples/wrr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ fn main() {
.build()
.unwrap();

let mut client = client::Client::new(config);
let client = client::Client::new(config);
let request = chat::CreateChatCompletionRequestArgs::default()
.messages([
chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(),
Expand Down
4 changes: 2 additions & 2 deletions src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl Client {
}

pub async fn create_response(
&mut self,
&self,
request: responses::CreateResponse,
) -> Result<responses::Response, OpenAIError> {
let candidate = self.router.sample();
Expand All @@ -39,7 +39,7 @@ impl Client {

// This is chat completion endpoint.
pub async fn create_completion(
&mut self,
&self,
request: chat::CreateChatCompletionRequest,
) -> Result<chat::CreateChatCompletionResponse, OpenAIError> {
let candidate = self.router.sample();
Expand Down
2 changes: 1 addition & 1 deletion src/router/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl Router for RandomRouter {
"RandomRouter"
}

fn sample(&mut self) -> ModelName {
fn sample(&self) -> ModelName {
let mut rng = rand::rng();
let idx = rng.random_range(0..self.model_infos.len());
self.model_infos[idx].name.clone()
Expand Down
2 changes: 1 addition & 1 deletion src/router/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub fn construct_router(mode: RouterMode, models: Vec<ModelConfig>) -> Box<dyn R

pub trait Router {
fn name(&self) -> &'static str;
fn sample(&mut self) -> ModelName;
fn sample(&self) -> ModelName;
}

#[cfg(test)]
Expand Down
25 changes: 14 additions & 11 deletions src/router/wrr.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::sync::atomic::AtomicI32;

use crate::client::config::ModelName;
use crate::router::router::{ModelInfo, Router};

pub struct WeightedRoundRobinRouter {
total_weight: i32,
model_infos: Vec<ModelInfo>,
// current_weight is ordered by model_infos index.
current_weights: Vec<i32>,
current_weights: Vec<AtomicI32>,
}

impl WeightedRoundRobinRouter {
Expand All @@ -16,7 +18,7 @@ impl WeightedRoundRobinRouter {
Self {
model_infos: model_infos,
total_weight: total_weight,
current_weights: vec![0; length],
current_weights: (0..length).map(|_| AtomicI32::new(0)).collect(),
}
}
}
Expand All @@ -27,27 +29,28 @@ impl Router for WeightedRoundRobinRouter {
}

// Use Smooth Weighted Round Robin Algorithm.
fn sample(&mut self) -> ModelName {
fn sample(&self) -> ModelName {
// return early if only one model.
if self.model_infos.len() == 1 {
return self.model_infos[0].name.clone();
}

self.current_weights
.iter_mut()
.enumerate()
.for_each(|(i, weight)| {
*weight += self.model_infos[i].weight;
});
// 1. add weight to current weight.
self.model_infos.iter().enumerate().for_each(|(i, weight)| {
self.current_weights[i].fetch_add(weight.weight, std::sync::atomic::Ordering::Relaxed);
});

let mut max_index = 0;
for i in 1..self.current_weights.len() {
if self.current_weights[i] > self.current_weights[max_index] {
if self.current_weights[i].load(std::sync::atomic::Ordering::Relaxed)
> self.current_weights[max_index].load(std::sync::atomic::Ordering::Relaxed)
{
max_index = i;
}
}

self.current_weights[max_index] -= self.total_weight;
self.current_weights[max_index]
.fetch_sub(self.total_weight, std::sync::atomic::Ordering::Relaxed);
self.model_infos[max_index].name.clone()
}
}
Expand Down
56 changes: 28 additions & 28 deletions tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,31 @@ use arms::types::responses;
mod tests {
use super::*;

#[tokio::test]
async fn test_completion() {
from_filename(".env.integration-test").ok();

let config = client::Config::builder()
.provider("faker")
.model(
client::ModelConfig::builder()
.name("fake-completion-model")
.build()
.unwrap(),
)
.build()
.unwrap();

let client = client::Client::new(config);
let request = chat::CreateChatCompletionRequestArgs::default()
.build()
.unwrap();

let response = client.create_completion(request).await.unwrap();
assert!(response.id.starts_with("fake-completion-id"));
assert!(response.model == "fake-completion-model");
}

#[tokio::test]
async fn test_response() {
from_filename(".env.integration-test").ok();
Expand All @@ -24,7 +49,7 @@ mod tests {
.build()
.unwrap();

let mut client = client::Client::new(config);
let client = client::Client::new(config);
let request = responses::CreateResponseArgs::default()
.input("tell me the weather today")
.build()
Expand All @@ -45,7 +70,7 @@ mod tests {
)
.build()
.unwrap();
let mut client = client::Client::new(config);
let client = client::Client::new(config);
let request = responses::CreateResponseArgs::default()
.model("gpt-3.5-turbo")
.input("tell me a joke")
Expand Down Expand Up @@ -74,36 +99,11 @@ mod tests {
)
.build()
.unwrap();
let mut client = client::Client::new(config);
let client = client::Client::new(config);
let request = responses::CreateResponseArgs::default()
.input("give me a poem about nature")
.build()
.unwrap();
let _ = client.create_response(request).await.unwrap();
}

#[tokio::test]
async fn test_completion() {
from_filename(".env.integration-test").ok();

let config = client::Config::builder()
.provider("faker")
.model(
client::ModelConfig::builder()
.name("fake-completion-model")
.build()
.unwrap(),
)
.build()
.unwrap();

let mut client = client::Client::new(config);
let request = chat::CreateChatCompletionRequestArgs::default()
.build()
.unwrap();

let response = client.create_completion(request).await.unwrap();
assert!(response.id.starts_with("fake-completion-id"));
assert!(response.model == "fake-completion-model");
}
}
Loading