Skip to content

Commit 0bfee30

Browse files
Julie Ganeshanfacebook-github-bot
authored andcommitted
Move logic from TorchX CLI -> API, so MVAI can call it (#955)
Summary: Pull Request resolved: #955 MVAI's "light" is synchronous - you can immediately see the logs for jobs you start. Only "fire" is asynchronous. TorchX's API, since it's generic, *always* creates jobs that are asynchronous. Therefore, there isn't a built-in interface for "tailing" the stderr of every started process - just for tailing individual replicas of a given role. The TorchX CLI's `torchx run` command **has** implemented this, but its implementation is coupled with the CLI implementations of `torchx run` and `torchx log`. This diff extracts the useful logic into a helper function of the TorchX API Reviewed By: andywag Differential Revision: D62463211
1 parent ce17fbb commit 0bfee30

File tree

3 files changed

+223
-35
lines changed

3 files changed

+223
-35
lines changed

torchx/cli/cmd_log.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
from torchx.schedulers.api import Stream
2424
from torchx.specs.api import is_started
2525
from torchx.specs.builders import make_app_handle
26+
from torchx.util.log_tee_helpers import (
27+
_find_role_replicas as find_role_replicas,
28+
_prefix_line,
29+
)
2630

2731
from torchx.util.types import none_throws
2832

@@ -39,19 +43,6 @@ def validate(job_identifier: str) -> None:
3943
sys.exit(1)
4044

4145

42-
def _prefix_line(prefix: str, line: str) -> str:
43-
"""
44-
_prefix_line ensure the prefix is still present even when dealing with return characters
45-
"""
46-
if "\r" in line:
47-
line = line.replace("\r", f"\r{prefix}")
48-
if "\n" in line[:-1]:
49-
line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:]
50-
if not line.startswith("\r"):
51-
line = f"{prefix}{line}"
52-
return line
53-
54-
5546
def print_log_lines(
5647
file: TextIO,
5748
runner: Runner,
@@ -167,17 +158,6 @@ def get_logs(
167158
raise threads_exceptions[0]
168159

169160

170-
def find_role_replicas(
171-
app: specs.AppDef, role_name: Optional[str]
172-
) -> List[Tuple[str, int]]:
173-
role_replicas = []
174-
for role in app.roles:
175-
if role_name is None or role_name == role.name:
176-
for i in range(role.num_replicas):
177-
role_replicas.append((role.name, i))
178-
return role_replicas
179-
180-
181161
class CmdLog(SubCommand):
182162
def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
183163
subparser.add_argument(

torchx/cli/cmd_run.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torchx.specs as specs
2222
from torchx.cli.argparse_util import ArgOnceAction, torchxconfig_run
2323
from torchx.cli.cmd_base import SubCommand
24-
from torchx.cli.cmd_log import get_logs
2524
from torchx.runner import config, get_runner, Runner
2625
from torchx.runner.config import load_sections
2726
from torchx.schedulers import get_default_scheduler_name, get_scheduler_factories
@@ -32,6 +31,7 @@
3231
get_builtin_source,
3332
get_components,
3433
)
34+
from torchx.util.log_tee_helpers import tee_logs
3535
from torchx.util.types import none_throws
3636

3737

@@ -288,16 +288,14 @@ def _wait_and_exit(self, runner: Runner, app_handle: str, log: bool) -> None:
288288
logger.debug(status)
289289

290290
def _start_log_thread(self, runner: Runner, app_handle: str) -> threading.Thread:
291-
thread = threading.Thread(
292-
target=get_logs,
293-
kwargs={
294-
"file": sys.stderr,
295-
"runner": runner,
296-
"identifier": app_handle,
297-
"regex": None,
298-
"should_tail": True,
299-
},
291+
thread = tee_logs(
292+
dst=sys.stderr,
293+
app_handle=app_handle,
294+
regex=None,
295+
runner=runner,
296+
should_tail=True,
297+
streams=None,
298+
colorize=not sys.stderr.closed and sys.stderr.isatty(),
300299
)
301-
thread.daemon = True
302300
thread.start()
303301
return thread

torchx/util/log_tee_helpers.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
"""
10+
If you're wrapping the TorchX API with your own CLI, these functions can
11+
help show the logs of the job within your CLI, just like
12+
`torchx log`
13+
"""
14+
15+
import logging
16+
import threading
17+
from queue import Queue
18+
from typing import List, Optional, TextIO, Tuple, TYPE_CHECKING
19+
20+
from torchx.util.types import none_throws
21+
22+
if TYPE_CHECKING:
23+
from torchx.runner.api import Runner
24+
from torchx.schedulers.api import Stream
25+
from torchx.specs.api import AppDef
26+
27+
logger: logging.Logger = logging.getLogger(__name__)
28+
29+
# A torchX job can have stderr/stdout for many replicas, of many roles
30+
# The scheduler API has functions that allow us to get,
31+
# with unspecified detail, the log lines of a given replica of
32+
# a given role.
33+
#
34+
# So, to neatly tee the results, we:
35+
# 1) Determine every role ID / replica ID pair we want to monitor
36+
# 2) Request the given stderr / stdout / combined streams from them (1 thread each)
37+
# 3) Concatenate each of those streams to a given destination file
38+
39+
40+
def _find_role_replicas(
41+
app: "AppDef",
42+
role_name: Optional[str],
43+
) -> List[Tuple[str, int]]:
44+
"""
45+
Enumerate all (role, replica id) pairs in the given AppDef.
46+
Replica IDs are 0-indexed, and range up to num_replicas,
47+
for each role.
48+
If role_name is provided, filters to only that name.
49+
"""
50+
role_replicas = []
51+
for role in app.roles:
52+
if role_name is None or role_name == role.name:
53+
for i in range(role.num_replicas):
54+
role_replicas.append((role.name, i))
55+
return role_replicas
56+
57+
58+
def _prefix_line(prefix: str, line: str) -> str:
59+
"""
60+
_prefix_line ensure the prefix is still present even when dealing with return characters
61+
"""
62+
if "\r" in line:
63+
line = line.replace("\r", f"\r{prefix}")
64+
if "\n" in line[:-1]:
65+
line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:]
66+
if not line.startswith("\r"):
67+
line = f"{prefix}{line}"
68+
return line
69+
70+
71+
def _print_log_lines_for_role_replica(
72+
dst: TextIO,
73+
app_handle: str,
74+
regex: Optional[str],
75+
runner: "Runner",
76+
which_role: str,
77+
which_replica: int,
78+
exceptions: "Queue[Exception]",
79+
should_tail: bool,
80+
streams: Optional["Stream"],
81+
colorize: bool = False,
82+
) -> None:
83+
"""
84+
Helper function that'll run in parallel - one
85+
per monitored replica of a given role.
86+
87+
Based on print_log_lines .. but not designed for TTY
88+
"""
89+
try:
90+
for line in runner.log_lines(
91+
app_handle,
92+
which_role,
93+
which_replica,
94+
regex,
95+
should_tail=should_tail,
96+
streams=streams,
97+
):
98+
if colorize:
99+
color_begin = "\033[32m"
100+
color_end = "\033[0m"
101+
else:
102+
color_begin = ""
103+
color_end = ""
104+
prefix = f"{color_begin}{which_role}/{which_replica}{color_end} "
105+
print(_prefix_line(prefix, line), file=dst, end="", flush=True)
106+
except Exception as e:
107+
exceptions.put(e)
108+
raise
109+
110+
111+
def _start_threads_to_monitor_role_replicas(
112+
dst: TextIO,
113+
app_handle: str,
114+
regex: Optional[str],
115+
runner: "Runner",
116+
which_role: Optional[str] = None,
117+
should_tail: bool = False,
118+
streams: Optional["Stream"] = None,
119+
colorize: bool = False,
120+
) -> None:
121+
threads = []
122+
123+
app = none_throws(runner.describe(app_handle))
124+
replica_ids = _find_role_replicas(app, role_name=which_role)
125+
126+
# Holds exceptions raised by all threads, in a thread-safe
127+
# object
128+
exceptions = Queue()
129+
130+
if not replica_ids:
131+
valid_roles = [role.name for role in app.roles]
132+
raise ValueError(
133+
f"{which_role} is not a valid role name. Available: {valid_roles}"
134+
)
135+
136+
for role_name, replica_id in replica_ids:
137+
threads.append(
138+
threading.Thread(
139+
target=_print_log_lines_for_role_replica,
140+
kwargs={
141+
"dst": dst,
142+
"runner": runner,
143+
"app_handle": app_handle,
144+
"which_role": role_name,
145+
"which_replica": replica_id,
146+
"regex": regex,
147+
"should_tail": should_tail,
148+
"exceptions": exceptions,
149+
"streams": streams,
150+
"colorize": colorize,
151+
},
152+
daemon=True,
153+
)
154+
)
155+
156+
for t in threads:
157+
t.start()
158+
159+
for t in threads:
160+
t.join()
161+
162+
# Retrieve all exceptions, print all except one and raise the first recorded exception
163+
threads_exceptions = []
164+
while not exceptions.empty():
165+
threads_exceptions.append(exceptions.get())
166+
167+
if len(threads_exceptions) > 0:
168+
for i in range(1, len(threads_exceptions)):
169+
logger.error(threads_exceptions[i])
170+
171+
raise threads_exceptions[0]
172+
173+
174+
def tee_logs(
175+
dst: TextIO,
176+
app_handle: str,
177+
regex: Optional[str],
178+
runner: "Runner",
179+
should_tail: bool = False,
180+
streams: Optional["Stream"] = None,
181+
colorize: bool = False,
182+
) -> threading.Thread:
183+
"""
184+
Makes a thread, which in turn will start 1 thread per replica
185+
per role, that tees that role-replica's logs to the given
186+
destination buffer.
187+
188+
You'll need to start and join with this parent thread.
189+
190+
dst: TextIO to tee the logs into
191+
app_handle: The return value of runner.run() or runner.schedule()
192+
regex: Regex to filter the logs that are tee-d
193+
runner: The Runner you used to schedule the job
194+
should_tail: If true, continue until we run out of logs. Otherwise, just fetch
195+
what's available
196+
streams: Whether to fetch STDERR, STDOUT, or the temporally COMBINED (default) logs
197+
"""
198+
thread = threading.Thread(
199+
target=_start_threads_to_monitor_role_replicas,
200+
kwargs={
201+
"dst": dst,
202+
"runner": runner,
203+
"app_handle": app_handle,
204+
"regex": None,
205+
"should_tail": True,
206+
"colorize": colorize,
207+
},
208+
daemon=True,
209+
)
210+
return thread

0 commit comments

Comments
 (0)