Skip to content
Merged
4 changes: 2 additions & 2 deletions .github/workflows/ci-rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
diff: true
diff-branch: main
diff-storage: _xml_coverage_reports
uncovered-statements-increase-failure: true
new-uncovered-statements-failure: true
uncovered-statements-increase-failure: true # DO NOT CHANGE THIS, ADD TESTS
new-uncovered-statements-failure: true # DO NOT CHANGE THIS, ADD TESTS
coverage-rate-reduction-failure: true
togglable-report: true
107 changes: 8 additions & 99 deletions rsworkspace/crates/acp-nats/src/client/fs_read_text_file.rs
Original file line number Diff line number Diff line change
@@ -1,72 +1,14 @@
use crate::client::rpc_reply;
use crate::jsonrpc::extract_request_id;
use crate::nats::{FlushClient, PublishClient, headers_with_trace_context};
use crate::nats::{FlushClient, PublishClient};
use agent_client_protocol::{
Client, Error, ErrorCode, ReadTextFileRequest, ReadTextFileResponse, Request, RequestId,
Response,
Client, ErrorCode, ReadTextFileRequest, ReadTextFileResponse, Request, Response,
};
use bytes::Bytes;
use serde::de::Error as SerdeDeError;
use tracing::{instrument, warn};
use trogon_std::JsonSerialize;

const CONTENT_TYPE_JSON: &str = "application/json";
const CONTENT_TYPE_PLAIN: &str = "text/plain";

fn error_response_fallback_bytes<S: JsonSerialize>(serializer: &S) -> (Bytes, &'static str) {
match serializer.to_vec(&Response::<()>::Error {
id: RequestId::Null,
error: Error::new(-32603, "Internal error"),
}) {
Ok(v) => (Bytes::from(v), CONTENT_TYPE_JSON),
Err(e) => {
warn!(
error = %e,
"Fallback JSON serialization failed, response may not be valid JSON-RPC"
);
(Bytes::from("Internal error"), CONTENT_TYPE_PLAIN)
}
}
}

async fn publish_reply<N: PublishClient + FlushClient>(
nats: &N,
reply_to: &str,
bytes: Bytes,
content_type: &str,
context: &str,
) {
let mut headers = headers_with_trace_context();
headers.insert("Content-Type", content_type);
if let Err(e) = nats
.publish_with_headers(reply_to.to_string(), headers, bytes)
.await
{
warn!(error = %e, "Failed to publish {}", context);
}
if let Err(e) = nats.flush().await {
warn!(error = %e, "Failed to flush {}", context);
}
}

fn error_response_bytes<S: JsonSerialize>(
serializer: &S,
request_id: RequestId,
code: ErrorCode,
message: &str,
) -> (Bytes, &'static str) {
let response = Response::<()>::Error {
id: request_id,
error: Error::new(i32::from(code), message),
};
match serializer.to_vec(&response) {
Ok(v) => (Bytes::from(v), CONTENT_TYPE_JSON),
Err(e) => {
warn!(error = %e, "JSON serialization failed, using fallback error");
error_response_fallback_bytes(serializer)
}
}
}

#[derive(Debug)]
pub enum FsReadTextFileError {
InvalidRequest(serde_json::Error),
Expand Down Expand Up @@ -134,17 +76,17 @@ pub async fn handle<N: PublishClient + FlushClient, C: Client, S: JsonSerialize>
id: request_id.clone(),
result: response,
})
.map(|v| (Bytes::from(v), CONTENT_TYPE_JSON))
.map(|v| (Bytes::from(v), rpc_reply::CONTENT_TYPE_JSON))
.unwrap_or_else(|e| {
warn!(error = %e, "JSON serialization of response failed, sending error reply");
error_response_bytes(
rpc_reply::error_response_bytes(
serializer,
request_id,
ErrorCode::InternalError,
&format!("Failed to serialize response: {}", e),
)
});
publish_reply(
rpc_reply::publish_reply(
nats,
reply_to,
response_bytes,
Expand All @@ -161,8 +103,8 @@ pub async fn handle<N: PublishClient + FlushClient, C: Client, S: JsonSerialize>
"Failed to handle fs_read_text_file"
);
let (bytes, content_type) =
error_response_bytes(serializer, request_id, code, &message);
publish_reply(
rpc_reply::error_response_bytes(serializer, request_id, code, &message);
rpc_reply::publish_reply(
nats,
reply_to,
bytes,
Expand Down Expand Up @@ -634,21 +576,6 @@ mod tests {
assert!(fs_err.source().is_some());
}

#[test]
fn error_response_bytes_first_fallback_uses_null_id() {
let mock = FailNextSerialize::new(1);
let (bytes, content_type) = error_response_bytes(
&mock,
RequestId::Number(42),
ErrorCode::InvalidParams,
"test message",
);
assert_eq!(content_type, "application/json");
let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(parsed["id"], serde_json::Value::Null);
assert_eq!(parsed["error"]["code"], -32603);
}

#[tokio::test]
async fn mock_client_request_permission_returns_err() {
let client = MockClient::new("x");
Expand All @@ -674,22 +601,4 @@ mod tests {
let result = client.request_permission(req).await;
assert!(result.is_err());
}

#[test]
fn error_response_bytes_last_resort_returns_plain_text() {
let mock = FailNextSerialize::new(2);
let (bytes, content_type) =
error_response_bytes(&mock, RequestId::Number(1), ErrorCode::InternalError, "msg");
assert_eq!(content_type, "text/plain");
assert_eq!(bytes.as_ref(), b"Internal error");
}

#[test]
fn error_response_fallback_bytes_std_serializer_returns_json() {
let (bytes, content_type) = error_response_fallback_bytes(&StdJsonSerialize);
assert_eq!(content_type, "application/json");
let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(parsed["id"], serde_json::Value::Null);
assert_eq!(parsed["error"]["code"], -32603);
}
}
Loading