Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit 1ae201a

Browse files
authored
refactor(embedding): level up embed method to top API add docs (#178)
* refactor(embedding): level up embed method to top API add docs * refactor(embedding): level up embed method to top API add docs
1 parent bf07ab1 commit 1ae201a

10 files changed

Lines changed: 67 additions & 25 deletions

File tree

docs/basics/embed.png

27.6 KB
Loading

docs/basics/fit.md

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,25 @@ Depending on your framework, `display` may require different argument for render
3333

3434
More information can be {ref}`found here<display-method>`.
3535

36+
## Embed documents
37+
38+
You can use `finetuner.embed()` method to compute the embeddings of a `DocumentArray` or `DocumentArrayMemmap`.
39+
40+
```python
41+
import finetuner
42+
from jina import DocumentArray
43+
44+
docs = DocumentArray(...)
45+
46+
finetuner.embed(docs, model)
47+
48+
print(docs.embeddings)
49+
```
50+
51+
Note that, `model` above must be an {term}`Embedding model`.
52+
53+
54+
3655
## Example
3756

3857
```python
@@ -59,9 +78,6 @@ model, summary = finetuner.fit(
5978
)
6079

6180
finetuner.display(model, input_size=(100,), input_dtype='long')
62-
63-
finetuner.save(model, './saved-model')
64-
summary.plot('fit.png')
6581
```
6682

6783
```console
@@ -81,7 +97,32 @@ Green layers can be used as embedding layers, whose name can be used as
8197
layer_name in to_embedding_model(...).
8298
```
8399

100+
```python
101+
finetuner.save(model, './saved-model')
102+
summary.plot('fit.png')
103+
```
104+
84105
```{figure} fit-plot.png
85106
:align: center
86107
:width: 80%
108+
```
109+
110+
```python
111+
from jina import DocumentArray
112+
all_q = DocumentArray(generate_qa_match())
113+
finetuner.embed(all_q, model)
114+
print(all_q.embeddings.shape)
115+
```
116+
117+
```console
118+
(481, 32)
119+
```
120+
121+
```python
122+
all_q.visualize('embed.png', method='tsne')
123+
```
124+
125+
```{figure} embed.png
126+
:align: center
127+
:width: 80%
87128
```

docs/index.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ Perfect! Now `embed_model` and `train_data` are already provided by you, simply
136136
```python
137137
import finetuner
138138

139-
tuned_model, _ = finetuner.fit(
139+
tuned_model, summary = finetuner.fit(
140140
embed_model,
141141
train_data=train_data
142142
)
@@ -159,7 +159,7 @@ emphasize-lines: 6
159159
---
160160
import finetuner
161161
162-
tuned_model, _ = finetuner.fit(
162+
tuned_model, summary = finetuner.fit(
163163
embed_model,
164164
train_data=unlabeled_data,
165165
interactive=True
@@ -183,7 +183,7 @@ emphasize-lines: 6, 7
183183
---
184184
import finetuner
185185
186-
tuned_model, _ = finetuner.fit(
186+
tuned_model, summary = finetuner.fit(
187187
general_model,
188188
train_data=labeled_data,
189189
to_embedding_model=True,
@@ -208,7 +208,7 @@ emphasize-lines: 6, 7
208208
---
209209
import finetuner
210210
211-
tuned_model, _ = finetuner.fit(
211+
tuned_model, summary = finetuner.fit(
212212
general_model,
213213
train_data=labeled_data,
214214
interactive=True,

finetuner/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def fit(
6767
optimizer: str = 'adam',
6868
optimizer_kwargs: Optional[Dict] = None,
6969
device: str = 'cpu',
70-
) -> Tuple['AnyDNN', 'Summary']:
70+
) -> Tuple['AnyDNN', None]:
7171
...
7272

7373

@@ -91,7 +91,7 @@ def fit(
9191
output_dim: Optional[int] = None,
9292
freeze: bool = False,
9393
device: str = 'cpu',
94-
) -> Tuple['AnyDNN', 'Summary']:
94+
) -> Tuple['AnyDNN', None]:
9595
...
9696

9797

@@ -116,3 +116,4 @@ def fit(
116116
# level them up to the top-level
117117
from .tuner import save
118118
from .tailor import display
119+
from .embedding import embed

finetuner/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .helper import AnyDNN, get_framework
66

77

8-
def set_embeddings(
8+
def embed(
99
docs: Union[DocumentArray, DocumentArrayMemmap],
1010
embed_model: AnyDNN,
1111
device: str = 'cpu',

finetuner/labeler/executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from jina import Executor, DocumentArray, requests, DocumentArrayMemmap
55
from jina.helper import cached_property
66

7-
from ..embedding import set_embeddings
7+
from ..embedding import embed
88
from ..tuner import fit, save
99

1010

@@ -42,8 +42,8 @@ def embed(self, docs: DocumentArray, parameters: Dict, **kwargs):
4242
min(len(self._all_data), int(parameters.get('sample_size', 1000)))
4343
)
4444

45-
set_embeddings(docs, self._embed_model)
46-
set_embeddings(_catalog, self._embed_model)
45+
embed(docs, self._embed_model)
46+
embed(_catalog, self._embed_model)
4747

4848
docs.match(
4949
_catalog,

tests/unit/test_embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from jina import DocumentArray, DocumentArrayMemmap
66

7-
from finetuner.embedding import set_embeddings
7+
from finetuner.embedding import embed
88
from finetuner.toydata import generate_fashion_match
99

1010
embed_models = {
@@ -41,11 +41,11 @@ def test_set_embeddings(framework, tmpdir):
4141
# works for DA
4242
embed_model = embed_models[framework]()
4343
docs = DocumentArray(generate_fashion_match(num_total=100))
44-
set_embeddings(docs, embed_model)
44+
embed(docs, embed_model)
4545
assert docs.embeddings.shape == (100, 32)
4646

4747
# works for DAM
4848
dam = DocumentArrayMemmap(tmpdir)
4949
dam.extend(generate_fashion_match(num_total=42))
50-
set_embeddings(dam, embed_model)
50+
embed(dam, embed_model)
5151
assert dam.embeddings.shape == (42, 32)

tests/unit/tuner/keras/test_gpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from jina import DocumentArray, DocumentArrayMemmap
44

55
from finetuner.tuner.keras import KerasTuner
6-
from finetuner.embedding import set_embeddings
6+
from finetuner.embedding import embed
77
from finetuner.toydata import generate_fashion_match
88

99
all_test_losses = [
@@ -47,11 +47,11 @@ def test_set_embeddings_gpu(tmpdir):
4747
]
4848
)
4949
docs = DocumentArray(generate_fashion_match(num_total=100))
50-
set_embeddings(docs, embed_model, 'cuda')
50+
embed(docs, embed_model, 'cuda')
5151
assert docs.embeddings.shape == (100, 32)
5252

5353
# works for DAM
5454
dam = DocumentArrayMemmap(tmpdir)
5555
dam.extend(generate_fashion_match(num_total=42))
56-
set_embeddings(dam, embed_model, 'cuda')
56+
embed(dam, embed_model, 'cuda')
5757
assert dam.embeddings.shape == (42, 32)

tests/unit/tuner/paddle/test_gpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import paddle.nn as nn
33
from jina import DocumentArray, DocumentArrayMemmap
44

5-
from finetuner.embedding import set_embeddings
5+
from finetuner.embedding import embed
66
from finetuner.toydata import generate_fashion_match
77
from finetuner.tuner.paddle import PaddleTuner
88

@@ -45,11 +45,11 @@ def test_set_embeddings_gpu(tmpdir):
4545
nn.Linear(in_features=128, out_features=32),
4646
)
4747
docs = DocumentArray(generate_fashion_match(num_total=100))
48-
set_embeddings(docs, embed_model, 'cuda')
48+
embed(docs, embed_model, 'cuda')
4949
assert docs.embeddings.shape == (100, 32)
5050

5151
# works for DAM
5252
dam = DocumentArrayMemmap(tmpdir)
5353
dam.extend(generate_fashion_match(num_total=42))
54-
set_embeddings(dam, embed_model, 'cuda')
54+
embed(dam, embed_model, 'cuda')
5555
assert dam.embeddings.shape == (42, 32)

tests/unit/tuner/torch/test_gpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn as nn
44
from jina import DocumentArray, DocumentArrayMemmap
55

6-
from finetuner.embedding import set_embeddings
6+
from finetuner.embedding import embed
77
from finetuner.toydata import generate_fashion_match
88
from finetuner.tuner.pytorch import PytorchTuner
99

@@ -49,11 +49,11 @@ def test_set_embeddings_gpu(tmpdir):
4949
nn.Linear(in_features=128, out_features=32),
5050
)
5151
docs = DocumentArray(generate_fashion_match(num_total=100))
52-
set_embeddings(docs, embed_model, 'cuda')
52+
embed(docs, embed_model, 'cuda')
5353
assert docs.embeddings.shape == (100, 32)
5454

5555
# works for DAM
5656
dam = DocumentArrayMemmap(tmpdir)
5757
dam.extend(generate_fashion_match(num_total=42))
58-
set_embeddings(dam, embed_model, 'cuda')
58+
embed(dam, embed_model, 'cuda')
5959
assert dam.embeddings.shape == (42, 32)

0 commit comments

Comments
 (0)