Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
380 changes: 380 additions & 0 deletions demo-notebooks/guided-demos/6_rayjob_checkpointing_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,380 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ray checkpointing example\n",
"\n",
"This notebook runs a **Ray Train** checkpointing demo on **Red Hat OpenShift AI** using the CodeFlare SDK:\n",
"\n",
"- **Red Hat build of Kueue** must be configured in your cluster (ResourceFlavor → ClusterQueue → **LocalQueue** in your namespace, and `kueue.openshift.io/managed=true` on the namespace). See the OpenShift *AI workloads* documentation for [Red Hat build of Kueue](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue).\n",
"- Submit a **RayJob** with a **managed Ray cluster** (`ManagedClusterConfig`) so KubeRay lifecycles the cluster with the job (`shutdownAfterJobFinishes`). The RayJob is labeled for your **LocalQueue** via `local_queue` (example: `\"default\"`).\n",
"- Configure **AWS credentials** for the S3 bucket used by Ray Train checkpoints.\n",
"- **Monitor** training in the Ray dashboard (**Jobs** tab), then **suspend and resume** the RayJob (`job.stop()` / `job.resubmit()`) to verify training **resumes from S3** after a simulated interruption.\n",
"\n",
"Training script: `train_with_checkpoints.py` in this directory (same source as the CodeFlare SDK guided demo)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import required libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from codeflare_sdk import RayJob, ManagedClusterConfig, set_api_client, get_cluster\n",
"from kube_authkit import AuthConfig, get_k8s_client\n",
"import time"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Authenticate to your OpenShift cluster"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import urllib3\n",
"\n",
"urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)\n",
"\n",
"# Authenticate to your Kubernetes/OpenShift cluster using kube-authkit\n",
"\n",
"# Option 1: Auto-detect credentials (kubeconfig or in-cluster service account)\n",
"# NOTE: In RHOAI Workbenches the workbench service account may not have Ray RBAC\n",
"# permissions. Use Option 2 (token) unless your admin has granted SA permissions\n",
"# (see RHOAIENG-46748). Auto-detect works if you have a local kubeconfig.\n",
"# auth_config = AuthConfig(method=\"auto\")\n",
"\n",
"# Option 2 (Recommended for RHOAI Workbenches): Token-based authentication\n",
"# Get your token with: oc whoami -t (or from the OpenShift console → Copy login command)\n",
"auth_config = AuthConfig(\n",
" method=\"openshift\",\n",
" k8s_api_host=\"https://api.example.com:6443\",\n",
" token=\"sha256~XXXXX\", # oc whoami -t\n",
")\n",
"\n",
"# Option 3: OIDC authentication (for BYOIDC-enabled clusters)\n",
"# auth_config = AuthConfig(\n",
"# method=\"oidc\",\n",
"# k8s_api_host=\"https://api.example.com:6443\",\n",
"# oidc_issuer=\"https://your-oidc-provider.com\",\n",
"# client_id=\"your-client-id\",\n",
"# use_device_flow=True, # Interactive device flow for notebook environments\n",
"# )\n",
"\n",
"api_client = get_k8s_client(config=auth_config)\n",
"# Set to False for self-signed / dev API certificates (optional).\n",
"api_client.configuration.verify_ssl = False\n",
"set_api_client(api_client)\n",
"\n",
"NAMESPACE = \"your-namespace\" # Data Science Project where LocalQueue + RayJob run\n",
"JOB_NAME = \"checkpointing-job\"\n",
"# Must match metadata.name of a LocalQueue in NAMESPACE (create per OpenShift Kueue docs).\n",
"LOCAL_QUEUE = \"default\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Red Hat build of Kueue (required before submit)\n",
"\n",
"Configure **ResourceFlavor** → **ClusterQueue** → **LocalQueue** in your project namespace, and label the namespace so Kueue manages workloads there. Official OpenShift 4.21 *AI workloads* docs:\n",
"\n",
"- [Configuring a resource flavor](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue#configuring-resourceflavors_configuring-quotas)\n",
"- [Configuring a cluster queue](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue#configuring-clusterqueues_configuring-quotas)\n",
"- [Configuring a local queue](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue#configuring-localqueues_configuring-quotas) — the `LocalQueue` `metadata.name` must match **`LOCAL_QUEUE`** above (e.g. `default`).\n",
"- [Labeling namespaces to allow Red Hat build of Kueue to manage jobs](https://docs.redhat.com/en/documentation/openshift_container_platform/4.21/html/ai_workloads/red-hat-build-of-kueue#labeling-namespaces-to-allow-red-hat-build-of-kueue-to-manage-jobs_managing-jobs-and-workloads): `oc label namespace <namespace> kueue.openshift.io/managed=true`\n",
"\n",
"After submit, the RayJob may show **Suspended** until Kueue **admits** a `Workload` — use `oc get workloads.kueue.x-k8s.io -n $NAMESPACE` if needed. That is **not** the same as the manual suspend used later for the checkpoint demo."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional: verify Kueue objects exist (cluster admin / user with read access)\n",
"# !oc get resourceflavor.kueue.x-k8s.io\n",
"# !oc get clusterqueue.kueue.x-k8s.io\n",
"# !oc get localqueue.kueue.x-k8s.io -n $NAMESPACE\n",
"\n",
"print(f\"Namespace: {NAMESPACE!r}, RayJob name: {JOB_NAME!r}, LocalQueue: {LOCAL_QUEUE!r}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set your AWS credentials"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set your AWS credentials\n",
"# WARNING: Do not commit credentials to version control. For production,\n",
"# use OpenShift AI Data Connections or OpenShift Secrets instead.\n",
"AWS_CREDENTIALS = {\n",
" \"AWS_ACCESS_KEY_ID\": \"your-access-key\",\n",
" \"AWS_SECRET_ACCESS_KEY\": \"your-secret-key\",\n",
" \"AWS_DEFAULT_REGION\": \"us-east-1\", # e.g. \"us-east-1\"\n",
" \"AWS_S3_BUCKET\": \"your-bucket-name\",\n",
"}\n",
"\n",
"# If using temporary credentials (SSO/federated), add the session token:\n",
"# AWS_CREDENTIALS[\"AWS_SESSION_TOKEN\"] = \"your-session-token\"\n",
"\n",
"print(f\"Using bucket: {AWS_CREDENTIALS['AWS_S3_BUCKET']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Submit RayJob (managed Ray cluster + Kueue local queue)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"managed = ManagedClusterConfig(\n",
" num_workers=2,\n",
" head_cpu_requests=2,\n",
" head_cpu_limits=4,\n",
" head_memory_requests=4,\n",
" head_memory_limits=8,\n",
" worker_cpu_requests=2,\n",
" worker_cpu_limits=4,\n",
" worker_memory_requests=4,\n",
" worker_memory_limits=8,\n",
")\n",
"\n",
"job = RayJob(\n",
" job_name=JOB_NAME,\n",
" entrypoint=\"python train_with_checkpoints.py\",\n",
" cluster_config=managed,\n",
" namespace=NAMESPACE,\n",
" local_queue=LOCAL_QUEUE,\n",
" runtime_env={\n",
" \"working_dir\": \"./\",\n",
" \"pip\": [\"torch\", \"torchvision\", \"s3fs\", \"pyarrow\"],\n",
" \"env_vars\": {\n",
" **AWS_CREDENTIALS,\n",
" \"RAY_TRAIN_WORKER_GROUP_START_TIMEOUT_S\": \"120\", # Allow time for worker scheduling\n",
" },\n",
" },\n",
")\n",
"\n",
"job.submit()\n",
"print(\n",
" \"RayJob submitted. If status stays Suspended briefly, Kueue may still be admitting the Workload.\"\n",
")\n",
"print(f\"RayCluster name (when assigned): {job.cluster_name}\")\n",
"print(\"Watch logs for: NO CHECKPOINT FOUND - Starting fresh\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Monitor job progress (status + Ray dashboard Jobs)\n",
"\n",
"Poll `job.status()`, then open the **Ray dashboard** URL from the RayCluster created by your RayJob. Use **Jobs** in the dashboard for live driver logs (epochs, checkpoint messages)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(job.status())\n",
"\n",
"# Resolve RayCluster created by the RayJob (retry until it exists after Kueue admission).\n",
"cluster = None\n",
"for _ in range(36):\n",
" try:\n",
" cluster = get_cluster(job.cluster_name, namespace=NAMESPACE, verify_tls=False)\n",
" break\n",
" except Exception:\n",
" time.sleep(5)\n",
"\n",
"if cluster is None:\n",
" raise RuntimeError(\n",
" \"RayCluster not ready — check RayJob / Workload admission and operator logs.\"\n",
" )\n",
"\n",
"print(f\"Ray Dashboard (open in browser): {cluster.cluster_dashboard_uri()}\")\n",
"print(\"In the dashboard, open Jobs and stream logs for the training driver.\")\n",
"print(\n",
" \"Wait for at least one full epoch and a checkpoint to S3 before running the suspend cell.\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Suspend RayJob (checkpoint demo)\n",
"\n",
"After logs show at least **one epoch** and a checkpoint written to **S3**, suspend the RayJob. This is a **manual** suspend for the demo (distinct from Kueue holding the job until admission right after submit).\n",
"\n",
"Use **Pause** in the OpenShift AI UI, or run the next cell (`job.stop()`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"=\" * 60)\n",
"print(\"SUSPENDING RayJob (checkpoint demo — not deleting the RayJob CR)\")\n",
"print(\"Checkpoints remain in S3.\")\n",
"print(\"=\" * 60)\n",
"\n",
"job.stop()\n",
"print(\"Stop requested; poll job.status() until the RayJob reports suspended / non-running.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Resume RayJob\n",
"\n",
"Use **Resume** in the OpenShift AI UI, or run `job.resubmit()` in the next cell. When the RayCluster is back, confirm in the dashboard **Jobs** view: `RESUMING FROM CHECKPOINT - Starting at epoch N`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"=\" * 60)\n",
"print(\"RESUMING RayJob after suspend\")\n",
"print(\"Watch for RESUMING FROM CHECKPOINT in dashboard Jobs logs\")\n",
"print(\"=\" * 60)\n",
"\n",
"job.resubmit()\n",
"time.sleep(10)\n",
"print(job.status())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Verify resume from checkpoint\n",
"\n",
"In the Ray dashboard **Jobs** tab, look for:\n",
"\n",
"```\n",
"RESUMING FROM CHECKPOINT - Starting at epoch N\n",
"Previous loss: X.XXXX\n",
"```\n",
"\n",
"That confirms optimizer and progress were restored from S3 across the suspend/resume cycle."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(job.status())\n",
"try:\n",
" cluster = get_cluster(job.cluster_name, namespace=NAMESPACE, verify_tls=False)\n",
" print(f\"Ray Dashboard: {cluster.cluster_dashboard_uri()}\")\n",
"except Exception as e:\n",
" print(f\"Could not resolve cluster yet: {e}\")\n",
"print(\"Check Jobs tab for: RESUMING FROM CHECKPOINT - Starting at epoch N\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cleanup\n",
"\n",
"Delete the RayJob and tear down the RayCluster if it is still present."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"Cleaning up...\")\n",
"cluster_name = job.cluster_name\n",
"try:\n",
" job.delete()\n",
"except Exception:\n",
" pass\n",
"\n",
"try:\n",
" c = get_cluster(cluster_name, namespace=NAMESPACE, verify_tls=False)\n",
" c.down()\n",
"except Exception:\n",
" pass\n",
"\n",
"print(\"Cleanup attempted (RayJob delete; cluster.down if RayCluster still exists).\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# No explicit logout needed - authentication is managed automatically by kube-authkit"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading
Loading