Skip to content

Fix #4382: Fix thread leak in WSTP by replacing LinkedTransferQueue with SynchronousQueue and ConcurrentHashMap #4388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: series/3.6.x
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import scala.concurrent.duration.{Duration, FiniteDuration}

import java.time.Instant
import java.time.temporal.ChronoField
import java.util.concurrent.{LinkedTransferQueue, ThreadLocalRandom}
import java.util.concurrent.{ConcurrentHashMap, SynchronousQueue, ThreadLocalRandom}
import java.util.concurrent.atomic.{
AtomicBoolean,
AtomicInteger,
Expand Down Expand Up @@ -131,8 +131,11 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
*/
private[this] val state: AtomicInteger = new AtomicInteger(threadCount << UnparkShift)

private[unsafe] val cachedThreads: LinkedTransferQueue[WorkerThread[P]] =
new LinkedTransferQueue
private[unsafe] val transferStateQueue: SynchronousQueue[WorkerThread.TransferState] =
new SynchronousQueue[WorkerThread.TransferState](false)

private[unsafe] val blockerThreads: ConcurrentHashMap[WorkerThread[P], java.lang.Boolean] =
new ConcurrentHashMap()

/**
* The shutdown latch of the work stealing thread pool.
Expand Down Expand Up @@ -749,11 +752,9 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
system.close()
}

var t: WorkerThread[P] = null
while ({
t = cachedThreads.poll()
t ne null
}) {
val it = blockerThreads.keySet().iterator()
while (it.hasNext()) {
val t = it.next()
Comment on lines -752 to +757
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, I think one of the goals here is to avoid any allocations, in case the runtime was shutting down in a fatal condition (e.g. out-of-memory). Unfortunately, creating the iterator is an allocation. But, I don't know how to iterate the elements of a ConcurrentHashMap without an iterator 🤔

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After searching I was also not able to find any allocation free method , it seems we might need to accept this small allocation as a trade-off, currently . I would still search for it and am open to suggestions for this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd have to retrace the code, but I think this is an area of secondary concern in terms of allocations. The critical path is ensuring that exceptions can propagate out to the IO#unsafe calling point without allocation. So long as that is achieved, everything else is gravy. Logically, I don't think WSTP shutdown matters as much since, in any fatal error case, the process is torpedoed anyway and about to die.

t.interrupt()
// don't bother joining, cached threads are not doing anything interesting
}
Expand Down
41 changes: 26 additions & 15 deletions core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.concurrent.{BlockContext, CanAwait}
import scala.concurrent.duration.{Duration, FiniteDuration}

import java.lang.Long.MIN_VALUE
import java.util.concurrent.{ArrayBlockingQueue, ThreadLocalRandom}
import java.util.concurrent.ThreadLocalRandom
import java.util.concurrent.atomic.AtomicBoolean

import WorkerThread.{Metrics, TransferState}
Expand Down Expand Up @@ -110,7 +110,6 @@ private[effect] final class WorkerThread[P <: AnyRef](
*/
private[this] var _active: Runnable = _

private val stateTransfer: ArrayBlockingQueue[TransferState] = new ArrayBlockingQueue(1)
private[this] val runtimeBlockingExpiration: Duration = pool.runtimeBlockingExpiration

private[effect] var currentIOFiber: IOFiber[?] = _
Expand Down Expand Up @@ -732,20 +731,27 @@ private[effect] final class WorkerThread[P <: AnyRef](
// by another thread in the future.
val len = runtimeBlockingExpiration.length
val unit = runtimeBlockingExpiration.unit
if (pool.cachedThreads.tryTransfer(this, len, unit)) {
// Someone accepted the transfer of this thread and will transfer the state soon.
val newState = stateTransfer.take()

// Try to poll for a new state from the transfer queue
val newState = pool.transferStateQueue.poll(len, unit)

if (newState ne null) {
// Got a state to take over
init(newState)

} else {
// The timeout elapsed and no one woke up this thread. It's time to exit.
// No state to take over after timeout, exit
pool.blockedWorkerThreadCounter.decrementAndGet()
// Remove from blocker threads map if present
pool.blockerThreads.remove(this)
return
}
} catch {
case _: InterruptedException =>
// This thread was interrupted while cached. This should only happen
// during the shutdown of the pool. Nothing else to be done, just
// exit.
pool.blockerThreads.remove(this)
return
}
}
Expand Down Expand Up @@ -928,15 +934,18 @@ private[effect] final class WorkerThread[P <: AnyRef](
// Set the name of this thread to a blocker prefixed name.
setName(s"$prefix-$nameIndex")

val cached = pool.cachedThreads.poll()
if (cached ne null) {
// There is a cached worker thread that can be reused.
val idx = index
pool.replaceWorker(idx, cached)
// Transfer the data structures to the cached thread and wake it up.
transferState.index = idx
transferState.tick = tick + 1
val _ = cached.stateTransfer.offer(transferState)
val idx = index

// Prepare the transfer state
transferState.index = idx
transferState.tick = tick + 1

// Register this thread in the blockerThreads map
val _ = pool.blockerThreads.put(this, java.lang.Boolean.TRUE)

if (pool.transferStateQueue.offer(transferState)) {
// If successful, a waiting thread will pick it up

} else {
// Spawn a new `WorkerThread`, a literal clone of this one. It is safe to
// transfer ownership of the local queue and the parked signal to the new
Expand Down Expand Up @@ -1002,6 +1011,8 @@ private[effect] final class WorkerThread[P <: AnyRef](
setName(s"$prefix-${_index}")

blocking = false

pool.replaceWorker(newIdx, this)
}

/**
Expand Down
Loading