Skip to content

Commit 17a76a6

Browse files
authored
Ensure duplicate arguments are only checked within their respective argument groups
Differential Revision: D57459718 Pull Request resolved: #911
1 parent ac3cc78 commit 17a76a6

File tree

2 files changed

+45
-10
lines changed

2 files changed

+45
-10
lines changed

torchx/cli/cmd_run.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import threading
1414
from collections import Counter
1515
from dataclasses import asdict
16+
from itertools import groupby
1617
from pathlib import Path
1718
from pprint import pformat
1819
from typing import Dict, List, Optional, Tuple
@@ -85,16 +86,19 @@ def _parse_component_name_and_args(
8586
component = args[0]
8687
component_args = args[1:]
8788

88-
# Error if there are repeated command line arguments
89-
all_options = [
90-
x
91-
for x in component_args
92-
if x.startswith("-") and x.strip() != "-" and x.strip() != "--"
93-
]
94-
arg_count = Counter(all_options)
95-
duplicates = [arg for arg, count in arg_count.items() if count > 1]
96-
if len(duplicates) > 0:
97-
subparser.error(f"Repeated Command Line Arguments: {duplicates}")
89+
# Error if there are repeated command line arguments each group of arguments,
90+
# where the groups are separated by "--"
91+
arg_groups = [list(g) for _, g in groupby(component_args, key=lambda x: x == "--")]
92+
for arg_group in arg_groups:
93+
all_options = [
94+
x
95+
for x in arg_group
96+
if x.startswith("-") and x.strip() != "-" and x.strip() != "--"
97+
]
98+
arg_count = Counter(all_options)
99+
duplicates = [arg for arg, count in arg_count.items() if count > 1]
100+
if len(duplicates) > 0:
101+
subparser.error(f"Repeated Command Line Arguments: {duplicates}")
98102

99103
if not component:
100104
subparser.error(MISSING_COMPONENT_ERROR_MSG)

torchx/cli/test/cmd_run_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,32 @@ def test_parse_component_name_and_args_no_default(self) -> None:
251251
),
252252
)
253253

254+
self.assertEqual(
255+
(
256+
"fb.python.binary",
257+
[
258+
"--img",
259+
"lex_ig_o3_package",
260+
"-m",
261+
"dper_lib.instagram.pyper_v2.teams.stories.train",
262+
"--",
263+
"-m",
264+
],
265+
),
266+
_parse_component_name_and_args(
267+
[
268+
"fb.python.binary",
269+
"--img",
270+
"lex_ig_o3_package",
271+
"-m",
272+
"dper_lib.instagram.pyper_v2.teams.stories.train",
273+
"--",
274+
"-m",
275+
],
276+
sp,
277+
),
278+
)
279+
254280
with self.assertRaises(SystemExit):
255281
_parse_component_name_and_args(["--"], sp)
256282

@@ -271,6 +297,11 @@ def test_parse_component_name_and_args_no_default(self) -> None:
271297
["--msg ", "hello", "--msg ", "repeate"], sp
272298
)
273299

300+
with self.assertRaises(SystemExit):
301+
_parse_component_name_and_args(
302+
["--m", "hello", "--", "--msg", "msg", "--msg", "repeate"], sp
303+
)
304+
274305
def test_parse_component_name_and_args_with_default(self) -> None:
275306
sp = argparse.ArgumentParser(prog="test")
276307
dirs = [str(self.tmpdir)]

0 commit comments

Comments
 (0)