Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,51 @@ class Dummy:
assert "runtime_env" not in captured


class TestTimeitMark:
"""Tests for the ``timeit.mark_start`` / ``mark_end`` non-context-manager API."""

def setup_method(self):
from torchrl._utils import timeit

timeit._REG.clear()
timeit._MARKS.clear()

def test_mark_start_end_records_into_reg(self):
from torchrl._utils import timeit

timeit.mark_start("alpha")
timeit.mark_end("alpha")
assert "alpha" in timeit._REG
avg, total, count = timeit._REG["alpha"]
assert count == 1
assert total >= 0.0
assert avg == total

def test_mark_end_pops_outstanding_mark(self):
from torchrl._utils import timeit

timeit.mark_start("beta")
assert "beta" in timeit._MARKS
timeit.mark_end("beta")
assert "beta" not in timeit._MARKS

def test_mark_env_is_alias_of_mark_end(self):
from torchrl._utils import timeit

assert timeit.mark_env is not timeit.mark_end # bound classmethod identity
timeit.mark_start("gamma")
timeit.mark_env("gamma")
assert "gamma" in timeit._REG
assert timeit._REG["gamma"][2] == 1

def test_context_manager_still_records(self):
from torchrl._utils import timeit

with timeit("delta"):
pass
assert "delta" in timeit._REG


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
29 changes: 24 additions & 5 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ class timeit:
"""

_REG = {}
_MARKS = {}

def __init__(self, name):
self.name = name
Expand Down Expand Up @@ -295,16 +296,34 @@ def elapsed(self) -> float:
"""
return time.time() - self.t0

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
t = self.elapsed()
val = self._REG.setdefault(self.name, [0.0, 0.0, 0])
@classmethod
def _record(cls, name: str, elapsed: float) -> None:
val = cls._REG.setdefault(name, [0.0, 0.0, 0])

count = val[2]
N = count + 1
val[0] = val[0] * (count / N) + t / N
val[1] += t
val[0] = val[0] * (count / N) + elapsed / N
val[1] += elapsed
val[2] = N

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self._record(self.name, self.elapsed())

@classmethod
def mark_start(cls, name: str) -> None:
"""Mark the start of a named timed event."""
cls._MARKS[name] = time.time()

@classmethod
def mark_end(cls, name: str) -> None:
"""Mark the end of a named timed event and record its elapsed time."""
cls._record(name, time.time() - cls._MARKS.pop(name))

@classmethod
def mark_env(cls, name: str) -> None:
"""Alias for :meth:`mark_end`."""
cls.mark_end(name)
Comment on lines +322 to +325
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this


@staticmethod
def print(prefix: str | None = None) -> str: # noqa: T202
"""Prints the state of the timer.
Expand Down
Loading