Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ cc_binary(
srcs = ["main.cc"],
visibility = ["//visibility:public"],
deps = [
":chat",
":fetch",
":llms",
":tui",
Expand All @@ -21,6 +22,17 @@ cc_binary(
],
)

cc_library(
name = "chat",
srcs = ["chat.cc"],
hdrs = ["chat.h"],
deps = [
":event_loop",
"@abseil-cpp//absl/functional:any_invocable",
"@abseil-cpp//absl/synchronization",
],
)

cc_library(
name = "fetch",
srcs = ["fetch.cc"],
Expand All @@ -36,6 +48,17 @@ cc_library(
],
)

cc_library(name="event_loop",
srcs = ["event_loop.cc"],
hdrs = ["event_loop.h"],
visibility = ["//visibility:public"],
deps = [
"@abseil-cpp//absl/functional:any_invocable",
"@abseil-cpp//absl/synchronization",
"@abseil-cpp//absl/log",
],
)

cc_library(
name = "json_decode",
srcs = ["json_decode.cc"],
Expand All @@ -52,6 +75,7 @@ cc_library(
name = "llms",
srcs = [
"anthropic.cc",
"model.cc",
"openai.cc",
],
hdrs = [
Expand All @@ -61,9 +85,11 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
":chat",
":fetch",
":json_decode",
"@abseil-cpp//absl/flags:flag",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings",
"@nlohmann_json//:json",
Expand Down
66 changes: 30 additions & 36 deletions src/anthropic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
#include <optional>
#include <string>
#include <string_view>
#include <unordered_map>
#include <utility>
#include <vector>

#include "absl/flags/flag.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"

#include "nlohmann/json.hpp"
#include "src/chat.h"
#include "src/fetch.h"
#include "src/json_decode.h"
#include "src/model.h"
Expand All @@ -27,47 +30,39 @@ namespace {

class AnthropicModel : public Model {
public:
AnthropicModel(std::string_view model,
std::string_view api_key, int max_tokens)
: model_(model),
AnthropicModel(std::string model, std::string_view api_key, int max_tokens,
std::shared_ptr<Fetch> fetch)
: Model(std::move(model)),
api_key_(api_key),
max_tokens_(max_tokens) {}
max_tokens_(max_tokens),
fetch_(std::move(fetch)) {}
~AnthropicModel() override = default;

std::string_view name() const override { return model_; }

absl::StatusOr<std::string> Prompt(
const Fetch& fetch, std::string_view prompt,
absl::Span<const std::string_view> input_contents) override;

private:
std::string model_;
absl::StatusOr<std::string> Send(const Message& message) override;

std::string api_key_;
int max_tokens_;
std::shared_ptr<Fetch> fetch_;
std::unordered_map<Chat*, std::unique_ptr<Chat::Unsubscribe>> subscriptions_;
};

absl::StatusOr<std::string> AnthropicModel::Prompt(
const Fetch& fetch, std::string_view prompt,
absl::Span<const std::string_view> input_contents) {
std::string combined_input = absl::StrJoin(input_contents, "\n\n");

absl::StatusOr<std::string> AnthropicModel::Send(const Message& message) {
nlohmann::json request = {
{"model", model_},
{"model", name()},
{"max_tokens", max_tokens_},
{"messages",
nlohmann::json::array(
{{{"role", "user"},
{"content", absl::StrCat(prompt, "\n\n", combined_input)}}})},
{"messages", nlohmann::json::array(
{{{"role", "user"}, {"content", message.content()}}})},
};

auto response =
fetch.Post("https://api.anthropic.com/v1/messages",
{
{.key = "Content-Type", .value = "application/json"},
{.key = "x-api-key", .value = api_key_},
{.key = "anthropic-version", .value = "2023-06-01"},
},
request);
fetch_->Post("https://api.anthropic.com/v1/messages",
{
{.key = "Content-Type", .value = "application/json"},
{.key = "x-api-key", .value = api_key_},
{.key = "anthropic-version", .value = "2023-06-01"},
},
request);

if (!response.ok()) {
return std::move(response).status();
Expand All @@ -83,13 +78,12 @@ absl::StatusOr<std::string> AnthropicModel::Prompt(
(*json_response)["error"].dump(2)));
}

auto message =
json::JsonDecode(*json_response)["content"][0]["text"].String();
if (!message.ok()) {
auto res = json::JsonDecode(*json_response)["content"][0]["text"].String();
if (!res.ok()) {
return absl::InternalError(
absl::StrCat("Anthropic API error: ", message.error()));
absl::StrCat("Anthropic API error: ", res.error()));
}
return message.value();
return res.value();
}

class AnthropicModelProvider : public ModelProvider {
Expand All @@ -106,8 +100,8 @@ class AnthropicModelProvider : public ModelProvider {
if (!api_key.has_value()) {
return absl::InvalidArgumentError("Anthropic API key is required");
}
auto client = std::make_unique<AnthropicModel>(model, *api_key,
parameters_.max_tokens());
auto client = std::make_unique<AnthropicModel>(
std::string(model), *api_key, parameters_.max_tokens(), fetch_);
return ModelHandle(std::move(client));
}

Expand Down
42 changes: 42 additions & 0 deletions src/chat.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "src/chat.h"

#include <utility>

namespace uchen::chat {

std::optional<Message> Chat::FindMessage(int id) const {
absl::MutexLock lock(&message_mutex_);
auto it = messages_.find(id);
if (it != messages_.end()) {
return it->second;
}
return std::nullopt;
}

Message Chat::SendMessage(Message::Origin origin, std::string content,
std::optional<int> parent_id, void* provider) {
absl::MutexLock lock(&message_mutex_);
int id = next_id_++;
auto result = messages_.emplace(
id, Message(id, origin, std::move(content), parent_id, provider));
event_loop_->Run(
[chat = shared_from_this(), message = result.first->second]() {
absl::MutexLock lock(&chat->callback_mutex_);
for (const auto& [_, callback] : chat->callbacks_) {
callback(message);
}
});
return result.first->second;
}

std::unique_ptr<Chat::Unsubscribe> Chat::Subscribe(Callback callback) {
size_t key = next_id_++;
event_loop_->Run([chat = shared_from_this(), key,
callback = std::move(callback)]() mutable {
absl::MutexLock lock(&chat->callback_mutex_);
chat->callbacks_.emplace(key, std::move(callback));
});
return std::make_unique<Unsubscribe>(this, key);
}

} // namespace uchen::chat
90 changes: 90 additions & 0 deletions src/chat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#ifndef SRC_CHAT_H_
#define SRC_CHAT_H_

#include <atomic>
#include <cstddef>
#include <memory>
#include <unordered_map>

#include "absl/functional/any_invocable.h"
#include "absl/synchronization/mutex.h"

#include "src/event_loop.h"

namespace uchen::chat {

class Message {
public:
enum class Origin { kAssistant, kSystem, kUser };

Message() = default;
Message(int id, Origin origin, std::string content,
std::optional<int> parent_id, void* provider)
: id_(id),
origin_(origin),
content_(std::move(content)),
parent_id_(parent_id),
provider_(provider) {}

Message(const Message&) = default;
Message& operator=(const Message&) = default;
Message(Message&&) = default;
Message& operator=(Message&&) = default;

bool operator==(const Message& other) const = default;

int id() const { return id_; }
Origin origin() const { return origin_; }
const std::string& content() const { return content_; }
std::optional<int> parent_id() const { return parent_id_; }
void* provider() const { return provider_; }

private:
int id_;
Origin origin_;
std::string content_;
std::optional<int> parent_id_;
void* provider_;
};

class Chat : public std::enable_shared_from_this<Chat> {
public:
using Callback = absl::AnyInvocable<void(const Message&) const>;

class Unsubscribe {
public:
Unsubscribe(Chat* chat, size_t id) : chat_(chat), id_(id) {}

~Unsubscribe() { chat_->callbacks_.erase(id_); }

private:
Chat* chat_;
size_t id_;
};

static std::shared_ptr<Chat> Create(std::shared_ptr<EventLoop> event_loop) {
// Can't use std::make_shared because ctor is private.
return std::shared_ptr<Chat>(new Chat(std::move(event_loop)));
}

std::optional<Message> FindMessage(int id) const;
Message SendMessage(Message::Origin origin, std::string content,
std::optional<int> parent_id, void* provider);
std::unique_ptr<Unsubscribe> Subscribe(Callback callback);

private:
explicit Chat(std::shared_ptr<EventLoop> event_loop)
: event_loop_(std::move(event_loop)) {}
mutable absl::Mutex message_mutex_;
mutable absl::Mutex callback_mutex_;
std::shared_ptr<EventLoop> event_loop_;
std::atomic_int next_id_{1};
std::unordered_map<size_t, Callback> callbacks_
ABSL_GUARDED_BY(&callback_mutex_);
std::unordered_map<size_t, Message> messages_
ABSL_GUARDED_BY(&message_mutex_);
};

} // namespace uchen::chat

#endif // SRC_CHAT_H_
41 changes: 41 additions & 0 deletions src/event_loop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include "src/event_loop.h"

#include <utility>
#include <vector>

namespace uchen::chat {

void EventLoop::Run(absl::AnyInvocable<void() > task) {
absl::MutexLock lock(&mutex_);
tasks_.push_back(std::move(task));
}

void EventLoop::Loop(EventLoop* event_loop) {
while (true) {
auto done_tasks = event_loop->GetTasks();
if (std::holds_alternative<bool>(done_tasks)) {
break;
}
for (auto& task : std::get<TasksList>(done_tasks)) {
task();
}
}
}

std::variant<bool, std::vector<absl::AnyInvocable<void()>>>
EventLoop::GetTasks() {
absl::MutexLock lock(&mutex_);
absl::Condition condition(
+[](EventLoop* event_loop)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(event_loop->mutex_) {
return !event_loop->tasks_.empty() || event_loop->stop_;
},
this);
mutex_.Await(condition);
if (stop_) {
return true;
}
return std::move(tasks_);
}

} // namespace uchen::chat
Loading