Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Runtime/LLMAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ public virtual async Task<string> Chat(string query, Action<string> callback = n
}

SetCompletionParameters();
result = await llmAgent.ChatAsync(query, addToHistory, wrappedCallback, false, debugPrompt);
result = await llmAgent.ChatAsync(query, addToHistory, wrappedCallback, true, debugPrompt);
result = ParseCompletionResponse(result);
if (this == null) return null;
if (addToHistory && result != null && save != "") _ = SaveHistory();
if (this != null) completionCallback?.Invoke();
Expand Down
43 changes: 42 additions & 1 deletion Runtime/LLMClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
/// @brief File implementing the base LLM client functionality for Unity.
using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Threading.Tasks;
using UndreamAI.LlamaLib;
Expand Down Expand Up @@ -193,6 +194,12 @@ public string grammar
set => SetGrammar(value);
}

/// <summary>Generated tokens per second from the most recent completed request, or -1 if unavailable</summary>
public float TokensPerSecond { get; protected set; } = -1f;

/// <summary>Prompt tokens per second from the most recent completed request, or -1 if unavailable</summary>
public float PromptTokensPerSecond { get; protected set; } = -1f;

#endregion

#region Private Fields
Expand Down Expand Up @@ -498,6 +505,39 @@ protected virtual void SetCompletionParameters()
}
}

/// <summary>
/// Extracts plain text and timing data from a completion response payload.
/// Falls back to the raw response if the payload is not valid JSON.
/// </summary>
/// <param name="response">Raw response payload returned by LlamaLib</param>
/// <returns>Assistant content as plain text</returns>
protected virtual string ParseCompletionResponse(string response)
{
TokensPerSecond = -1f;
PromptTokensPerSecond = -1f;

if (string.IsNullOrEmpty(response)) return response ?? string.Empty;

try
{
JObject json = JObject.Parse(response);
TokensPerSecond = ParsePositiveTimingValue(json["timings"]?["predicted_per_second"]);
PromptTokensPerSecond = ParsePositiveTimingValue(json["timings"]?["prompt_per_second"]);
return json["content"]?.ToString() ?? string.Empty;
}
catch
{
return response;
}
}

protected float ParsePositiveTimingValue(JToken token)
{
if (token == null) return -1f;
if (!float.TryParse(token.ToString(), NumberStyles.Float, CultureInfo.InvariantCulture, out float value)) return -1f;
return value > 0f ? value : -1f;
}

#endregion

#region Core LLM Operations
Expand Down Expand Up @@ -593,7 +633,8 @@ public virtual async Task<string> Completion(string prompt, Action<string> callb
}

SetCompletionParameters();
string result = await llmClient.CompletionAsync(prompt, wrappedCallback, id_slot);
string result = await llmClient.CompletionAsync(prompt, wrappedCallback, id_slot, true);
result = ParseCompletionResponse(result);
completionCallback?.Invoke();
return result;
}
Expand Down
12 changes: 6 additions & 6 deletions Runtime/LlamaLib/LLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -183,23 +183,23 @@ public void CheckCompletionInternal(string prompt)
CheckLlamaLib();
}

public string CompletionInternal(string prompt, LlamaLib.CharArrayCallback callback, int idSlot)
public string CompletionInternal(string prompt, LlamaLib.CharArrayCallback callback, int idSlot, bool returnResponseJson = false)
{
IntPtr result;
result = llamaLib.LLM_Completion(llm, prompt ?? string.Empty, callback, idSlot);
result = llamaLib.LLM_Completion(llm, prompt ?? string.Empty, callback, idSlot, returnResponseJson);
return Marshal.PtrToStringAnsi(result) ?? string.Empty;
}

public string Completion(string prompt, LlamaLib.CharArrayCallback callback = null, int idSlot = -1)
public string Completion(string prompt, LlamaLib.CharArrayCallback callback = null, int idSlot = -1, bool returnResponseJson = false)
{
CheckCompletionInternal(prompt);
return CompletionInternal(prompt, callback, idSlot);
return CompletionInternal(prompt, callback, idSlot, returnResponseJson);
}

public async Task<string> CompletionAsync(string prompt, LlamaLib.CharArrayCallback callback = null, int idSlot = -1)
public async Task<string> CompletionAsync(string prompt, LlamaLib.CharArrayCallback callback = null, int idSlot = -1, bool returnResponseJson = false)
{
CheckCompletionInternal(prompt);
return await Task.Run(() => CompletionInternal(prompt, callback, idSlot));
return await Task.Run(() => CompletionInternal(prompt, callback, idSlot, returnResponseJson));
}
}

Expand Down
81 changes: 81 additions & 0 deletions Tests/Editor/TestLLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,64 @@ public void TestLoras()
}
}

public class TestLLMClient_ResponseParsing
{
private class TestableLLMClient : LLMClient
{
public string Parse(string response)
{
return ParseCompletionResponse(response);
}
}

[Test]
public void TestParseCompletionResponseJson()
{
GameObject gameObject = new GameObject();
var llmClient = gameObject.AddComponent<TestableLLMClient>();

string result = llmClient.Parse("{\"content\":\"Hello!\",\"timings\":{\"predicted_per_second\":148.62,\"prompt_per_second\":345.12}}");

Assert.AreEqual("Hello!", result);
Assert.AreEqual(148.62f, llmClient.TokensPerSecond);
Assert.AreEqual(345.12f, llmClient.PromptTokensPerSecond);

UnityEngine.Object.DestroyImmediate(gameObject);
}

[Test]
public void TestParseCompletionResponseFallback()
{
GameObject gameObject = new GameObject();
var llmClient = gameObject.AddComponent<TestableLLMClient>();

string result = llmClient.Parse("plain text response");

Assert.AreEqual("plain text response", result);
Assert.AreEqual(-1f, llmClient.TokensPerSecond);
Assert.AreEqual(-1f, llmClient.PromptTokensPerSecond);

UnityEngine.Object.DestroyImmediate(gameObject);
}

[Test]
public void TestParseCompletionResponseResetsStaleMetrics()
{
GameObject gameObject = new GameObject();
var llmClient = gameObject.AddComponent<TestableLLMClient>();

string firstResult = llmClient.Parse("{\"content\":\"Hello!\",\"timings\":{\"predicted_per_second\":148.62,\"prompt_per_second\":345.12}}");
string secondResult = llmClient.Parse("{\"content\":\"No timings this time\"}");

Assert.AreEqual("Hello!", firstResult);
Assert.AreEqual("No timings this time", secondResult);
Assert.AreEqual(-1f, llmClient.TokensPerSecond);
Assert.AreEqual(-1f, llmClient.PromptTokensPerSecond);

UnityEngine.Object.DestroyImmediate(gameObject);
}
}

public class TestLLM
{
protected string modelNameLLManager;
Expand Down Expand Up @@ -296,21 +354,32 @@ public virtual async Task Tests()
await llmAgent.Tokenize("I", TestTokens);
await llmAgent.Warmup();
TestPostChat(0);
TestWarmupTimings();

string reply = await llmAgent.Chat(query);
TestChat(reply, reply1);
TestInferenceTimings();
TestPostChat(2);

llmAgent.systemPrompt = prompt2;
reply = await llmAgent.Chat(query, TestStreamingChat);
TestChat(reply, reply2);
TestInferenceTimings();
TestPostChat(4);

string completion = await llmAgent.Completion("The cat is away");
Assert.That(!string.IsNullOrWhiteSpace(completion));
TestInferenceTimings();

await llmAgent.ClearHistory();
TestPostChat(0);

await llmAgent.Chat("bye!");
TestInferenceTimings();
TestPostChat(2);

await llmAgent.Warmup();
TestWarmupTimings();
}

public virtual void TestArchitecture()
Expand All @@ -328,6 +397,18 @@ public void TestStreamingChat(string reply)
Assert.That(reply != "");
}

public void TestInferenceTimings()
{
Assert.That(llmAgent.TokensPerSecond > 0f);
Assert.That(llmAgent.PromptTokensPerSecond == -1f || llmAgent.PromptTokensPerSecond > 0f);
}

public void TestWarmupTimings()
{
Assert.AreEqual(-1f, llmAgent.TokensPerSecond);
Assert.That(llmAgent.PromptTokensPerSecond == -1f || llmAgent.PromptTokensPerSecond > 0f);
}

public void TestChat(string reply, string replyGT)
{
Debug.Log(reply.Trim());
Expand Down