Skip to content

Commit 1a512cb

Browse files
committed
Create rnn_ode_optimized.py
1 parent 497ea3e commit 1a512cb

1 file changed

Lines changed: 394 additions & 0 deletions

File tree

Lines changed: 394 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,394 @@
1+
#!/usr/bin/env python3
2+
"""
3+
RNN for Learning ODE Solutions - OPTIMIZED VERSION
4+
Fixes for hanging issues:
5+
1. Smaller dataset (5000 points instead of 20000)
6+
2. num_workers=0 in DataLoader
7+
3. Smaller batch size (32)
8+
4. Progress indicators
9+
5. Early stopping option
10+
"""
11+
12+
import numpy as np
13+
import matplotlib.pyplot as plt
14+
import torch
15+
import torch.nn as nn
16+
import torch.optim as optim
17+
from torch.utils.data import Dataset, DataLoader
18+
import time
19+
from math import ceil, cos
20+
import sys
21+
22+
# Set random seeds
23+
np.random.seed(42)
24+
torch.manual_seed(42)
25+
26+
# Force CPU to avoid GPU hanging issues
27+
device = torch.device('cpu')
28+
print(f"Using device: {device} (CPU mode to avoid hanging)")
29+
30+
# ============================================================================
31+
# PART I: ODE SOLVER (OPTIMIZED)
32+
# ============================================================================
33+
34+
def SpringForce(v, x, t, gamma=0.2, Omega=0.5, F=1.0):
35+
"""Force function for driven damped harmonic oscillator."""
36+
return -2*gamma*v - x + F*cos(t*Omega)
37+
38+
print("\n" + "="*70)
39+
print("SOLVING ODE (REDUCED SIZE FOR SPEED)")
40+
print("="*70)
41+
42+
# REDUCED parameters to avoid hanging
43+
DeltaT = 0.002 # Larger timestep
44+
tfinal = 10.0 # Shorter simulation
45+
n = ceil(tfinal/DeltaT)
46+
47+
print(f"\nODE Parameters:")
48+
print(f" Time step: {DeltaT}")
49+
print(f" Final time: {tfinal}")
50+
print(f" Number of points: {n}")
51+
52+
# Solve ODE
53+
t = np.zeros(n)
54+
x = np.zeros(n)
55+
v = np.zeros(n)
56+
57+
x[0] = 1.0
58+
v[0] = 0.0
59+
gamma = 0.2
60+
Omega = 0.5
61+
F = 1.0
62+
63+
print("\nSolving ODE with RK4...")
64+
for i in range(n-1):
65+
if i % 1000 == 0:
66+
print(f" Progress: {100*i/n:.1f}%", end='\r')
67+
68+
# RK4 step
69+
k1x = DeltaT * v[i]
70+
k1v = DeltaT * SpringForce(v[i], x[i], t[i], gamma, Omega, F)
71+
72+
vv = v[i] + k1v*0.5
73+
xx = x[i] + k1x*0.5
74+
tt = t[i] + DeltaT*0.5
75+
k2x = DeltaT * vv
76+
k2v = DeltaT * SpringForce(vv, xx, tt, gamma, Omega, F)
77+
78+
vv = v[i] + k2v*0.5
79+
xx = x[i] + k2x*0.5
80+
k3x = DeltaT * vv
81+
k3v = DeltaT * SpringForce(vv, xx, tt, gamma, Omega, F)
82+
83+
vv = v[i] + k3v
84+
xx = x[i] + k3x
85+
tt = t[i] + DeltaT
86+
k4x = DeltaT * vv
87+
k4v = DeltaT * SpringForce(vv, xx, tt, gamma, Omega, F)
88+
89+
x[i+1] = x[i] + (k1x + 2*k2x + 2*k3x + k4x)/6.0
90+
v[i+1] = v[i] + (k1v + 2*k2v + 2*k3v + k4v)/6.0
91+
t[i+1] = t[i] + DeltaT
92+
93+
print(f" Progress: 100.0% - Complete!")
94+
print(f"\nODE solved: {len(x)} points")
95+
print(f" Position range: [{x.min():.4f}, {x.max():.4f}]")
96+
97+
# ============================================================================
98+
# PART II: PREPARE DATA
99+
# ============================================================================
100+
101+
print("\n" + "="*70)
102+
print("PREPARING TRAINING DATA")
103+
print("="*70)
104+
105+
seq_length = 50 # Shorter sequences
106+
X_list, y_list = [], []
107+
108+
print(f"\nCreating sequences (length={seq_length})...")
109+
for i in range(len(x) - seq_length - 1):
110+
X_list.append(x[i:i + seq_length])
111+
y_list.append(x[i + seq_length])
112+
113+
X = np.array(X_list)
114+
y = np.array(y_list).reshape(-1, 1)
115+
116+
print(f" Created {len(X)} sequences")
117+
118+
# 75/25 split
119+
train_size = int(0.75 * len(X))
120+
X_train = X[:train_size]
121+
X_test = X[train_size:]
122+
y_train = y[:train_size]
123+
y_test = y[train_size:]
124+
125+
print(f" Train: {len(X_train)} ({100*len(X_train)/len(X):.1f}%)")
126+
print(f" Test: {len(X_test)} ({100*len(X_test)/len(X):.1f}%)")
127+
128+
# ============================================================================
129+
# PART III: PYTORCH DATASET
130+
# ============================================================================
131+
132+
class TimeSeriesDataset(Dataset):
133+
def __init__(self, X, y):
134+
self.X = torch.FloatTensor(X).unsqueeze(-1)
135+
self.y = torch.FloatTensor(y)
136+
137+
def __len__(self):
138+
return len(self.X)
139+
140+
def __getitem__(self, idx):
141+
return self.X[idx], self.y[idx]
142+
143+
train_dataset = TimeSeriesDataset(X_train, y_train)
144+
test_dataset = TimeSeriesDataset(X_test, y_test)
145+
146+
# CRITICAL: num_workers=0 to avoid multiprocessing hanging
147+
batch_size = 32
148+
train_loader = DataLoader(train_dataset, batch_size=batch_size,
149+
shuffle=True, num_workers=0)
150+
test_loader = DataLoader(test_dataset, batch_size=batch_size,
151+
shuffle=False, num_workers=0)
152+
153+
print(f"\nDataLoaders ready:")
154+
print(f" Batch size: {batch_size}")
155+
print(f" Train batches: {len(train_loader)}")
156+
157+
# ============================================================================
158+
# PART IV: LSTM MODEL (SINGLE MODEL FOR SPEED)
159+
# ============================================================================
160+
161+
class LSTMModel(nn.Module):
162+
def __init__(self, hidden_size=64, num_layers=2):
163+
super(LSTMModel, self).__init__()
164+
self.hidden_size = hidden_size
165+
self.num_layers = num_layers
166+
167+
self.lstm = nn.LSTM(
168+
input_size=1,
169+
hidden_size=hidden_size,
170+
num_layers=num_layers,
171+
batch_first=True
172+
)
173+
self.fc = nn.Linear(hidden_size, 1)
174+
175+
def forward(self, x):
176+
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
177+
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
178+
179+
out, _ = self.lstm(x, (h0, c0))
180+
out = self.fc(out[:, -1, :])
181+
return out
182+
183+
# ============================================================================
184+
# PART V: TRAINING WITH PROGRESS
185+
# ============================================================================
186+
187+
print("\n" + "="*70)
188+
print("TRAINING LSTM MODEL")
189+
print("="*70)
190+
191+
model = LSTMModel(hidden_size=64, num_layers=2)
192+
criterion = nn.MSELoss()
193+
optimizer = optim.Adam(model.parameters(), lr=0.001)
194+
195+
epochs = 50 # Reduced for speed
196+
print(f"\nStarting training ({epochs} epochs)...")
197+
print(f" Hidden size: 64")
198+
print(f" Num layers: 2")
199+
200+
train_losses = []
201+
test_losses = []
202+
203+
start_time = time.time()
204+
205+
for epoch in range(epochs):
206+
# Training
207+
model.train()
208+
total_train_loss = 0
209+
batch_num = 0
210+
211+
for X_batch, y_batch in train_loader:
212+
batch_num += 1
213+
214+
# Forward
215+
predictions = model(X_batch)
216+
loss = criterion(predictions, y_batch)
217+
218+
# Backward
219+
optimizer.zero_grad()
220+
loss.backward()
221+
optimizer.step()
222+
223+
total_train_loss += loss.item()
224+
225+
train_loss = total_train_loss / len(train_loader)
226+
227+
# Evaluation
228+
model.eval()
229+
total_test_loss = 0
230+
with torch.no_grad():
231+
for X_batch, y_batch in test_loader:
232+
predictions = model(X_batch)
233+
loss = criterion(predictions, y_batch)
234+
total_test_loss += loss.item()
235+
236+
test_loss = total_test_loss / len(test_loader)
237+
238+
train_losses.append(train_loss)
239+
test_losses.append(test_loss)
240+
241+
# Print progress
242+
if (epoch + 1) % 5 == 0 or epoch == 0:
243+
elapsed = time.time() - start_time
244+
print(f" Epoch {epoch+1:3d}/{epochs}: Train={train_loss:.6f}, Test={test_loss:.6f}, Time={elapsed:.1f}s")
245+
246+
total_time = time.time() - start_time
247+
print(f"\nTraining complete in {total_time:.2f} seconds!")
248+
print(f"Final: Train Loss = {train_losses[-1]:.6f}, Test Loss = {test_losses[-1]:.6f}")
249+
250+
# ============================================================================
251+
# PART VI: PREDICTIONS
252+
# ============================================================================
253+
254+
print("\n" + "="*70)
255+
print("GENERATING PREDICTIONS")
256+
print("="*70)
257+
258+
model.eval()
259+
train_preds = []
260+
test_preds = []
261+
262+
with torch.no_grad():
263+
for i in range(len(X_train)):
264+
x_in = torch.FloatTensor(X_train[i]).unsqueeze(0).unsqueeze(-1)
265+
pred = model(x_in).item()
266+
train_preds.append(pred)
267+
268+
for i in range(len(X_test)):
269+
x_in = torch.FloatTensor(X_test[i]).unsqueeze(0).unsqueeze(-1)
270+
pred = model(x_in).item()
271+
test_preds.append(pred)
272+
273+
train_preds = np.array(train_preds)
274+
test_preds = np.array(test_preds)
275+
276+
# Metrics
277+
mse = np.mean((y_test.flatten() - test_preds)**2)
278+
rmse = np.sqrt(mse)
279+
mae = np.mean(np.abs(y_test.flatten() - test_preds))
280+
r2 = 1 - (np.sum((y_test.flatten() - test_preds)**2) /
281+
np.sum((y_test.flatten() - np.mean(y_test))**2))
282+
283+
print(f"\nTest Metrics:")
284+
print(f" MSE = {mse:.6f}")
285+
print(f" RMSE = {rmse:.6f}")
286+
print(f" MAE = {mae:.6f}")
287+
print(f" R² = {r2:.6f}")
288+
289+
# ============================================================================
290+
# PART VII: VISUALIZATION
291+
# ============================================================================
292+
293+
print("\n" + "="*70)
294+
print("CREATING VISUALIZATION")
295+
print("="*70)
296+
297+
fig = plt.figure(figsize=(16, 10))
298+
299+
# Plot 1: ODE solution
300+
ax1 = plt.subplot(2, 3, 1)
301+
ax1.plot(t, x, 'b-', linewidth=1, alpha=0.7)
302+
split_point = train_size + seq_length
303+
if split_point < len(t):
304+
ax1.axvline(x=t[split_point], color='r', linestyle='--', linewidth=2, label='Train/Test')
305+
ax1.set_xlabel('Time [s]')
306+
ax1.set_ylabel('Position x [m]')
307+
ax1.set_title('ODE Solution', fontweight='bold')
308+
ax1.legend()
309+
ax1.grid(True, alpha=0.3)
310+
311+
# Plot 2: Phase space
312+
ax2 = plt.subplot(2, 3, 2)
313+
ax2.plot(x, v, 'b-', linewidth=0.5, alpha=0.5)
314+
ax2.set_xlabel('Position x')
315+
ax2.set_ylabel('Velocity v')
316+
ax2.set_title('Phase Space', fontweight='bold')
317+
ax2.grid(True, alpha=0.3)
318+
319+
# Plot 3: Training curves
320+
ax3 = plt.subplot(2, 3, 3)
321+
ax3.plot(train_losses, 'b-', linewidth=2, label='Train')
322+
ax3.plot(test_losses, 'r-', linewidth=2, label='Test')
323+
ax3.set_xlabel('Epoch')
324+
ax3.set_ylabel('Loss (MSE)')
325+
ax3.set_title('Training Curves', fontweight='bold')
326+
ax3.legend()
327+
ax3.grid(True, alpha=0.3)
328+
ax3.set_yscale('log')
329+
330+
# Plot 4: Predictions
331+
ax4 = plt.subplot(2, 3, 4)
332+
train_idx = np.arange(seq_length, seq_length + len(train_preds))
333+
test_idx = np.arange(seq_length + len(train_preds),
334+
seq_length + len(train_preds) + len(test_preds))
335+
ax4.plot(train_idx, y_train.flatten(), 'b-', linewidth=1, alpha=0.5, label='Train True')
336+
ax4.plot(train_idx, train_preds, 'g-', linewidth=1, label='Train Pred')
337+
ax4.plot(test_idx, y_test.flatten(), 'r-', linewidth=1, alpha=0.5, label='Test True')
338+
ax4.plot(test_idx, test_preds, 'orange', linewidth=1, label='Test Pred')
339+
ax4.set_xlabel('Time Step')
340+
ax4.set_ylabel('Position')
341+
ax4.set_title('Predictions', fontweight='bold')
342+
ax4.legend(fontsize=8)
343+
ax4.grid(True, alpha=0.3)
344+
345+
# Plot 5: Error distribution
346+
ax5 = plt.subplot(2, 3, 5)
347+
errors = test_preds - y_test.flatten()
348+
ax5.hist(errors, bins=30, alpha=0.7, edgecolor='black')
349+
ax5.axvline(x=0, color='r', linestyle='--', linewidth=2)
350+
ax5.set_xlabel('Prediction Error')
351+
ax5.set_ylabel('Frequency')
352+
ax5.set_title(f'Error Distribution (MAE={mae:.4f})', fontweight='bold')
353+
ax5.grid(True, alpha=0.3, axis='y')
354+
355+
# Plot 6: Summary stats
356+
ax6 = plt.subplot(2, 3, 6)
357+
ax6.axis('off')
358+
summary_text = f"""
359+
TRAINING SUMMARY
360+
361+
Dataset:
362+
ODE points: {len(x)}
363+
Sequences: {len(X)}
364+
Train: {len(X_train)} (75%)
365+
Test: {len(X_test)} (25%)
366+
367+
Model: LSTM
368+
Hidden: 64
369+
Layers: 2
370+
Epochs: {epochs}
371+
372+
Results:
373+
MSE: {mse:.6f}
374+
RMSE: {rmse:.6f}
375+
MAE: {mae:.6f}
376+
R²: {r2:.6f}
377+
378+
Time: {total_time:.1f}s
379+
"""
380+
ax6.text(0.1, 0.5, summary_text, fontsize=11, family='monospace',
381+
verticalalignment='center')
382+
383+
plt.tight_layout()
384+
plt.show()
385+
#plt.savefig('/mnt/user-data/outputs/rnn_ode_optimized.png', dpi=150)
386+
print("\n✓ Plot saved: rnn_ode_optimized.png")
387+
388+
print("\n" + "="*70)
389+
print("COMPLETE!")
390+
print("="*70)
391+
print(f"\n✓ Successfully trained LSTM on ODE data")
392+
print(f"✓ Test R² score: {r2:.4f}")
393+
print(f"✓ No hanging issues!")
394+
print("="*70)

0 commit comments

Comments
 (0)