Skip to content

Commit d31d81a

Browse files
kimoonkimash211
authored andcommitted
Fix an HDFS data locality bug in case cluster node names are short host names (apache-spark-on-k8s#291)
* Fix an HDFS data locality bug in case cluster node names are not full host names * Add a NOTE about InetAddress caching
1 parent 2a2cfb6 commit d31d81a

File tree

2 files changed

+141
-2
lines changed

2 files changed

+141
-2
lines changed

resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesTaskSetManager.scala

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,18 @@
1616
*/
1717
package org.apache.spark.scheduler.cluster.kubernetes
1818

19+
import java.net.InetAddress
20+
1921
import scala.collection.mutable.ArrayBuffer
2022

2123
import org.apache.spark.scheduler.{TaskSchedulerImpl, TaskSet, TaskSetManager}
2224

2325
private[spark] class KubernetesTaskSetManager(
2426
sched: TaskSchedulerImpl,
2527
taskSet: TaskSet,
26-
maxTaskFailures: Int) extends TaskSetManager(sched, taskSet, maxTaskFailures) {
28+
maxTaskFailures: Int,
29+
inetAddressUtil: InetAddressUtil = new InetAddressUtil)
30+
extends TaskSetManager(sched, taskSet, maxTaskFailures) {
2731

2832
/**
2933
* Overrides the lookup to use not only the executor pod IP, but also the cluster node
@@ -52,12 +56,30 @@ private[spark] class KubernetesTaskSetManager(
5256
if (pendingTasksClusterNodeIP.nonEmpty) {
5357
logDebug(s"Got preferred task list $pendingTasksClusterNodeIP for executor host " +
5458
s"$executorIP using cluster node IP $clusterNodeIP")
59+
pendingTasksClusterNodeIP
60+
} else {
61+
val clusterNodeFullName = inetAddressUtil.getFullHostName(clusterNodeIP)
62+
val pendingTasksClusterNodeFullName = super.getPendingTasksForHost(clusterNodeFullName)
63+
if (pendingTasksClusterNodeFullName.nonEmpty) {
64+
logDebug(s"Got preferred task list $pendingTasksClusterNodeFullName " +
65+
s"for executor host $executorIP using cluster node full name $clusterNodeFullName")
66+
}
67+
pendingTasksClusterNodeFullName
5568
}
56-
pendingTasksClusterNodeIP
5769
}
5870
} else {
5971
pendingTasksExecutorIP // Empty
6072
}
6173
}
6274
}
6375
}
76+
77+
// To support mocks in unit tests.
78+
private[kubernetes] class InetAddressUtil {
79+
80+
// NOTE: This does issue a network call to DNS. Caching is done internally by the InetAddress
81+
// class for both hits and misses.
82+
def getFullHostName(ipAddress: String): String = {
83+
InetAddress.getByName(ipAddress).getCanonicalHostName
84+
}
85+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.scheduler.cluster.kubernetes
18+
19+
import scala.collection.mutable.ArrayBuffer
20+
21+
import io.fabric8.kubernetes.api.model.{Pod, PodSpec, PodStatus}
22+
import org.mockito.Mockito._
23+
24+
import org.apache.spark.{SparkContext, SparkFunSuite}
25+
import org.apache.spark.scheduler.{FakeTask, FakeTaskScheduler, HostTaskLocation, TaskLocation}
26+
27+
class KubernetesTaskSetManagerSuite extends SparkFunSuite {
28+
29+
val sc = new SparkContext("local", "test")
30+
val sched = new FakeTaskScheduler(sc,
31+
("execA", "10.0.0.1"), ("execB", "10.0.0.2"), ("execC", "10.0.0.3"))
32+
val backend = mock(classOf[KubernetesClusterSchedulerBackend])
33+
sched.backend = backend
34+
35+
test("Find pending tasks for executors using executor pod IP addresses") {
36+
val taskSet = FakeTask.createTaskSet(3,
37+
Seq(TaskLocation("10.0.0.1", "execA")), // Task 0 runs on executor pod 10.0.0.1.
38+
Seq(TaskLocation("10.0.0.1", "execA")), // Task 1 runs on executor pod 10.0.0.1.
39+
Seq(TaskLocation("10.0.0.2", "execB")) // Task 2 runs on executor pod 10.0.0.2.
40+
)
41+
42+
val manager = new KubernetesTaskSetManager(sched, taskSet, maxTaskFailures = 2)
43+
assert(manager.getPendingTasksForHost("10.0.0.1") == ArrayBuffer(1, 0))
44+
assert(manager.getPendingTasksForHost("10.0.0.2") == ArrayBuffer(2))
45+
}
46+
47+
test("Find pending tasks for executors using cluster node names that executor pods run on") {
48+
val taskSet = FakeTask.createTaskSet(2,
49+
Seq(HostTaskLocation("kube-node1")), // Task 0's partition belongs to datanode on kube-node1
50+
Seq(HostTaskLocation("kube-node1")) // Task 1's partition belongs to datanode on kube-node2
51+
)
52+
val spec1 = mock(classOf[PodSpec])
53+
when(spec1.getNodeName).thenReturn("kube-node1")
54+
val pod1 = mock(classOf[Pod])
55+
when(pod1.getSpec).thenReturn(spec1)
56+
when(backend.getExecutorPodByIP("10.0.0.1")).thenReturn(Some(pod1))
57+
58+
val manager = new KubernetesTaskSetManager(sched, taskSet, maxTaskFailures = 2)
59+
assert(manager.getPendingTasksForHost("10.0.0.1") == ArrayBuffer(1, 0))
60+
}
61+
62+
test("Find pending tasks for executors using cluster node IPs that executor pods run on") {
63+
val taskSet = FakeTask.createTaskSet(2,
64+
Seq(HostTaskLocation("196.0.0.5")), // Task 0's partition belongs to datanode on 196.0.0.5.
65+
Seq(HostTaskLocation("196.0.0.5")) // Task 1's partition belongs to datanode on 196.0.0.5.
66+
)
67+
val spec1 = mock(classOf[PodSpec])
68+
when(spec1.getNodeName).thenReturn("kube-node1")
69+
val pod1 = mock(classOf[Pod])
70+
when(pod1.getSpec).thenReturn(spec1)
71+
val status1 = mock(classOf[PodStatus])
72+
when(status1.getHostIP).thenReturn("196.0.0.5")
73+
when(pod1.getStatus).thenReturn(status1)
74+
when(backend.getExecutorPodByIP("10.0.0.1")).thenReturn(Some(pod1))
75+
val manager = new KubernetesTaskSetManager(sched, taskSet, maxTaskFailures = 2)
76+
assert(manager.getPendingTasksForHost("10.0.0.1") == ArrayBuffer(1, 0))
77+
}
78+
79+
test("Find pending tasks for executors using cluster node FQDNs that executor pods run on") {
80+
val taskSet = FakeTask.createTaskSet(2,
81+
Seq(HostTaskLocation("kube-node1.domain1")), // Task 0's partition belongs to datanode here.
82+
Seq(HostTaskLocation("kube-node1.domain1")) // task 1's partition belongs to datanode here.
83+
)
84+
val spec1 = mock(classOf[PodSpec])
85+
when(spec1.getNodeName).thenReturn("kube-node1")
86+
val pod1 = mock(classOf[Pod])
87+
when(pod1.getSpec).thenReturn(spec1)
88+
val status1 = mock(classOf[PodStatus])
89+
when(status1.getHostIP).thenReturn("196.0.0.5")
90+
when(pod1.getStatus).thenReturn(status1)
91+
val inetAddressUtil = mock(classOf[InetAddressUtil])
92+
when(inetAddressUtil.getFullHostName("196.0.0.5")).thenReturn("kube-node1.domain1")
93+
when(backend.getExecutorPodByIP("10.0.0.1")).thenReturn(Some(pod1))
94+
95+
val manager = new KubernetesTaskSetManager(sched, taskSet, maxTaskFailures = 2, inetAddressUtil)
96+
assert(manager.getPendingTasksForHost("10.0.0.1") == ArrayBuffer(1, 0))
97+
}
98+
99+
test("Return empty pending tasks for executors when all look up fail") {
100+
val taskSet = FakeTask.createTaskSet(1,
101+
Seq(HostTaskLocation("kube-node1.domain1")) // task 0's partition belongs to datanode here.
102+
)
103+
val spec1 = mock(classOf[PodSpec])
104+
when(spec1.getNodeName).thenReturn("kube-node2")
105+
val pod1 = mock(classOf[Pod])
106+
when(pod1.getSpec).thenReturn(spec1)
107+
val status1 = mock(classOf[PodStatus])
108+
when(status1.getHostIP).thenReturn("196.0.0.6")
109+
when(pod1.getStatus).thenReturn(status1)
110+
val inetAddressUtil = mock(classOf[InetAddressUtil])
111+
when(inetAddressUtil.getFullHostName("196.0.0.6")).thenReturn("kube-node2.domain1")
112+
when(backend.getExecutorPodByIP("10.0.0.1")).thenReturn(Some(pod1))
113+
114+
val manager = new KubernetesTaskSetManager(sched, taskSet, maxTaskFailures = 2, inetAddressUtil)
115+
assert(manager.getPendingTasksForHost("10.0.0.1") == ArrayBuffer())
116+
}
117+
}

0 commit comments

Comments
 (0)