Skip to content

Commit d7b5f5e

Browse files
Add bundled inputs recipe (pytorch#1524)
* bundled inputs recipe * test * Bundled input tutorials v1 * update format bundled inputs recipe * fix format * clean up bad file * address feedback Co-authored-by: Brian Johnson <[email protected]>
1 parent 2283a3e commit d7b5f5e

File tree

1 file changed

+198
-0
lines changed

1 file changed

+198
-0
lines changed

recipes_source/bundled_inputs.rst

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
(beta) Bundling inputs to PyTorch Models
2+
==================================================================
3+
4+
**Author**: `Jacob Szwejbka <https://github.com/JacobSzwejbka>`_
5+
6+
Introduction
7+
------------
8+
9+
This tutorial introduces the steps to use PyTorch's utility to bundle example or trivial inputs directly into your TorchScript Module.
10+
11+
The interface of the model remains unchanged (other than adding a few methods), so it can still be safely deployed to production. The advantage of this standardized interface is that tools that run models can use it instead of having some sort of external file (or worse, document) that tells you how to run the model properly.
12+
13+
Common case, bundling an input to a model that only uses 'forward' for inference
14+
-------------------
15+
1. **Prepare model**: Convert your model to TorchScript through either tracing or scripting
16+
17+
.. code:: python
18+
19+
import torch
20+
import torch.jit
21+
import torch.utils
22+
import torch.utils.bundled_inputs
23+
24+
class Net(nn.Module):
25+
def __init__(self):
26+
super(Net, self).__init__()
27+
self.lin = nn.Linear(10, 1)
28+
29+
def forward(self, x):
30+
return self.lin(x)
31+
32+
model = Net()
33+
scripted_module = torch.jit.script(model)
34+
35+
2. **Create example input and attach to model**
36+
37+
.. code:: python
38+
39+
# For each method create a list of inputs and each input is a tuple of arguments
40+
sample_input = [(torch.zeros(1,10),)]
41+
42+
# Create model with bundled inputs, if type(input) is list then the input is bundled to 'forward'
43+
bundled_model = bundle_inputs(scripted_module, sample_input)
44+
45+
46+
3. **Run model with input as arguments**
47+
48+
.. code:: python
49+
50+
sample_inputs = bundled_model.get_all_bundled_inputs()
51+
52+
print(bundled_model(*sample_inputs[0]))
53+
54+
55+
Uncommon case, bundling and retrieving inputs for functions beyond 'forward'
56+
-------------------
57+
1. **Prepare model**: Convert your model to TorchScript through either tracing or scripting
58+
59+
.. code:: python
60+
61+
import torch
62+
import torch.jit
63+
import torch.utils
64+
import torch.utils.bundled_inputs
65+
from typing import Dict
66+
67+
class Net(nn.Module):
68+
def __init__(self):
69+
super(Net, self).__init__()
70+
self.lin = nn.Linear(10, 1)
71+
72+
def forward(self, x):
73+
return self.lin(x)
74+
75+
@torch.jit.export
76+
def foo(self, x: Dict[String, int]):
77+
return x['a'] + x['b']
78+
79+
80+
model = Net()
81+
scripted_module = torch.jit.script(model)
82+
83+
2. **Create example input and attach to model**
84+
85+
.. code:: python
86+
87+
# For each method create a list of inputs and each input is a tuple of arguments
88+
example_dict = {'a' : 1, 'b' : 2}
89+
sample_input = {
90+
scripted_module.forward : [(torch.zeros(1,10),)],
91+
scripted_module.foo : [(example_dict,)]
92+
}
93+
94+
# Create model with bundled inputs, if type(sample_input) is Dict then each callable key is mapped to its corresponding bundled input
95+
bundled_model = bundle_inputs(scripted_module, sample_input)
96+
97+
98+
3. **Retrieve inputs and run model on them**
99+
100+
.. code:: python
101+
102+
all_info = bundled_model.get_bundled_inputs_functions_and_info()
103+
104+
# The return type for get_bundled_inputs_functions_and_info is complex, but essentially we are retrieving the name
105+
# of a function we can use to get the bundled input for our models method
106+
for func_name in all_info.keys():
107+
input_func_name = all_info[func_name]['get_inputs_function_name'][0]
108+
func_to_run = getattr(bundled_model, input_func_name)
109+
# retrieve input
110+
sample_input = func_to_run()
111+
model_function = getattr(bundled_model, func_name)
112+
for i in range(len(sample_input)):
113+
print(model_function(*sample_input[i]))
114+
115+
Inflatable args
116+
-------------------
117+
Attaching inputs to models can result in nontrivial size increases. Inflatable args are a way to compress and decompress inputs to minimize this impact.
118+
119+
.. note:: Any automatic compression, or parsing of inflatable args only happens to top level arguments in the input tuple.
120+
121+
- ie if your model takes in a List type of inputs you would need to create an inflatable arg that returned a list not create a list of inflatable args.
122+
123+
1. **Existing Inflatable args**
124+
125+
The following input types are compressed automatically without requiring an explicit inflatable arg:
126+
- Small contiguous tensors are cloned to have small storage.
127+
- Inputs from torch.zeros, torch.ones, or torch.full are moved to their compact representations.
128+
129+
.. code:: python
130+
131+
# bundle_randn will generate a random tensor when the model is asked for bundled inputs
132+
sample_inputs = [(torch.utils.bundled_inputs.bundle_randn((1,10)),)]
133+
bundled_model = bundle_inputs(scripted_module, sample_inputs)
134+
print(bundled_model.get_all_bundled_inputs())
135+
136+
2. **Creating your own**
137+
138+
Inflatable args are composed of 2 parts, the deflated (compressed) argument, and an expression or function definition to inflate them.
139+
140+
.. code:: python
141+
142+
def create_example(*size, dtype=None):
143+
"""Generate a tuple of 2 random tensors both of the specified size"""
144+
145+
deflated_input = (torch.zeros(1, dtype=dtype).expand(*size), torch.zeros(1, dtype=dtype).expand(*size))
146+
147+
# {0} is how you access your deflated value in the inflation expression
148+
return torch.utils.bundled_inputs.InflatableArg(
149+
value=stub,
150+
fmt="(torch.randn_like({0}[0]), torch.randn_like({0}[1]))",
151+
)
152+
153+
3. **Using a function instead**
154+
If you need to create a more complicated input providing a function is an easy alternative
155+
156+
.. code:: python
157+
158+
sample = dict(
159+
a=torch.zeros([10, 20]),
160+
b=torch.zeros([1, 1]),
161+
c=torch.zeros([10, 20]),
162+
)
163+
164+
def condensed(t):
165+
ret = torch.empty_like(t).flatten()[0].clone().expand(t.shape)
166+
assert ret.storage().size() == 1
167+
return ret
168+
169+
# An example of how to create an inflatable arg for a complex model input like Optional[Dict[str, Tensor]]
170+
# here we take in a normal input, deflate it, and define an inflater function that converts the mapped tensors to random values
171+
def bundle_optional_dict_of_randn(template: Optional[Dict[str, Tensor]]):
172+
return torch.utils.bundled_inputs.InflatableArg(
173+
value=(
174+
None
175+
if template is None
176+
else {k: condensed(v) for (k, v) in template.items()}
177+
),
178+
fmt="{}",
179+
fmt_fn="""
180+
def {}(self, value: Optional[Dict[str, Tensor]]):
181+
if value is not None:
182+
output = {{}}
183+
for k, v in value.items():
184+
output[k] = torch.randn_like(v)
185+
return output
186+
else:
187+
return None
188+
""",
189+
)
190+
191+
sample_inputs = (
192+
bundle_optional_dict_of_randn(sample),
193+
)
194+
195+
196+
Learn More
197+
----------
198+
- To learn more about PyTorch Mobile, please refer to `PyTorch Mobile Home Page <https://pytorch.org/mobile/home/>`_

0 commit comments

Comments
 (0)