Skip to content

Commit e11962d

Browse files
kiukchungfacebook-github-bot
authored andcommitted
(2/n torchx-allocator)(monarch/tools) implement a way to resolve hostname to ipv6 or ipv4 address (#295)
Summary: implement `monarch.tools.network.get_ip_addr(hostname)` that resolves a hostname to an ip address by: 1. Look up a TCP (`SOCK_STREAM`) compatible ipv6 address 2. If not found, fall-back to TCP compatible ipv4 address 3. Error if no ipv6 or ipv4 This is required since hyperactor's `ChannelAddr` for the `tcp` transport (e.g. `tcp!slurm-compute-node-0:26600`) takes a socketaddr (ip + port) not a hostname. So the caller has to resolve the hostname (typically queried from the scheduler) to an ip. Reviewed By: ahmadsharif1 Differential Revision: D76846286
1 parent 35e9632 commit e11962d

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

python/monarch/tools/network.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
import logging
9+
import socket
10+
from typing import Optional
11+
12+
logger: logging.Logger = logging.getLogger(__name__)
13+
14+
15+
def get_ip_addr(hostname: str) -> str:
16+
"""Resolves and returns the ip address of the given hostname.
17+
18+
This function will return an ipv6 address if one that can bind
19+
`SOCK_STREAM` (TCP) socket is found. Otherwise it will fall-back
20+
to resolving an ipv4 `SOCK_STREAM` address.
21+
22+
Raises a `RuntimeError` if neither ipv6 or ipv4 ip can be resolved from hostname.
23+
"""
24+
25+
def get_sockaddr(family: socket.AddressFamily) -> Optional[str]:
26+
try:
27+
addrs = socket.getaddrinfo(
28+
hostname, port=None, family=family, type=socket.SOCK_STREAM
29+
) # tcp
30+
if addrs:
31+
# socket.getaddrinfo return a list of addr 5-tuple addr infos
32+
_, _, _, _, sockaddr = addrs[0] # use the first address
33+
34+
# sockaddr is a tuple (ipv4) or a 4-tuple (ipv6) where the first element is the ip addr
35+
ipaddr = str(sockaddr[0])
36+
37+
logger.info(
38+
"Resolved %s address: `%s` for host: `%s`",
39+
family.name,
40+
ipaddr,
41+
hostname,
42+
)
43+
return str(ipaddr)
44+
else:
45+
return None
46+
except socket.gaierror as e:
47+
logger.info(
48+
"No %s address that can bind TCP sockets for host: %s. %s",
49+
family.name,
50+
hostname,
51+
e,
52+
)
53+
return None
54+
55+
ipaddr = get_sockaddr(socket.AF_INET6) or get_sockaddr(socket.AF_INET)
56+
if not ipaddr:
57+
raise RuntimeError(
58+
f"Unable to resolve `{hostname}` to ipv6 or ipv4 address that can bind TCP socket."
59+
" Check the network configuration on the host."
60+
)
61+
return ipaddr

python/tests/tools/test_network.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import socket
10+
import unittest
11+
from unittest import mock
12+
13+
from monarch.tools import network
14+
15+
16+
class TestNetwork(unittest.TestCase):
17+
def test_network_ipv4_fallback(self) -> None:
18+
with mock.patch(
19+
"socket.getaddrinfo",
20+
side_effect=[
21+
socket.gaierror,
22+
[
23+
(
24+
socket.AF_INET,
25+
socket.SOCK_STREAM,
26+
socket.IPPROTO_TCP,
27+
"",
28+
("123.45.67.89", 80),
29+
)
30+
],
31+
],
32+
):
33+
self.assertEqual(
34+
"123.45.67.89", network.get_ip_addr("foo.bar.facebook.com")
35+
)
36+
37+
def test_network_ipv6(self) -> None:
38+
with mock.patch(
39+
"socket.getaddrinfo",
40+
return_value=(
41+
[
42+
(
43+
socket.AF_INET6,
44+
socket.SOCK_STREAM,
45+
socket.IPPROTO_TCP,
46+
"",
47+
("1234:ab00:567c:89d:abcd:0:328:0", 0, 0, 0),
48+
)
49+
]
50+
),
51+
):
52+
self.assertEqual(
53+
"1234:ab00:567c:89d:abcd:0:328:0",
54+
network.get_ip_addr("foo.bar.facebook.com"),
55+
)
56+
57+
def test_network(self) -> None:
58+
# since we patched `socket.getaddrinfo` above
59+
# don't patch and just make sure things don't error out
60+
self.assertIsNotNone(network.get_ip_addr(socket.getfqdn()))

0 commit comments

Comments
 (0)