diff --git a/Cargo.toml b/Cargo.toml index c38421f..c9dfe97 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,9 +21,9 @@ name = "_lib" crate-type = ["cdylib"] [dependencies] -# argmin_core = { path = "../argmin-core" } -# argmin = { path = "../argmin", features = ["ctrlc", "ndarrayl"]} -argmin = {version = "0.2.6", features = ["ctrlc", "ndarrayl"]} +#argmin = { path = "../argmin", features = ["ctrlc", "ndarrayl"]} +#argmin = {version = "0.2.6", features = ["ctrlc", "ndarrayl"]} +argmin = {git = "https://github.com/argmin-rs/argmin" } pyo3 = { version = "0.8.5", features = ["extension-module"] } serde = "1.0" numpy = "0.7.0" diff --git a/src/lib.rs b/src/lib.rs index 15b0f89..1ce747d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -96,18 +96,13 @@ impl PyLBFGS { } } -#[pyfunction] -/// blah -fn lbfgs(m: usize) -> Py { +#[pyfunction(m = "10", tol_grad = "1e-5")] +/// lbfgs(m=10, tol_grad=1e-5) +fn lbfgs(m: usize, tol_grad: f64) -> Py { let gil_guard = Python::acquire_gil(); let py = gil_guard.python(); - Py::new( - py, - PyLBFGS { - solver: LBFGS::new(MoreThuenteLineSearch::new(), m), - }, - ) - .unwrap() + let solver = LBFGS::new(MoreThuenteLineSearch::new(), m).with_tol_grad(tol_grad); + Py::new(py, PyLBFGS { solver: solver }).unwrap() } #[pyfunction]