Skip to content

Commit b4ba17b

Browse files
authored
Add start_trace and stop_trace API in profiler (#8743)
1 parent 71d0ce6 commit b4ba17b

File tree

8 files changed

+202
-0
lines changed

8 files changed

+202
-0
lines changed

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ function run_xla_op_tests1 {
157157
run_test "$CDIR/test_async_closures.py"
158158
run_test "$CDIR/test_hlo_metadata.py"
159159
run_test "$CDIR/test_profiler.py"
160+
run_test "$CDIR/test_profiler_session.py"
160161
run_test "$CDIR/pjrt/test_runtime.py"
161162
run_test "$CDIR/pjrt/test_runtime_single_proc_gpu.py"
162163
run_test "$CDIR/pjrt/test_runtime_multi_gpu.py"

test/test_profiler_session.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import glob
2+
import os
3+
from absl.testing import absltest
4+
5+
import torch
6+
import torch_xla.debug.profiler as xp
7+
8+
9+
def _run_computation():
10+
11+
class M(torch.nn.Module):
12+
13+
def __init__(self):
14+
super(M, self).__init__()
15+
self.fc1 = torch.nn.Linear(10, 5)
16+
self.fc2 = torch.nn.Linear(5, 10)
17+
18+
def forward(self, x):
19+
with xp.Trace('fc1'):
20+
x = self.fc1(x)
21+
with xp.Trace('fc2'):
22+
x = self.fc2(x)
23+
return x
24+
25+
m = M()
26+
m = m.to('xla')
27+
x = torch.randn(10, 10).to('xla')
28+
for _ in range(20):
29+
y = m(x)
30+
y.cpu()
31+
32+
33+
class TestProfilerSession(absltest.TestCase):
34+
35+
def setUp(self):
36+
self.server = xp.start_server(8005)
37+
38+
def test_start_and_stop(self):
39+
tempdir = self.create_tempdir().full_path
40+
xp.start_trace(tempdir)
41+
_run_computation()
42+
xp.stop_trace()
43+
tempdir2 = self.create_tempdir().full_path
44+
xp.start_trace(tempdir2)
45+
_run_computation()
46+
xp.stop_trace()
47+
files = glob.glob(
48+
os.path.join(tempdir, '**', '*.xplane.pb'), recursive=True)
49+
self.assertEqual(len(files), 1)
50+
files = glob.glob(
51+
os.path.join(tempdir2, '**', '*.xplane.pb'), recursive=True)
52+
self.assertEqual(len(files), 1)
53+
54+
def test_error_double_start(self):
55+
tempdir = self.create_tempdir().full_path
56+
xp.start_trace(tempdir)
57+
try:
58+
with self.assertRaisesRegex(RuntimeError,
59+
"Only one profile may be run at a time."):
60+
xp.start_trace(tempdir)
61+
finally:
62+
xp.stop_trace()
63+
64+
def test_error_stop_before_start(self):
65+
with self.assertRaisesRegex(RuntimeError, "No profile started"):
66+
xp.stop_trace()
67+
68+
69+
if __name__ == '__main__':
70+
absltest.main()

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
3737
python3 "$TEST_CDIR/test_pallas.py" -v
3838
python3 "$TEST_CDIR/test_pallas_spmd.py"
3939
XLA_DISABLE_FUNCTIONALIZATION=1 python3 "$TEST_CDIR/test_pallas_spmd.py"
40+
python3 "$TEST_CDIR/test_profiler_session.py"
4041
python3 "$TEST_CDIR/test_multi_queries_paged_attention_kernel.py"
4142
python3 "$TEST_CDIR/test_ragged_paged_attention_kernel.py"
4243
python3 "$TEST_CDIR/test_input_output_aliases.py"

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,24 @@ void BuildProfilerSubmodule(py::module* m) {
978978
[](const std::string& name) -> std::unique_ptr<torch::lazy::ScopePusher> {
979979
return absl::make_unique<torch::lazy::ScopePusher>(name);
980980
});
981+
982+
// Profiler Session Definition.
983+
py::class_<runtime::profiler::TslProfilerSessionWrapper,
984+
std::unique_ptr<runtime::profiler::TslProfilerSessionWrapper>>
985+
profiler_session_class(profiler, "TslProfilerSessionWrapper");
986+
profiler_session_class.def(
987+
py::init(&runtime::profiler::TslProfilerSessionWrapper::Create));
988+
profiler_session_class.def("stop", [](py::object self) -> py::bytes {
989+
std::string xspace_str =
990+
py::cast<runtime::profiler::TslProfilerSessionWrapper*>(self)->Stop();
991+
return py::bytes(xspace_str);
992+
});
993+
profiler_session_class.def("export", [](py::object self, py::bytes xspace,
994+
const std::string& dump_dir) {
995+
const std::string xspace_str = xspace.cast<std::string>();
996+
py::cast<runtime::profiler::TslProfilerSessionWrapper*>(self)->Export(
997+
xspace_str, dump_dir);
998+
});
981999
}
9821000

9831001
class PyLoweringContext {

torch_xla/csrc/runtime/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,10 @@ cc_library(
302302
"@com_google_absl//absl/status",
303303
"@xla//xla/backends/profiler/plugin:profiler_c_api_hdrs",
304304
"@xla//xla/backends/profiler/plugin:plugin_tracer",
305+
"@xla//xla/pjrt:status_casters",
305306
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
306307
"@tsl//tsl/profiler/lib:profiler_factory",
308+
"@tsl//tsl/profiler/lib:profiler_session",
307309
"@xla//xla/tsl/profiler/rpc:profiler_server_impl",
308310
"@xla//xla/tsl/profiler/rpc/client:capture_profile",
309311
"@com_google_absl//absl/container:flat_hash_map",

torch_xla/csrc/runtime/profiler.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "xla/backends/profiler/plugin/plugin_tracer.h"
88
#include "xla/backends/profiler/plugin/profiler_c_api.h"
99
#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h"
10+
#include "xla/pjrt/status_casters.h"
1011
#include "xla/tsl/profiler/rpc/client/capture_profile.h"
1112
#include "xla/tsl/profiler/rpc/profiler_server.h"
1213

@@ -45,6 +46,31 @@ void ProfilerServer::Start(int port) {
4546

4647
ProfilerServer::~ProfilerServer() {}
4748

49+
const std::string TslProfilerSessionWrapper::Stop() const {
50+
tensorflow::profiler::XSpace xspace;
51+
// Disables the ProfilerSession
52+
xla::ThrowIfError(this->session->CollectData(&xspace));
53+
std::string xspace_str = xspace.SerializeAsString();
54+
return xspace_str;
55+
}
56+
57+
void TslProfilerSessionWrapper::Export(
58+
const std::string& xspace_str, const std::string& tensorboard_dir) const {
59+
tensorflow::profiler::XSpace xspace_proto;
60+
xspace_proto.ParseFromString(xspace_str);
61+
xla::ThrowIfError(
62+
tsl::profiler::ExportToTensorBoard(xspace_proto, tensorboard_dir,
63+
/* also_export_trace_json= */ true));
64+
}
65+
66+
std::unique_ptr<TslProfilerSessionWrapper> TslProfilerSessionWrapper::Create() {
67+
tensorflow::ProfileOptions options = tsl::ProfilerSession::DefaultOptions();
68+
options.set_python_tracer_level(1);
69+
options.set_enable_hlo_proto(true);
70+
return absl::make_unique<runtime::profiler::TslProfilerSessionWrapper>(
71+
tsl::ProfilerSession::Create(options));
72+
}
73+
4874
absl::Status Trace(
4975
const char* service_addr, const char* logdir, int duration_ms,
5076
int num_tracing_attempts,

torch_xla/csrc/runtime/profiler.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "absl/container/flat_hash_map.h"
77
#include "absl/status/status.h"
8+
#include "tsl/profiler/lib/profiler_session.h"
89
#include "xla/pjrt/c/pjrt_c_api.h"
910

1011
namespace torch_xla {
@@ -23,6 +24,25 @@ class ProfilerServer {
2324
std::unique_ptr<Impl> impl_;
2425
};
2526

27+
// Profiler session implementation is based on OpenXLA, we cannot reuse
28+
// the Python binding since it's using nanobind and torch_xla is using pybind11.
29+
// https://github.com/openxla/xla/blob/main/xla/python/profiler.cc
30+
class TslProfilerSessionWrapper {
31+
public:
32+
static std::unique_ptr<TslProfilerSessionWrapper> Create();
33+
34+
explicit TslProfilerSessionWrapper(
35+
std::unique_ptr<tsl::ProfilerSession> session)
36+
: session(std::move(session)) {}
37+
38+
void Export(const std::string& xspace_str,
39+
const std::string& tensorboard_dir) const;
40+
const std::string Stop() const;
41+
42+
private:
43+
std::unique_ptr<tsl::ProfilerSession> session;
44+
};
45+
2646
absl::Status Trace(
2747
const char* service_addr, const char* logdir, int duration_ms,
2848
int num_tracing_attempts,

torch_xla/debug/profiler.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import functools
2+
import os
23
import threading
4+
35
import torch_xla
46
import torch_xla.core.xla_model as xm
57

@@ -183,3 +185,65 @@ def wrapper_trace_me(*args, **kwargs):
183185
return wrapper_trace_me
184186

185187
return decorator_trace_me
188+
189+
190+
# The profiler implementation is based on JAX implementation
191+
# https://github.com/jax-ml/jax/blob/main/jax/_src/profiler.py
192+
class _ProfileState:
193+
194+
def __init__(self):
195+
self.profile_session = None
196+
self.log_dir = None
197+
self.create_perfetto_link = False
198+
self.create_perfetto_trace = False
199+
self.lock = threading.Lock()
200+
201+
def reset(self):
202+
_profile_state.profile_session = None
203+
_profile_state.create_perfetto_link = False
204+
_profile_state.create_perfetto_trace = False
205+
_profile_state.log_dir = None
206+
207+
208+
_profile_state = _ProfileState()
209+
210+
211+
def start_trace(log_dir: os.PathLike | str) -> None:
212+
"""Starts a profiler trace.
213+
214+
The trace will capture CPU, GPU, and/or TPU activity, including Python
215+
functions and PyTorch/XLA on-device operations. Use :func:`stop_trace` to end
216+
the trace and save the results to ``log_dir``.
217+
218+
The resulting trace can be viewed with TensorBoard. Note that TensorBoard
219+
doesn't need to be running when collecting the trace.
220+
221+
Only one trace may be collected at a time. A RuntimeError will be raised if
222+
:func:`start_trace` is called while another trace is running.
223+
224+
Args:
225+
log_dir: The directory to save the profiler trace to (usually the
226+
TensorBoard log directory).
227+
"""
228+
with _profile_state.lock:
229+
if _profile_state.profile_session is not None:
230+
raise RuntimeError("Profile has already been started. "
231+
"Only one profile may be run at a time.")
232+
233+
_profile_state.profile_session = torch_xla._XLAC.profiler.TslProfilerSessionWrapper(
234+
)
235+
_profile_state.log_dir = str(log_dir)
236+
237+
238+
def stop_trace() -> None:
239+
"""Stops the currently-running profiler trace.
240+
241+
The trace will be saved to the ``log_dir`` passed to the corresponding
242+
:func:`start_trace` call. Raises a RuntimeError if a trace hasn't been started.
243+
"""
244+
with _profile_state.lock:
245+
if _profile_state.profile_session is None:
246+
raise RuntimeError("No profile started")
247+
sess = _profile_state.profile_session
248+
sess.export(sess.stop(), str(_profile_state.log_dir))
249+
_profile_state.reset()

0 commit comments

Comments
 (0)