8000 Remove duplicate call to objective function in strong wolfe line sear… · pytorch/pytorch@f43165a · GitHub
[go: up one dir, main page]

Skip to content

Commit f43165a

Browse files
svenslaggarepytorchmergebot
authored andcommitted
Remove duplicate call to objective function in strong wolfe line search in L-BFGS optimizer. (#72773)
Summary: With this change, the optimizer is almost twice as fast as before. As the result of the first call is never used, it looks like a copy paste error and therefore can be removed. In addition, this duplicate call is not present in the Python implementation. Pull Request resolved: #72773 Reviewed By: samdow Differential Revision: D34214312 Pulled By: albanD fbshipit-source-id: 4f4de08633c7236f3ccce8a2a74e56500003281b (cherry picked from commit 4a63f81)
1 parent 80f2346 commit f43165a

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

torch/csrc/api/src/optim/lbfgs.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,6 @@ std::tuple<double, Tensor, double, int64_t> _strong_wolfe(const Function& obj_fu
232232
auto d_norm = val(d.abs().max());
233233
g = g.clone(at::MemoryFormat::Contiguous);
234234
// evaluate objective and gradient using initial step
235-
auto obj_func_res = obj_func(x, t, d);
236235
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
237236
double f_new;
238237
Tensor g_new;
@@ -285,7 +284,6 @@ std::tuple<double, Tensor, double, int64_t> _strong_wolfe(const Function& obj_fu
285284
f_prev = f_new;
286285
g_prev = g_new.clone(at::MemoryFormat::Contiguous);
287286
gtd_prev = gtd_new;
288-
obj_func_res = obj_func(x, t, d);
289287
std::tie(f_new, g_new) = obj_func(x, t, d);
290288
ls_func_evals += 1;
291289
gtd_new = g_new.dot(d);
@@ -335,9 +333,7 @@ std::tuple<double, Tensor, double, int64_t> _strong_wolfe(const Function& obj_fu
335333
}
336334

337335
// Evaluate new point
338-
obj_func_res = obj_func(x, t, d);
339-
f_new = std::get<0>(obj_func_res);
340-
g_new = std::get<1>(obj_func_res);
336+
std::tie(f_new, g_new) = obj_func(x, t, d);
341337
ls_func_evals += 1;
342338
gtd_new = g_new.dot(d);
343339
ls_iter += 1;

0 commit comments

Comments
 (0)
0