Skip to content

Commit 7ac9971

Browse files
FrostMLguoshengCS
andauthored
Fix multiprocess download (#383)
* fix multiprocess download * simplify code * another idea * update * Fix multiprocess _get_data using status file and atexit hook. * Register remove hook of get data to all processes to make sure the hook execution. * use try except * Move atexit register hook to before. * move register hook to first and register togather. * Fix move register hook to first and register togather. * Move lock files into a list to remove togather. * Fix OSError beark loop. * Refine nested try-catch. Co-authored-by: guosheng <[email protected]>
1 parent b7dd5ce commit 7ac9971

File tree

1 file changed

+49
-7
lines changed

1 file changed

+49
-7
lines changed

paddlenlp/datasets/dataset.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import atexit
1516
import collections
1617
import io
1718
import math
@@ -20,11 +21,12 @@
2021
import sys
2122
import inspect
2223
from multiprocess import Pool, RLock
24+
import time
2325

2426
import paddle.distributed as dist
2527
from paddle.io import Dataset, IterableDataset
2628
from paddle.dataset.common import md5file
27-
from paddle.utils.download import get_path_from_url
29+
from paddle.utils.download import get_path_from_url, _get_unique_endpoints
2830
from paddlenlp.utils.env import DATA_HOME
2931
from typing import Iterable, Iterator, Optional, List, Any, Callable, Union
3032
import importlib
@@ -494,19 +496,59 @@ def read_datasets(self, splits=None, data_files=None):
494496
for split, filename in data_files.items()
495497
]
496498

499+
def remove_if_exit(filepath):
500+
if isinstance(filepath, (list, tuple)):
501+
for file in filepath:
502+
try:
503+
os.remove(file)
504+
except OSError:
505+
pass
506+
else:
507+
try:
508+
os.remove(filepath)
509+
except OSError:
510+
pass
511+
497512
if splits:
498513
assert isinstance(splits, str) or (
499514
isinstance(splits, list) and isinstance(splits[0], str)
500515
) or (
501516
isinstance(splits, tuple) and isinstance(splits[0], str)
502517
), "`splits` should be a string or list of string or a tuple of string."
503518
if isinstance(splits, str):
504-
filename = self._get_data(splits)
505-
datasets.append(self.read(filename=filename, split=splits))
506-
else:
507-
for split in splits:
508-
filename = self._get_data(split)
509-
datasets.append(self.read(filename=filename, split=split))
519+
splits = [splits]
520+
parallel_env = dist.ParallelEnv()
521+
unique_endpoints = _get_unique_endpoints(
522+
parallel_env.trainer_endpoints[:])
523+
# move register hook to first and register togather
524+
lock_files = []
525+
for split in splits:
526+
lock_file = os.path.join(DATA_HOME, self.__class__.__name__)
527+
if self.name is not None:
528+
lock_file = lock_file + "." + self.name
529+
lock_file += "." + split + ".done" + "." + str(os.getppid())
530+
lock_files.append(lock_file)
531+
# Must register to all procs to make the lock file can be removed
532+
# when any proc breaks. Otherwise, the single registered proc may
533+
# not receive proper singal send by the parent proc to exit.
534+
atexit.register(lambda: remove_if_exit(lock_files))
535+
for split in splits:
536+
filename = self._get_data(split)
537+
lock_file = os.path.join(DATA_HOME, self.__class__.__name__)
538+
if self.name is not None:
539+
lock_file = lock_file + "." + self.name
540+
lock_file += "." + split + ".done" + "." + str(os.getppid())
541+
# `lock_file` indicates the finished status of`_get_data`.
542+
# `_get_data` only works in the `unique_endpoints` specified
543+
# proc since `get_path_from_url` only work for it. The other
544+
# procs wait `_get_data` to be finished.
545+
if parallel_env.current_endpoint in unique_endpoints:
546+
f = open(lock_file, "w")
547+
f.close()
548+
else:
549+
while not os.path.exists(lock_file):
550+
time.sleep(1)
551+
datasets.append(self.read(filename=filename, split=split))
510552

511553
return datasets if len(datasets) > 1 else datasets[0]
512554

0 commit comments

Comments
 (0)