|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import atexit |
15 | 16 | import collections
|
16 | 17 | import io
|
17 | 18 | import math
|
|
20 | 21 | import sys
|
21 | 22 | import inspect
|
22 | 23 | from multiprocess import Pool, RLock
|
| 24 | +import time |
23 | 25 |
|
24 | 26 | import paddle.distributed as dist
|
25 | 27 | from paddle.io import Dataset, IterableDataset
|
26 | 28 | from paddle.dataset.common import md5file
|
27 |
| -from paddle.utils.download import get_path_from_url |
| 29 | +from paddle.utils.download import get_path_from_url, _get_unique_endpoints |
28 | 30 | from paddlenlp.utils.env import DATA_HOME
|
29 | 31 | from typing import Iterable, Iterator, Optional, List, Any, Callable, Union
|
30 | 32 | import importlib
|
@@ -494,19 +496,59 @@ def read_datasets(self, splits=None, data_files=None):
|
494 | 496 | for split, filename in data_files.items()
|
495 | 497 | ]
|
496 | 498 |
|
| 499 | + def remove_if_exit(filepath): |
| 500 | + if isinstance(filepath, (list, tuple)): |
| 501 | + for file in filepath: |
| 502 | + try: |
| 503 | + os.remove(file) |
| 504 | + except OSError: |
| 505 | + pass |
| 506 | + else: |
| 507 | + try: |
| 508 | + os.remove(filepath) |
| 509 | + except OSError: |
| 510 | + pass |
| 511 | + |
497 | 512 | if splits:
|
498 | 513 | assert isinstance(splits, str) or (
|
499 | 514 | isinstance(splits, list) and isinstance(splits[0], str)
|
500 | 515 | ) or (
|
501 | 516 | isinstance(splits, tuple) and isinstance(splits[0], str)
|
502 | 517 | ), "`splits` should be a string or list of string or a tuple of string."
|
503 | 518 | if isinstance(splits, str):
|
504 |
| - filename = self._get_data(splits) |
505 |
| - datasets.append(self.read(filename=filename, split=splits)) |
506 |
| - else: |
507 |
| - for split in splits: |
508 |
| - filename = self._get_data(split) |
509 |
| - datasets.append(self.read(filename=filename, split=split)) |
| 519 | + splits = [splits] |
| 520 | + parallel_env = dist.ParallelEnv() |
| 521 | + unique_endpoints = _get_unique_endpoints( |
| 522 | + parallel_env.trainer_endpoints[:]) |
| 523 | + # move register hook to first and register togather |
| 524 | + lock_files = [] |
| 525 | + for split in splits: |
| 526 | + lock_file = os.path.join(DATA_HOME, self.__class__.__name__) |
| 527 | + if self.name is not None: |
| 528 | + lock_file = lock_file + "." + self.name |
| 529 | + lock_file += "." + split + ".done" + "." + str(os.getppid()) |
| 530 | + lock_files.append(lock_file) |
| 531 | + # Must register to all procs to make the lock file can be removed |
| 532 | + # when any proc breaks. Otherwise, the single registered proc may |
| 533 | + # not receive proper singal send by the parent proc to exit. |
| 534 | + atexit.register(lambda: remove_if_exit(lock_files)) |
| 535 | + for split in splits: |
| 536 | + filename = self._get_data(split) |
| 537 | + lock_file = os.path.join(DATA_HOME, self.__class__.__name__) |
| 538 | + if self.name is not None: |
| 539 | + lock_file = lock_file + "." + self.name |
| 540 | + lock_file += "." + split + ".done" + "." + str(os.getppid()) |
| 541 | + # `lock_file` indicates the finished status of`_get_data`. |
| 542 | + # `_get_data` only works in the `unique_endpoints` specified |
| 543 | + # proc since `get_path_from_url` only work for it. The other |
| 544 | + # procs wait `_get_data` to be finished. |
| 545 | + if parallel_env.current_endpoint in unique_endpoints: |
| 546 | + f = open(lock_file, "w") |
| 547 | + f.close() |
| 548 | + else: |
| 549 | + while not os.path.exists(lock_file): |
| 550 | + time.sleep(1) |
| 551 | + datasets.append(self.read(filename=filename, split=split)) |
510 | 552 |
|
511 | 553 | return datasets if len(datasets) > 1 else datasets[0]
|
512 | 554 |
|
|
0 commit comments