Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions code_to_optimize/discrete_riccati.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
Utility functions used in CompEcon
"""Utility functions used in CompEcon

Based routines found in the CompEcon toolbox by Miranda and Fackler.

Expand All @@ -9,14 +8,15 @@
and Finance, MIT Press, 2002.

"""

from functools import reduce

import numpy as np
import torch


def ckron(*arrays):
"""
Repeatedly applies the np.kron function to an arbitrary number of
"""Repeatedly applies the np.kron function to an arbitrary number of
input arrays

Parameters
Expand All @@ -43,8 +43,7 @@ def ckron(*arrays):


def gridmake(*arrays):
"""
Expands one or more vectors (or matrices) into a matrix where rows span the
"""Expands one or more vectors (or matrices) into a matrix where rows span the
cartesian product of combinations of the input arrays. Each column of the
input arrays will correspond to one column of the output matrix.

Expand Down Expand Up @@ -79,13 +78,11 @@ def gridmake(*arrays):
out = _gridmake2(out, arr)

return out
else:
raise NotImplementedError("Come back here")
raise NotImplementedError("Come back here")


def _gridmake2(x1, x2):
"""
Expands two vectors (or matrices) into a matrix where rows span the
"""Expands two vectors (or matrices) into a matrix where rows span the
cartesian product of combinations of the input arrays. Each column of the
input arrays will correspond to one column of the output matrix.

Expand Down Expand Up @@ -114,19 +111,17 @@ def _gridmake2(x1, x2):

"""
if x1.ndim == 1 and x2.ndim == 1:
return np.column_stack([np.tile(x1, x2.shape[0]),
np.repeat(x2, x1.shape[0])])
elif x1.ndim > 1 and x2.ndim == 1:
return np.column_stack([np.tile(x1, x2.shape[0]), np.repeat(x2, x1.shape[0])])
if x1.ndim > 1 and x2.ndim == 1:
first = np.tile(x1, (x2.shape[0], 1))
second = np.repeat(x2, x1.shape[0])
return np.column_stack([first, second])
else:
raise NotImplementedError("Come back here")
raise NotImplementedError("Come back here")


@torch.compile
def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
"""
PyTorch version of _gridmake2.
"""PyTorch version of _gridmake2.

Expands two tensors into a matrix where rows span the cartesian product
of combinations of the input tensors. Each column of the input tensors
Expand Down Expand Up @@ -157,14 +152,18 @@ def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:

"""
if x1.dim() == 1 and x2.dim() == 1:
# tile x1 by x2.shape[0] times, repeat_interleave x2 by x1.shape[0]
first = x1.tile(x2.shape[0])
second = x2.repeat_interleave(x1.shape[0])
return torch.column_stack([first, second])
elif x1.dim() > 1 and x2.dim() == 1:
# tile x1 along first dimension
first = x1.tile(x2.shape[0], 1)
second = x2.repeat_interleave(x1.shape[0])
# Avoid unnecessary .tile, which is slow, by repeat_interleave & repeat + reshape
m = x1.shape[0]
n = x2.shape[0]
first = x1.repeat(n)
second = x2.repeat_interleave(m)
return torch.stack((first, second), dim=1)
if x1.dim() > 1 and x2.dim() == 1:
# For 2D or higher dims -- for each row in x1, repeat for each entry in x2
m = x1.shape[0]
n = x2.shape[0]
# This method avoids .tile which makes unnecessary copies
first = x1.repeat(n, 1)
second = x2.repeat_interleave(m)
return torch.column_stack([first, second])
else:
raise NotImplementedError("Come back here")
raise NotImplementedError("Come back here")
Loading