Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 78 additions & 5 deletions pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,75 @@ def toy_computation() -> None:
x.block_until_ready()


def _is_default_profile_options(
profiler_options: jax.profiler.ProfileOptions,
) -> bool:
if jax.version.__version_info__ < (0, 9, 2):
return True

default_options = jax.profiler.ProfileOptions()
return (
profiler_options.host_tracer_level == default_options.host_tracer_level
and profiler_options.python_tracer_level
== default_options.python_tracer_level
and profiler_options.duration_ms == default_options.duration_ms
and not getattr(profiler_options, "advanced_configuration", None)
)


def _create_profile_request(
log_dir: os.PathLike[str] | str,
profiler_options: jax.profiler.ProfileOptions | None = None,
) -> Mapping[str, Any]:
"""Creates a profile request mapping from the given options."""
profile_request = {}
profile_request["traceLocation"] = str(log_dir)
profile_request: dict[str, Any] = {
"traceLocation": str(log_dir),
}

if profiler_options is None or _is_default_profile_options(profiler_options):
return profile_request

advanced_config = None
if getattr(profiler_options, "advanced_configuration", None):
advanced_config = {}
for k, v in getattr(profiler_options, "advanced_configuration").items():
# Convert python dict to tensorflow.ProfileOptions.AdvancedConfigValue
# json-compatible dict
if isinstance(v, bool):
advanced_config[k] = {"boolValue": v}
elif isinstance(v, int):
advanced_config[k] = {"intValue": v}
elif isinstance(v, str):
advanced_config[k] = {"stringValue": v}
else:
raise ValueError(
f"Unsupported advanced configuration value type: {type(v)}. "
"Supported types are bool, int, and str."
)

xprof_options: dict[str, Any] = {
"traceDirectory": str(log_dir),
}

if profiler_options.host_tracer_level != 2:
xprof_options["hostTraceLevel"] = profiler_options.host_tracer_level

pw_trace_opts: dict[str, Any] = {}
if profiler_options.python_tracer_level:
pw_trace_opts["enablePythonTracer"] = bool(
profiler_options.python_tracer_level
)

if advanced_config:
pw_trace_opts["advancedConfiguration"] = advanced_config

if pw_trace_opts:
xprof_options["pwTraceOptions"] = pw_trace_opts

profile_request["xprofTraceOptions"] = xprof_options

if profiler_options.duration_ms > 0:
profile_request["maxDurationSecs"] = profiler_options.duration_ms / 1000.0

return profile_request

Expand Down Expand Up @@ -104,7 +167,7 @@ def start_trace(
*,
create_perfetto_link: bool = False,
create_perfetto_trace: bool = False,
profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument
profiler_options: jax.profiler.ProfileOptions | None = None,
) -> None:
"""Starts a profiler trace.

Expand Down Expand Up @@ -133,7 +196,6 @@ def start_trace(
This feature is experimental for Pathways on Cloud and may not be fully
supported.
profiler_options: Profiler options to configure the profiler for collection.
Options are not currently supported and ignored.
"""
if not str(log_dir).startswith("gs://"):
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
Expand All @@ -144,7 +206,18 @@ def start_trace(
"features for Pathways on Cloud and may not be fully supported."
)

_start_pathways_trace_from_profile_request(_create_profile_request(log_dir))
if jax.version.__version_info__ < (0, 9, 2) and profiler_options is not None:
_logger.warning(
"ProfileOptions are not supported until JAX 0.9.2 and will be omitted. "
"Some options can be specified via command line flags."
)
profiler_options = None

profile_request = _create_profile_request(log_dir, profiler_options)

_logger.debug("Profile request: %s", profile_request)

_start_pathways_trace_from_profile_request(profile_request)

_original_start_trace(
log_dir=log_dir,
Expand Down
65 changes: 59 additions & 6 deletions pathwaysutils/test/profiling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import json
import logging
import unittest
from unittest import mock

from absl.testing import absltest
Expand Down Expand Up @@ -225,9 +226,11 @@ def test_start_trace_success(self):

self.mock_toy_computation.assert_called_once()
self.mock_plugin_executable_cls.assert_called_once_with(
json.dumps(
{"profileRequest": {"traceLocation": "gs://test_bucket/test_dir"}}
)
json.dumps({
"profileRequest": {
"traceLocation": "gs://test_bucket/test_dir",
}
})
)
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
self.mock_original_start_trace.assert_called_once_with(
Expand Down Expand Up @@ -391,10 +394,60 @@ def test_monkey_patched_stop_server(self):

mocks["stop_server"].assert_called_once()

def test_create_profile_request_no_options(self):
request = profiling._create_profile_request("gs://bucket/dir")
self.assertEqual(request, {"traceLocation": "gs://bucket/dir"})
@parameterized.parameters(None, jax.profiler.ProfileOptions())
def test_create_profile_request_default_options(self, profiler_options):
request = profiling._create_profile_request(
"gs://bucket/dir", profiler_options=profiler_options
)
self.assertEqual(
request,
{
"traceLocation": "gs://bucket/dir",
},
)

@unittest.skipIf(
jax.version.__version_info__ < (0, 9, 2),
"ProfileOptions requires JAX 0.9.2 or newer",
)
def test_create_profile_request_with_options(self):
options = jax.profiler.ProfileOptions()
options.host_tracer_level = 2
options.python_tracer_level = 1
options.duration_ms = 2000
options.start_timestamp_ns = 123456789
options.advanced_configuration = {
"tpu_num_chips_to_profile_per_task": 3,
"tpu_num_sparse_core_tiles_to_trace": 5,
"tpu_trace_mode": "TRACE_COMPUTE",
}

request = profiling._create_profile_request(
"gs://bucket/dir", profiler_options=options
)
self.assertEqual(
request,
{
"traceLocation": "gs://bucket/dir",
"maxDurationSecs": 2.0,
"xprofTraceOptions": {
"traceDirectory": "gs://bucket/dir",
"pwTraceOptions": {
"enablePythonTracer": True,
"advancedConfiguration": {
"tpu_num_chips_to_profile_per_task": {"intValue": 3},
"tpu_num_sparse_core_tiles_to_trace": {"intValue": 5},
"tpu_trace_mode": {"stringValue": "TRACE_COMPUTE"},
},
},
},
},
)

@unittest.skipIf(
jax.version.__version_info__ < (0, 9, 2),
"ProfileOptions requires JAX 0.9.2 or newer",
)
@parameterized.parameters(
({"traceLocation": "gs://test_bucket/test_dir"},),
({
Expand Down
Loading