Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions beginner_source/examples_autograd/polynomial_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
-------------------------------

A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi`
to :math:`\pi` by minimizing squared Euclidean distance.
to :math:`\pi` by minimizing squared Euclidean distance. Feel free to try other
functions such as `y=\exp(x)` with faster convergence and play with the learning
rate.


This implementation computes the forward pass using operations on PyTorch
Tensors, and uses PyTorch autograd to compute gradients.
Expand All @@ -27,8 +30,8 @@
# Create Tensors to hold input and outputs.
# By default, requires_grad=False, which indicates that we do not need to
# compute gradients with respect to these Tensors during the backward pass.
x = torch.linspace(-1, 1, 2000, dtype=dtype)
y = torch.exp(x) # A Taylor expansion would be 1 + x + (1/2) x**2 + (1/3!) x**3 + ...
x = torch.linspace(-math.pi, math.pi, 2000, dtype=dtype)
y = torch.sin(x) # Note that a Taylor expansion would be y = 0 + x + 0 + (-1/3!) x^3 + ...

# Create random Tensors for weights. For a third order polynomial, we need
# 4 weights: y = a + b x + c x^2 + d x^3
Expand All @@ -40,7 +43,7 @@
d = torch.randn((), dtype=dtype, requires_grad=True)

initial_loss = 1.
learning_rate = 1e-5
learning_rate = 1e-6
for t in range(5000):
# Forward pass: compute predicted y using operations on Tensors.
y_pred = a + b * x + c * x ** 2 + d * x ** 3
Expand All @@ -50,7 +53,7 @@
# loss.item() gets the scalar value held in the loss.
loss = (y_pred - y).pow(2).sum()

# Calculare initial loss, so we can report loss relative to it
# Calculate initial loss, so we can report loss relative to it
if t==0:
initial_loss=loss.item()

Expand Down