Skip to content

Commit 5ee1df8

Browse files
Add par_map methods to collections
* `par_map` facilitates running an async function across all items in the collection * Resolves #213
1 parent 65cd3a9 commit 5ee1df8

File tree

8 files changed

+163
-4
lines changed

8 files changed

+163
-4
lines changed

expression/collections/array.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
from __future__ import annotations
1111

1212
import array
13+
import asyncio
1314
import builtins
1415
import functools
15-
from collections.abc import Callable, Iterable, Iterator, MutableSequence
16+
from collections.abc import Awaitable, Callable, Iterable, Iterator, MutableSequence
1617
from enum import Enum
1718
from typing import Any, TypeVar, cast
1819

@@ -185,9 +186,36 @@ def __init__(
185186
self.typecode = typecode
186187

187188
def map(self, mapping: Callable[[_TSource], _TResult]) -> TypedArray[_TResult]:
189+
"""Map array.
190+
191+
Builds a new array whose elements are the results of applying
192+
the given function to each of the elements of the array.
193+
194+
Args:
195+
mapping: A function to transform items from the input array.
196+
197+
Returns:
198+
The result sequence.
199+
"""
188200
result = builtins.map(mapping, self.value)
189201
return TypedArray(result)
190202

203+
async def par_map(self, mapping: Callable[[_TSource], Awaitable[_TResult]]) -> TypedArray[_TResult]:
204+
"""Map array asynchronously.
205+
206+
Builds a new array whose elements are the results of applying
207+
the given asynchronous function to each of the elements of the
208+
array.
209+
210+
Args:
211+
mapping: A function to transform items from the input array.
212+
213+
Returns:
214+
The result sequence.
215+
"""
216+
result = await asyncio.gather(*(mapping(item) for item in self))
217+
return TypedArray(result)
218+
191219
def choose(self, chooser: Callable[[_TSource], Option[_TResult]]) -> TypedArray[_TResult]:
192220
"""Choose items from the list.
193221

expression/collections/block.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020

2121
from __future__ import annotations
2222

23+
import asyncio
2324
import builtins
2425
import functools
2526
import itertools
26-
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
27+
from collections.abc import Awaitable, Callable, Collection, Iterable, Iterator, Sequence
2728
from typing import TYPE_CHECKING, Any, Literal, TypeVar, get_args, overload
2829

2930
from typing_extensions import TypeVarTuple, Unpack
@@ -239,6 +240,23 @@ def map(self, mapping: Callable[[_TSource], _TResult]) -> Block[_TResult]:
239240
"""
240241
return Block((*builtins.map(mapping, self),))
241242

243+
async def par_map(self, mapping: Callable[[_TSource], Awaitable[_TResult]]) -> Block[_TResult]:
244+
"""Map list asynchronously.
245+
246+
Builds a new collection whose elements are the results of
247+
applying the given asynchronous function to each of the
248+
elements of the collection.
249+
250+
Args:
251+
mapping: The function to transform elements from the input
252+
list.
253+
254+
Returns:
255+
The list of transformed elements.
256+
"""
257+
result = await asyncio.gather(*(mapping(item) for item in self))
258+
return Block(result)
259+
242260
def starmap(self: Block[tuple[Unpack[_P]]], mapping: Callable[[Unpack[_P]], _TResult]) -> Block[_TResult]:
243261
"""Starmap source sequence.
244262

expression/collections/map.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
# - https://github.com/fsharp/fsharp/blob/master/src/fsharp/FSharp.Core/map.fs
1717
from __future__ import annotations
1818

19-
from collections.abc import Callable, ItemsView, Iterable, Iterator, Mapping
19+
import asyncio
20+
from collections.abc import Awaitable, Callable, ItemsView, Iterable, Iterator, Mapping
2021
from typing import Any, TypeVar, cast
2122

2223
from expression.core import Option, PipeMixin, SupportsLessThan, curry_flip, pipe
@@ -114,6 +115,25 @@ def map(self, mapping: Callable[[_Key, _Value], _Result]) -> Map[_Key, _Result]:
114115
"""
115116
return Map(maptree.map(mapping, self._tree))
116117

118+
async def par_map(self, mapping: Callable[[_Key, _Value], Awaitable[_Result]]) -> Map[_Key, _Result]:
119+
"""Map the mapping asynchronously.
120+
121+
Builds a new collection whose elements are the results of
122+
applying the given asynchronous function to each of the elements
123+
of the collection. The key passed to the function indicates the
124+
key of element being transformed.
125+
126+
Args:
127+
mapping: The function to transform the key/value pairs
128+
129+
Returns:
130+
The resulting map of keys and transformed values.
131+
"""
132+
keys_and_values = self.to_seq()
133+
result = await asyncio.gather(*(mapping(key, value) for key, value in keys_and_values))
134+
keys = [key for key, _ in keys_and_values]
135+
return Map.of_seq(zip(keys, result))
136+
117137
def partition(self, predicate: Callable[[_Key, _Value], bool]) -> tuple[Map[_Key, _Value], Map[_Key, _Value]]:
118138
r1, r2 = maptree.partition(predicate, self._tree)
119139
return Map(r1), Map(r2)

expression/collections/seq.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424

2525
from __future__ import annotations
2626

27+
import asyncio
2728
import builtins
2829
import functools
2930
import itertools
30-
from collections.abc import Callable, Iterable, Iterator
31+
from collections.abc import Awaitable, Callable, Iterable, Iterator
3132
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
3233

3334
from expression.core import (
@@ -175,6 +176,22 @@ def map(self, mapper: Callable[[_TSource], _TResult]) -> Seq[_TResult]:
175176
"""
176177
return Seq(pipe(self, map(mapper)))
177178

179+
async def par_map(self, mapper: Callable[[_TSource], Awaitable[_TResult]]) -> Seq[_TResult]:
180+
"""Map sequence asynchronously.
181+
182+
Builds a new collection whose elements are the results of
183+
applying the given asynchronous function to each of the elements
184+
of the collection.
185+
186+
Args:
187+
mapper: A function to transform items from the input sequence.
188+
189+
Returns:
190+
The result sequence.
191+
"""
192+
result = await asyncio.gather(*(mapper(item) for item in self))
193+
return Seq(result)
194+
178195
@overload
179196
def starmap(self: Seq[tuple[_T1, _T2]], mapping: Callable[[_T1, _T2], _TResult]) -> Seq[_TResult]: ...
180197

tests/test_array.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import functools
23
from collections.abc import Callable
34
from typing import Any
@@ -428,3 +429,20 @@ def test_array_monad_law_associativity_iterable(xs: list[int]):
428429

429430
m = array.of_seq(xs)
430431
assert m.collect(f).collect(g) == m.collect(lambda x: f(x).collect(g))
432+
433+
@pytest.mark.asyncio
434+
async def test_par_map():
435+
async def async_fn(i: int):
436+
await asyncio.sleep(0.1)
437+
return i * 2
438+
439+
xs = TypedArray(range(1, 10))
440+
441+
start_time = asyncio.get_event_loop().time()
442+
ys = await xs.par_map(async_fn)
443+
end_time = asyncio.get_event_loop().time()
444+
445+
assert ys == TypedArray(i * 2 for i in range(1, 10))
446+
447+
time_taken = end_time - start_time
448+
assert time_taken < 0.2, "par_map took too long"

tests/test_block.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import functools
23
from builtins import list as list
34
from collections.abc import Callable
@@ -7,6 +8,7 @@
78
from hypothesis import strategies as st
89
from pydantic import BaseModel, Field, GetCoreSchemaHandler
910
from pydantic_core import CoreSchema, core_schema
11+
import pytest
1012

1113
from expression import Nothing, Option, Some, pipe
1214
from expression.collections import Block, block
@@ -458,3 +460,20 @@ def test_serialize_block_works():
458460
assert model_.annotated_type_empty == block.empty
459461
assert model_.custom_type == Block(["a", "b", "c"])
460462
assert model_.custom_type_empty == block.empty
463+
464+
@pytest.mark.asyncio
465+
async def test_par_map():
466+
async def async_fn(i: int):
467+
await asyncio.sleep(0.1)
468+
return i * 2
469+
470+
xs = Block(range(1, 10))
471+
472+
start_time = asyncio.get_event_loop().time()
473+
ys = await xs.par_map(async_fn)
474+
end_time = asyncio.get_event_loop().time()
475+
476+
assert ys == Block(i * 2 for i in range(1, 10))
477+
478+
time_taken = end_time - start_time
479+
assert time_taken < 0.2, "par_map took too long"

tests/test_map.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import asyncio
12
from collections.abc import Callable, ItemsView, Iterable
23

4+
import pytest
35
from hypothesis import given # type: ignore
46
from hypothesis import strategies as st
57

@@ -150,3 +152,21 @@ def test_expression_issue_105():
150152
m = m.add("1", 1).add("2", 2).add("3", 3).add("4", 4)
151153
m = m.change("2", lambda x: x)
152154
m = m.change("3", lambda x: x)
155+
156+
157+
@pytest.mark.asyncio
158+
async def test_par_map():
159+
async def async_fn(key: str, value: int) -> int:
160+
await asyncio.sleep(0.1)
161+
return int(key) * value
162+
163+
xs = Map.of_seq((str(i), i) for i in range(1, 10))
164+
165+
start_time = asyncio.get_event_loop().time()
166+
ys = await xs.par_map(async_fn)
167+
end_time = asyncio.get_event_loop().time()
168+
169+
assert ys == Map.of_seq((str(i), i * i) for i in range(1, 10))
170+
171+
time_taken = end_time - start_time
172+
assert time_taken < 0.2, "par_map took too long"

tests/test_seq.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import functools
23
from collections.abc import Callable, Iterable
34
from itertools import accumulate
@@ -382,3 +383,21 @@ def test_seq_monad_law_associativity_empty(value: int):
382383
# Empty list
383384
m = empty
384385
assert list(m.collect(f).collect(g)) == list(m.collect(lambda x: f(x).collect(g)))
386+
387+
388+
@pytest.mark.asyncio
389+
async def test_par_map():
390+
async def async_fn(i: int):
391+
await asyncio.sleep(0.1)
392+
return i * 2
393+
394+
xs = seq.of_iterable(range(1, 10))
395+
396+
start_time = asyncio.get_event_loop().time()
397+
ys = await xs.par_map(async_fn)
398+
end_time = asyncio.get_event_loop().time()
399+
400+
assert list(ys) == [i * 2 for i in range(1, 10)]
401+
402+
time_taken = end_time - start_time
403+
assert time_taken < 0.2, "par_map took too long"

0 commit comments

Comments
 (0)