Skip to content

Commit 4e6976d

Browse files
SGD and TrainingModule in Python
Differential Revision: D63650449 Pull Request resolved: #5847
1 parent 3b25b05 commit 4e6976d

File tree

12 files changed

+463
-4
lines changed

12 files changed

+463
-4
lines changed

examples/llm_pte_finetuning/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def main() -> None:
9898
# for us to update with the gradients in-place.
9999
# See https://github.com/pytorch/executorch/blob/main/extension/pybindings/pybindings.cpp#L736
100100
# for more info.
101-
out = et_mod.forward((tokens, labels), clone_outputs=False) # pyre-ignore
101+
out = et_mod.forward((tokens, labels), clone_outputs=False)
102102

103103
loss = out[0]
104104
losses.append(loss.item())

extension/pybindings/pybindings.pyi

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,20 @@ class ExecuTorchModule:
3333
"""
3434

3535
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
36-
def __call__(self, inputs: Any) -> List[Any]: ...
36+
def __call__(self, inputs: Any, clone_outputs: bool = True) -> List[Any]: ...
3737
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
38-
def run_method(self, method_name: str, inputs: Sequence[Any]) -> List[Any]: ...
38+
def run_method(
39+
self,
40+
method_name: str,
41+
inputs: Sequence[Any], # pyre-ignore[2]: "Any" in parameter type annotations.
42+
clone_outputs: bool = True,
43+
) -> List[Any]: ...
3944
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
40-
def forward(self, inputs: Sequence[Any]) -> List[Any]: ...
45+
def forward(
46+
self,
47+
inputs: Sequence[Any], # pyre-ignore[2]: "Any" in parameter type annotations.
48+
clone_outputs: bool = True,
49+
) -> List[Any]: ...
4150
# pyre-ignore[3]: "Any" in return type annotations.
4251
def plan_execute(self) -> List[Any]: ...
4352
# Bundled program methods.

extension/training/TARGETS

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
5+
load(":targets.bzl", "define_common_targets")
6+
7+
oncall("executorch")
8+
9+
define_common_targets()
10+
11+
python_library(
12+
name = "lib",
13+
srcs = [
14+
"__init__.py",
15+
],
16+
deps = [
17+
"//executorch/extension/training/pybindings:_training_lib",
18+
"//executorch/extension/training/pybindings:_training_module",
19+
],
20+
)

extension/training/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from executorch.extension.training.pybindings._training_lib import get_sgd_optimizer
10+
11+
from executorch.extension.training.pybindings._training_module import (
12+
_load_for_executorch_for_training,
13+
_load_for_executorch_for_training_from_buffer,
14+
TrainingModule,
15+
)
16+
17+
__all__ = [
18+
"get_sgd_optimizer",
19+
"TrainingModule",
20+
"_load_for_executorch_for_training_from_buffer",
21+
"_load_for_executorch_for_training",
22+
]

extension/training/pybindings/TARGETS

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
5+
load(":targets.bzl", "define_common_targets")
6+
7+
oncall("executorch")
8+
9+
define_common_targets()
10+
11+
runtime.cxx_python_extension(
12+
name = "_training_lib",
13+
srcs = [
14+
"_training_lib.cpp",
15+
],
16+
base_module = "executorch.extension.training.pybindings",
17+
types = ["_training_lib.pyi"],
18+
visibility = ["//executorch/extension/training/..."],
19+
deps = [
20+
"//executorch/extension/aten_util:aten_bridge",
21+
"//executorch/extension/training/optimizer:sgd",
22+
],
23+
external_deps = [
24+
"pybind11",
25+
"libtorch_python",
26+
],
27+
)
28+
29+
runtime.python_library(
30+
name = "_training_module",
31+
srcs = [
32+
"_training_module.py",
33+
],
34+
base_module = "executorch.extension.training.pybindings",
35+
visibility = ["//executorch/extension/training/..."],
36+
deps = [
37+
"//caffe2:torch",
38+
"//executorch/extension/pybindings:portable_lib",
39+
],
40+
)
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <pybind11/pybind11.h>
10+
#include <pybind11/stl.h>
11+
#include <memory>
12+
13+
#include <ATen/Tensor.h>
14+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
15+
#include <torch/csrc/utils/pybind.h>
16+
#include "executorch/extension/tensor/tensor.h"
17+
#include "executorch/extension/training/optimizer/sgd.h"
18+
#ifndef USE_ATEN_LIB
19+
#include <executorch/extension/aten_util/aten_bridge.h>
20+
#endif
21+
22+
namespace py = pybind11;
23+
24+
namespace executorch {
25+
namespace extension {
26+
namespace training {
27+
28+
namespace {
29+
30+
struct PySGD final {
31+
explicit PySGD(
32+
const py::dict& named_params,
33+
double lr,
34+
double momentum,
35+
double dampening,
36+
double weight_decay,
37+
bool nesterov)
38+
: sgd_(nullptr),
39+
fqns_()
40+
#ifndef USE_ATEN_LIB
41+
,
42+
params_()
43+
#endif
44+
{
45+
std::map<exec_aten::string_view, exec_aten::Tensor> cpp_inputs;
46+
auto py_named_params =
47+
py::cast<std::unordered_map<std::string, at::Tensor>>(named_params);
48+
const auto params_size = py::len(named_params);
49+
fqns_ = std::vector<std::string>();
50+
fqns_.reserve(params_size);
51+
52+
for (auto pair : py_named_params) {
53+
fqns_.push_back(pair.first);
54+
exec_aten::string_view v{fqns_.back().c_str(), pair.first.size()};
55+
#ifndef USE_ATEN_LIB
56+
// convert at::Tensor to torch::executor::Tensor
57+
params_.emplace_back(alias_tensor_ptr_to_attensor(pair.second));
58+
cpp_inputs.insert({v, *params_.back()});
59+
#else
60+
cpp_inputs.insert({v, pair.second});
61+
#endif
62+
}
63+
sgd_ = std::make_unique<optimizer::SGD>(
64+
cpp_inputs,
65+
extension::training::optimizer::SGDOptions(
66+
lr, momentum, dampening, weight_decay, nesterov));
67+
}
68+
69+
// Not needed for now, so just delete.
70+
PySGD(const PySGD&) = delete;
71+
PySGD& operator=(const PySGD&) = delete;
72+
PySGD(PySGD&&) = delete;
73+
PySGD& operator=(PySGD&&) = delete;
74+
75+
void step(const py::dict& py_dict) {
76+
auto py_named_gradients =
77+
py::cast<std::unordered_map<std::string, at::Tensor>>(py_dict);
78+
std::map<exec_aten::string_view, exec_aten::Tensor> cpp_inputs;
79+
80+
std::vector<std::string> fqn;
81+
#ifndef USE_ATEN_LIB
82+
std::vector<TensorPtr> et_tensors;
83+
#endif
84+
85+
// Convert python objects into cpp.
86+
for (const auto& pair : py_named_gradients) {
87+
fqn.push_back(pair.first);
88+
auto at_tensor = pair.second;
89+
// alias_etensor_to_attensor will assert on this later, so to better
90+
// propogate up to python we check early and throw an exception.
91+
if (!at_tensor.is_contiguous()) {
92+
auto error_msg = "Gradient is not contiguous.";
93+
throw std::runtime_error(error_msg);
94+
}
95+
#ifndef USE_ATEN_LIB
96+
// convert at::Tensor to torch::executor::Tensor
97+
auto temp = alias_tensor_ptr_to_attensor(at_tensor);
98+
et_tensors.push_back(temp);
99+
cpp_inputs.insert({pair.first.c_str(), *et_tensors.back()});
100+
#else
101+
cpp_inputs.insert({pair.first.c_str(), at_tensor});
102+
#endif
103+
}
104+
105+
auto err = sgd_->step(cpp_inputs);
106+
if (err != runtime::Error::Ok) {
107+
throw std::runtime_error("SGD step failed");
108+
}
109+
}
110+
111+
private:
112+
// TODO(jakeszwe): Write an optimizer interface and use it here instead of SGD
113+
// specifically.
114+
std::unique_ptr<optimizer::SGD> sgd_ = nullptr;
115+
std::vector<std::string> fqns_;
116+
117+
#ifndef USE_ATEN_LIB // Portable mode
118+
std::vector<TensorPtr> params_;
119+
#endif
120+
;
121+
};
122+
123+
static std::unique_ptr<PySGD> get_sgd_optimizer(
124+
const py::dict& named_params,
125+
double lr,
126+
double momentum = 0,
127+
double dampening = 0,
128+
double weight_decay = 0,
129+
bool nesterov = false) {
130+
return std::make_unique<PySGD>(
131+
named_params, lr, momentum, dampening, weight_decay, nesterov);
132+
}
133+
134+
} // namespace
135+
136+
PYBIND11_MODULE(_training_lib, m) {
137+
m.def(
138+
"get_sgd_optimizer",
139+
&get_sgd_optimizer,
140+
py::arg("named_params"),
141+
py::arg("lr") = 0.1,
142+
py::arg("momentum") = 0.0,
143+
py::arg("dampening") = 0.0,
144+
py::arg("weight_decay") = 0.0,
145+
py::arg("nesterov") = false);
146+
py::class_<PySGD>(m, "ExecuTorchSGD").def("step", &PySGD::step);
147+
}
148+
149+
} // namespace training
150+
} // namespace extension
151+
} // namespace executorch
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from __future__ import annotations
10+
11+
from typing import Any, Dict, List, Optional, Sequence, Tuple
12+
13+
from executorch.exir._warnings import experimental
14+
from torch import Tensor
15+
16+
@experimental("This API is experimental and subject to change without notice.")
17+
class ExecuTorchSGD:
18+
"""SGD Optimizer.
19+
20+
.. warning::
21+
22+
This API is experimental and subject to change without notice.
23+
"""
24+
25+
def step(self, named_gradients: Dict[str, Tensor]) -> None:
26+
"""Take a step in the direction of the gradients."""
27+
...
28+
29+
@experimental("This API is experimental and subject to change without notice.")
30+
def get_sgd_optimizer(
31+
named_parameters: Dict[str, Tensor],
32+
lr: float,
33+
momentum: float = 0,
34+
dampening: float = 0,
35+
weight_decay: float = 0,
36+
nesterov: bool = False,
37+
) -> ExecuTorchSGD:
38+
"""Creates an sgd optimizer that operates on the passed in named_parameters according to the specified hyper parameters.
39+
40+
.. warning::
41+
42+
This API is experimental and subject to change without notice.
43+
...
44+
"""
45+
...

0 commit comments

Comments
 (0)