Skip to content

Commit c668fda

Browse files
author
neutrino
committed
Fix memory leak in zeroGradients().
1 parent 963332d commit c668fda

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@ public void zeroGradients() {
123123
NDManager systemManager = MxNDManager.getSystemManager();
124124
for (NDArray array : systemManager.getManagedArrays()) {
125125
if (array.hasGradient()) {
126-
array.getGradient().subi(array.getGradient());
126+
// To prevent memory leak we must close gradient after use.
127+
try (NDArray gradient = array.getGradient()) {
128+
gradient.subi(gradient);
129+
}
127130
}
128131
}
129132
}

engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ public void zeroGradients() {
7676
NDManager systemManager = PtNDManager.getSystemManager();
7777
for (NDArray array : systemManager.getManagedArrays()) {
7878
if (array.hasGradient()) {
79-
array.getGradient().subi(array.getGradient());
79+
// To prevent memory leak we must close gradient after use.
80+
try (NDArray gradient = array.getGradient()) {
81+
gradient.subi(gradient);
82+
}
8083
}
8184
}
8285
}

0 commit comments

Comments
 (0)