Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions agentic_security/http_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ async def _probe_with_files(self, files):

return response

def validate(
self, prompt: str, encoded_image: str, encoded_audio: str, files: dict | None
) -> None:
def validate(self, prompt: str, encoded_image: str, encoded_audio: str, files: dict | None) -> None:
if self.has_files and not files:
raise ValueError("Files are required for this request.")

Expand Down Expand Up @@ -107,12 +105,17 @@ async def probe(
content = content.replace("<<BASE64_IMAGE>>", encoded_image)
content = content.replace("<<BASE64_AUDIO>>", encoded_audio)

# Remove Content-Length from headers to avoid mismatch when
# placeholder replacement changes body size. httpx will set
# the correct Content-Length based on the actual content.
clean_headers = {k: v for k, v in self.headers.items() if k.lower() != "content-length"}

transport = httpx.AsyncHTTPTransport(retries=settings_var("network.retry", 3))
async with httpx.AsyncClient(transport=transport) as client:
response = await client.request(
method=self.method,
url=self.url,
headers=self.headers,
headers=clean_headers,
content=content,
timeout=self.timeout(),
)
Expand All @@ -127,9 +130,7 @@ async def verify(self) -> httpx.Response:
return await self.probe(
"test",
# TODO: fix url for mp3
encoded_audio=encode_audio_base64_by_url(
"https://www.example.com/audio.mp3"
),
encoded_audio=encode_audio_base64_by_url("https://www.example.com/audio.mp3"),
)
case LLMSpec(has_files=True):
return await self._probe_with_files({})
Expand Down Expand Up @@ -168,18 +169,14 @@ def parse_http_spec(http_spec: str) -> LLMSpec:
# Extract the method and URL from the first line
request_line_parts = lines[0].split()
if len(request_line_parts) < 2:
raise InvalidHTTPSpecError(
"First line of HTTP spec must include the method and URL."
)
raise InvalidHTTPSpecError("First line of HTTP spec must include the method and URL.")
method, url = request_line_parts[0], request_line_parts[1]

# Check url validity
valid_url = urlparse(url)
# if missing the correct formatting ://, urlparse.netloc will be empty
if valid_url.scheme not in ("http", "https") or not valid_url.netloc:
raise InvalidHTTPSpecError(
f"Invalid URL: {url}. Ensure it starts with 'http://' or 'https://'"
)
raise InvalidHTTPSpecError(f"Invalid URL: {url}. Ensure it starts with 'http://' or 'https://'")

# Initialize headers and body
headers = {}
Expand Down
36 changes: 13 additions & 23 deletions agentic_security/probe_actor/fuzzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
INITIAL_OPTIMIZER_POINTS = settings_var("fuzzer.initial_optimizer_points", 25)
MIN_FAILURE_SAMPLES = settings_var("fuzzer.min_failure_samples", 5)
FAILURE_RATE_THRESHOLD = settings_var("fuzzer.failure_rate_threshold", 0.5)
FAILURES_CSV_PATH = settings_var("fuzzer.failures_csv_path", "failures.csv")
FULL_LOG_CSV_PATH = settings_var("fuzzer.full_log_csv_path", "full_scan_log.csv")
MAX_INJECTION_ATTEMPTS = settings_var("fuzzer.max_injection_attempts", 20)


async def generate_prompts(
Expand Down Expand Up @@ -111,9 +114,7 @@ async def process_prompt(

if response.status_code >= 400:
logger.error(f"HTTP {response.status_code} {response.content=}")
fuzzer_state.add_error(
module_name, prompt, response.status_code, response.text
)
fuzzer_state.add_error(module_name, prompt, response.status_code, response.text)
return tokens, True

# Process successful response
Expand All @@ -123,9 +124,7 @@ async def process_prompt(
# Check if the response indicates a refusal
refused = refusal_heuristic(response.json())
if refused:
fuzzer_state.add_refusal(
module_name, prompt, response.status_code, response_text
)
fuzzer_state.add_refusal(module_name, prompt, response.status_code, response_text)

fuzzer_state.add_output(module_name, prompt, response_text, refused)
return tokens, refused
Expand Down Expand Up @@ -169,10 +168,7 @@ async def process_prompt_batch(
- Total number of tokens processed.
- Number of failed prompts.
"""
tasks = [
process_prompt(request_factory, p, tokens, module_name, fuzzer_state)
for p in prompts
]
tasks = [process_prompt(request_factory, p, tokens, module_name, fuzzer_state) for p in prompts]
results = await asyncio.gather(*tasks)
total_tokens = sum(r[0] for r in results)
failures = sum(1 for r in results if r[1])
Expand Down Expand Up @@ -216,11 +212,7 @@ async def scan_module(

# Initialize optimizer if optimization is enabled
optimizer = (
Optimizer(
[Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS
)
if optimize
else None
Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS) if optimize else None
)

module_size = 0 if module.lazy else len(module.prompts)
Expand Down Expand Up @@ -422,8 +414,8 @@ async def perform_single_shot_scan(
processed_prompts += module_size

yield ScanResult.status_msg("Scan completed.")
fuzzer_state.export_failures("failures.csv")
fuzzer_state.export_full_log("full_scan_log.csv")
fuzzer_state.export_failures(FAILURES_CSV_PATH)
fuzzer_state.export_full_log(FULL_LOG_CSV_PATH)


async def perform_many_shot_scan(
Expand Down Expand Up @@ -515,7 +507,7 @@ async def perform_many_shot_scan(
tokens += prompt_tokens

injected = False
for _ in range(20):
for _ in range(MAX_INJECTION_ATTEMPTS):
if injected:
break

Expand Down Expand Up @@ -552,14 +544,12 @@ async def perform_many_shot_scan(
).model_dump_json()

if optimize and len(failure_rates) >= MIN_FAILURE_SAMPLES:
yield ScanResult.status_msg(
f"High failure rate detected ({failure_rate:.2%}). Stopping this module..."
)
yield ScanResult.status_msg(f"High failure rate detected ({failure_rate:.2%}). Stopping this module...")
break

yield ScanResult.status_msg("Scan completed.")
fuzzer_state.export_failures("failures.csv")
fuzzer_state.export_full_log("full_scan_log.csv")
fuzzer_state.export_failures(FAILURES_CSV_PATH)
fuzzer_state.export_full_log(FULL_LOG_CSV_PATH)


def scan_router(
Expand Down
Loading