Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 17 additions & 41 deletions IDEAS.mk
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ endif
RUSTFLAGS ?= -Awarnings## Ignore Rust compiler warnings
CARGO_NET_OFFLINE ?= true## Cargo offline mode
CFLAGS ?= -w## Ignore C compiler warnings
export EXTRACT_INFO_CMAKE CFLAGS

GIT = git -C ${TRANSLATION_DIR}

Expand All @@ -34,42 +35,17 @@ endif


# cmake
cmake: build-ninja/build.log

.PRECIOUS: build-ninja/CMakeCache.txt
build-ninja/CMakeCache.txt: test_case/CMakeLists.txt ${EXTRACT_INFO_CMAKE}
@rm -rf build-ninja
ifeq ($(wildcard CMakePresets.json),)
cmake -S test_case -B build-ninja -G Ninja \
-DCMAKE_BUILD_TYPE=Debug \
-DCMAKE_C_FLAGS_DEBUG="-g -O0" \
-DCMAKE_PROJECT_TOP_LEVEL_INCLUDES="${EXTRACT_INFO_CMAKE}" \
-DCMAKE_C_FLAGS="${CFLAGS}" \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
else
cmake -S . --preset test \
-DCMAKE_BUILD_TYPE=Debug \
-DCMAKE_C_FLAGS_DEBUG="-g -O0" \
-DCMAKE_PROJECT_TOP_LEVEL_INCLUDES="${EXTRACT_INFO_CMAKE}" \
-DCMAKE_C_FLAGS="${CFLAGS}" \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
endif

.PRECIOUS: build-ninja/compile_commands.json
build-ninja/compile_commands.json: build-ninja/CMakeCache.txt ;
.PHONY: cmake
cmake: build-ninja/cmake.log

.PRECIOUS: build-ninja/build.log
build-ninja/build.log: build-ninja/CMakeCache.txt
ifeq ($(wildcard CMakePresets.json),)
-cmake --build build-ninja --target all 2> $@
else
-cmake --build build-ninja --target all --preset test 2> $@
endif
@find build-ninja -maxdepth 1 -type f -executable | \
xargs -I{} sh -c "nm --extern-only {} | \
awk '{if (\$$2 == \"T\") print \$$NF}' | \
grep -v ^_ > {}.symbols"
build-ninja/cmake.log: test_case/CMakeLists.txt ${EXTRACT_INFO_CMAKE}
uv run python -m ideas.cmake source_dir=test_case \
build_dir=build-ninja
@touch $@

build-ninja/CMakeCache.txt: build-ninja/cmake.log
build-ninja/compile_commands.json: build-ninja/cmake.log
build-ninja/build.log: build-ninja/cmake.log

# init
.PHONY: init
Expand All @@ -87,7 +63,7 @@ ${TRANSLATION_DIR}/.git/config:
${GIT} commit --quiet --all --message "Initial commit"

.PRECIOUS: ${TRANSLATION_DIR}/Cargo.toml
${TRANSLATION_DIR}/Cargo.toml: ${TRANSLATION_DIR}/.git/config
${TRANSLATION_DIR}/Cargo.toml: | ${TRANSLATION_DIR}/.git/config
echo -n "[workspace]\nresolver = \"3\"" > $@
${GIT} add Cargo.toml
${GIT} commit --quiet --all --message "Created cargo workspace"
Expand Down Expand Up @@ -180,16 +156,16 @@ ${TRANSLATION_DIR}/cargo_test.log: ${TRANSLATION_DIR}/build.log $(patsubst %,${T
.PRECIOUS: ${TRANSLATION_DIR}/%/cargo_test.log
${TRANSLATION_DIR}/%/cargo_test.log: ${TRANSLATION_DIR}/%/build.log ${TRANSLATION_DIR}/%/tests/test_cases.rs
if [ $$(stat -c %s ${TRANSLATION_DIR}/$*/build.log) = 0 ]; then \
cargo test --manifest-path ${TRANSLATION_DIR}/$*/Cargo.toml --test test_cases | tee $@ ; \
else \
find test_vectors -name '*.json' -exec echo "test {} ... FAILED" \; | tee $@ ; \
fi \
cargo test --manifest-path ${TRANSLATION_DIR}/$*/Cargo.toml --test test_cases | tee $@ ; \
else \
find test_vectors -name '*.json' -exec echo "test {} ... FAILED" \; | tee $@ ; \
fi \

.PRECIOUS: ${TRANSLATION_DIR}/%/tests/test_cases.rs
${TRANSLATION_DIR}/%/tests/test_cases.rs: | ${TEST_FILES} ${TRANSLATION_DIR}/%/Cargo.toml build-ninja/%.type
@mkdir -p $(@D)
cargo add --quiet --manifest-path ${TRANSLATION_DIR}/$*/Cargo.toml --dev assert_cmd@2.0.17 ntest@0.9.3 predicates@3.1.3
-uv run python -m ideas.convert_tests ${TEST_FILES} --crate_manifest $(realpath ${TRANSLATION_DIR}/$*/Cargo.toml) | rustfmt > $@
-uv run python -m ideas.convert_tests --crate_manifest $(realpath ${TRANSLATION_DIR}/$*/Cargo.toml) \
${TEST_FILES} | rustfmt > $@
${GIT} add $*/Cargo.toml $*/tests/test_cases.rs
${GIT} commit --quiet --message "Converted \`$*\` test vectors"

Expand Down
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ requires-python = "~=3.14.0"

dependencies = [
"clang==21.1.7",
"dspy==3.1.0",
"dspy==3.1.2",
"hydra-core",
"tree-sitter==0.24.0",
"tree-sitter-c==0.23.4",
"tree-sitter-rust==0.23.2",
"tree-sitter==0.25.2",
"tree-sitter-rust==0.24.0",
]

[dependency-groups]
Expand Down
45 changes: 45 additions & 0 deletions src/ideas/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# Copyright (C) 2026 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
#

from unittest.mock import patch

from pydantic.fields import FieldInfo

import dspy
import dspy.adapters.chat_adapter
from dspy.adapters.chat_adapter import ChatAdapter as _ChatAdapter
from dspy.adapters.utils import translate_field_type as _translate_field_type
from dspy.signatures.utils import get_dspy_field_type


class Code(dspy.Code):
def format(self):
return f"```{self.language.lower()}\n{self.code.rstrip()}\n```"

@classmethod
def short_description(cls):
return f"must be {cls.__name__}"


class ChatAdapter(_ChatAdapter):
def format_field_structure(self, signature: type[dspy.Signature]) -> str:
with patch.object(
dspy.adapters.chat_adapter, "translate_field_type", translate_field_type
):
return super().format_field_structure(signature)


def translate_field_type(field_name: str, field_info: FieldInfo) -> str:
# If a non-input field has a short_description, then use that.
field_type = field_info.annotation
if not field_type:
raise RuntimeError(f"Field '{field_name}' is missing a type annotation")

if hasattr(field_type, "short_description") and get_dspy_field_type(field_info) != "input":
desc = field_type.short_description()
desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else ""
return f"{{{field_name}}}{desc}"
return _translate_field_type(field_name, field_info)
143 changes: 143 additions & 0 deletions src/ideas/ast_rust.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#
# Copyright (C) 2026 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
#

from collections import OrderedDict

from tree_sitter import Language, Parser, Node, Query, QueryCursor
import tree_sitter_rust

# Initialize the Rust language once
RUST_LANGUAGE = Language(tree_sitter_rust.language())
RUST_PARSER = Parser(RUST_LANGUAGE)


class RustFnSignature:
def __init__(self, node: Node):
if not node.type == "function_item":
raise ValueError(
f"Node {node} is not a function_item, so cannot extract a signature!"
)

name = node.child_by_field_name("name")
if not name:
raise ValueError(f"Function name not found in {node}!")

self.name: Node = name
self.params: Node | None = node.child_by_field_name("parameters")
self.return_type: Node | None = node.child_by_field_name("return_type")

def __repr__(self) -> str:
text = ""
if _text := self.name.text:
text += _text.decode()

if self.params and (_text := self.params.text):
text += _text.decode()

if self.return_type and (_text := self.return_type.text):
text += _text.decode()
return text

def __eq__(self, other: object) -> bool:
if not isinstance(other, RustFnSignature):
return NotImplemented

return self.__repr__() == other.__repr__()


def get_root(code: str) -> Node:
tree = RUST_PARSER.parse(code.encode())
return tree.root_node


def get_nodes(node: Node, node_type: str | None = None) -> list[Node]:
nodes = []
for child in node.children:
if not node_type or child.type == node_type:
nodes.append(child)
return nodes


def get_ancestor_nodes(node: Node, node_type: str | None = None) -> list[Node]:
ancestors = []
# Excluding self
current = node.parent
while current:
if not node_type or current.type == node_type:
ancestors.append(current)
current = current.parent

# Remove root node from ancestors
return ancestors[:-1]


def get_macro_nodes(root: Node, placeholder: str) -> list[Node]:
# Query for all nodes containing macro invocation
source = f"""
(macro_invocation
macro: (identifier) @macro_name
(#eq? @macro_name "{placeholder}")) @macro
"""

query = Query(RUST_LANGUAGE, source)
cursor = QueryCursor(query)
captures = cursor.captures(root)

# Collect all unique ancestors by walking up from each macro invocation
ancestors = set()
for macro_node in captures.get("macro", []):
ancestors.update(get_ancestor_nodes(macro_node))

return list(ancestors)


def validate_changes(code: str, template: str) -> OrderedDict[str, str]:
code_root = get_root(code)
template_root = get_root(template)

nodes = get_nodes(code_root)
template_nodes = get_nodes(template_root)
allowed_change_nodes = get_macro_nodes(template_root, "unimplemented")

scope_feedback = OrderedDict()

# Check for top-level changes
if len(nodes) != len(template_nodes):
scope_feedback["top_level_changes"] = (
"The generated code modifies parts outside the function body.\n"
"You must **only** modify the `unimplemented!()` function body and leave everything else **unchanged**!"
)

# Check for allowed changes
for template_node, node in zip(template_nodes, nodes):
if not template_node.text == node.text:
if (
template_node not in allowed_change_nodes
or not template_node.type == "function_item"
):
scope_feedback["top_level_changes"] = (
"The generated code modifies parts outside the function body.\n"
"You must **only** modify the `unimplemented!()` function body and leave everything else **unchanged**!"
)

if not node.type == "function_item" or (node.type != template_node.type):
scope_feedback["signature_changes"] = (
"You must preserve the function signature in the template intact and **not modify it**!"
)
else:
# Compare signatures
template_signature = RustFnSignature(template_node)
try:
signature = RustFnSignature(node)
except ValueError:
signature = None

if template_signature != signature:
scope_feedback["signature_changes"] = (
"You must preserve the function signature in the template intact and **not modify it**!"
)

return scope_feedback
Loading