Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
115 changes: 115 additions & 0 deletions MaxKernel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# tpu_kernel_gen


## Installation

### As a Package (Recommended)

Install this package locally for use anywhere:

```bash
# From the project directory
pip install -e .
```

This allows you to import and use the modules from anywhere:

```python
from tpu_kernel_gen.kernel_parser import parse_kernels
from tpu_kernel_gen.embed import generate_embeddings
from tpu_kernel_gen.kernel_retrieval import search_similar_kernels
```

### Dependencies Only

Alternatively, install just the required dependencies:

```bash
pip install -r requirements.txt
```

## Usage

### As Python Package

After installing the package, you can use it programmatically:

```python
import tpu_kernel_gen.kernel_parser as parser
import tpu_kernel_gen.embed as embedder
import tpu_kernel_gen.kernel_retrieval as retriever

# Parse kernels
kernels = parser.parse_kernels("/path/to/source")

# Generate embeddings
embedder.add_embeddings("kernels.csv")

# Search for similar kernels
results = retriever.search_kernels("matrix multiplication", k=10)
```

### Command Line Usage

## How to populate kernel DB
### Step 1: Parse kernels from source code

Use the kernel parser to extract Pallas kernels from Python source files:

```bash
python kernel_parser.py /path/to/source/directory --output kernels.csv
```

This will:
- Recursively scan Python files for JAX Pallas kernels
- Extract kernel definitions and call sites
- Save results to `kernels.csv`

### Step 2: Generate embeddings

Add code embeddings to the kernel data using UniXcoder:

```bash
python embed.py kernels.csv --code_column code
```

This will:
- Load the UniXcoder model for code embeddings
- Process each kernel's code to generate vector embeddings
- Add embedding columns to the CSV file in-place
- Create a backup of the original file

### Step 3: Upload to BigQuery

Upload the enriched kernel data to BigQuery:

```bash
python bq_upload.py --csv-file kernels.csv --table-name your_dataset.kernels --project-id your-project-id
```

This will:
- Upload the CSV data to the specified BigQuery table
- Auto-generate incremental UUIDs for new entries
- Apply the proper schema for the kernel database


## How to retrieve from kernel DB


Use the kernel retrieval tool to search for similar kernels in the BigQuery vector database:

```bash
python kernel_retrieval.py --project-id your-project-id --dataset-name your_dataset --table-name kernels --query "matrix multiplication kernel" --k 10
```

This will:
- Connect to your BigQuery vector store using UniXcoder embeddings
- Search for kernels similar to your query using cosine similarity
- Return the top k most similar results with metadata and similarity scores
- Display operation names, frameworks, hardware targets, and file locations

Optional flags:
- `--verbose`: Enable detailed output during the search process
- `--k`: Number of similar kernels to retrieve (default: 5)

The results will show ranked kernels with their similarity scores, operation metadata, and source file information to help you find relevant kernel implementations.
10 changes: 10 additions & 0 deletions MaxKernel/format.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
set -e

echo "Applying formatting to MaxKernel..."
ruff format .

echo "Applying lint fixes to MaxKernel..."
ruff check --fix .

echo "Formatting and fixes complete!"
10 changes: 10 additions & 0 deletions MaxKernel/lint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
set -e

echo "Running ruff check on MaxKernel..."
ruff check .

echo "Running ruff format check on MaxKernel..."
ruff format --check .

echo "Linting passed!"
28 changes: 28 additions & 0 deletions MaxKernel/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[tool.ruff]
line-length = 120
target-version = "py310"
indent-width = 2

[tool.ruff.lint]
select = ["E", "F", "I", "N", "W"]
ignore = [
"E501", # Line too long
"N", # Naming conventions
"W", # Warnings
"E741", # Ambiguous variable name
"F821", # Undefined name (often in specialized kernels)
"F823", # Local variable referenced before assignment
"E722", # Bare except
"F841", # Local variable assigned but never used
"E731", # Lambda assignment
"E402", # Module level import not at top of file
]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"] # Unused imports are common in __init__.py for re-exporting

[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
9 changes: 9 additions & 0 deletions MaxKernel/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
torch
google-cloud-bigquery
langchain-google-community
tensorflow
transformers
jax[tpu]
matplotlib
pandas
protobuf
14 changes: 14 additions & 0 deletions MaxKernel/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from setuptools import find_packages, setup

setup(
name="tpu-kernel-gen",
version="0.1.0",
description="TPU Kernel Generation Package",
author="Your Name",
author_email="your.email@example.com",
packages=find_packages(),
python_requires=">=3.7",
install_requires=[
# Add your dependencies here
],
)
37 changes: 37 additions & 0 deletions MaxKernel/tests/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
import requests

# IMPORTANT: Replace with your server's actual internal IP address
SERVER_IP = "10.130.0.155"
SERVER_URL = f"http://{SERVER_IP}:5000/receive_data"


def send_data_to_server():
"""Creates a NumPy array and sends it to the server."""

# 1. Create some data to send
my_array = np.arange(12, dtype=np.float32).reshape(3, 4)
print("📦 Data to be sent from client:")
print(my_array)

# 2. Prepare the data for transfer
# We convert the NumPy array to a standard Python list to serialize it into JSON.
payload = {"array_data": my_array.tolist()}

# 3. Send the HTTP POST request
try:
print(f"\n🚀 Sending data to {SERVER_URL}...")
response = requests.post(SERVER_URL, json=payload, timeout=10)

# Check if the request was successful
response.raise_for_status()

print("✅ Server responded successfully!")
print(f"Response JSON: {response.json()}")

except requests.exceptions.RequestException as e:
print(f"\n❌ Failed to connect to server: {e}")


if __name__ == "__main__":
send_data_to_server()
Loading