8000 LBFGSFunctionOptimizer by DirkToewe · Pull Request #1385 · tensorflow/tfjs-core · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

LBFGSFunctionOptimizer #1385

Open
wants to merge 14 commits into
base: master 8000
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added reset to L-BFGS if no progress was made.
  • Loading branch information
DirkToewe committed Nov 9, 2018
commit 519df5260c31f5381206ee98ff7012a5c3fa060b
56 changes: 47 additions & 9 deletions src/optimizers/lbfgs_function_optimizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ function dotProd( x: Tensor1D, y: Tensor1D ) {
return z;
}

export class LineSearchError extends Error {
constructor( msg: string ) {
super(msg);
}
}

export class LineSearchNoProgressError extends LineSearchError {
constructor( msg: string ) {
super(msg);
}
}

/** The function type of a linesearch method.
*
* @param fg A function that returns both the value and gradients of the
Expand Down Expand Up @@ -151,7 +163,7 @@ export function strongWolfeLineSearch(
}

if( ! (α < αMax) ) {
throw new Error(
throw new LineSearchError(
'strongWolfeLineSearch(): '
+ 'Strong Wolfe condition not satisfiable in range.'
);
Expand All @@ -162,7 +174,7 @@ export function strongWolfeLineSearch(
}

if( αMin === αMax ) {
throw new Error('strongWolfeLineSearch: bracketing failed.');
throw new LineSearchError('strongWolfeLineSearch: bracketing failed.');
}

// STEP 2: BISECTION PHASE
Expand All @@ -181,7 +193,9 @@ export function strongWolfeLineSearch(

if( f - f0 > c1*α*p0 || f >= fMin ) {
if( αMax === α ) {
throw new Error('strongWolfeLineSearch(): bisection failed.');
throw new LineSearchError(
'strongWolfeLineSearch(): bisection failed.'
);
}
αMax = α;
}
Expand All 8000 @@ -198,14 +212,20 @@ export function strongWolfeLineSearch(
}

if( αMin === α ) {
throw new Error('strongWolfeLineSearch(): bisection failed.');
throw new LineSearchError(
'strongWolfeLineSearch(): bisection failed.'
);
}
αMin = α;
fMin = f;
}

if( αMin === αMax ) {
throw new Error('strongWolfeLineSearch(): bisection failed.');
const msg = 'strongWolfeLineSearch(): bisection failed.';
if( αMin === 0) {
throw new LineSearchNoProgressError(msg);
}
throw new LineSearchError(msg);
}
}
};
Expand Down Expand Up @@ -396,15 +416,33 @@ export class LBFGSFunctionOptimizer {
dGdX = this.dGdX,
dG = this.dG;

const [x,f,g] = ENV.engine.tidy(
() => this.lineSearch(
const [x,f,g] = ( () => {
const stepFunc = () => this.lineSearch(
this.fg,
this.x,
this.f,
this.g,
this.negDir(this.g)
)
);
);
try {
return ENV.engine.tidy(stepFunc);
}
catch( err ) {
if( err instanceof LineSearchNoProgressError ) {
// reset line search
while( dX.length > 0 ) {
dX.pop().dispose();
dGdX.pop().dispose();
dG .pop().dispose();
}
// try one last time
return ENV.engine.tidy(stepFunc);
}
else {
throw err;
}
}
})();

const dXi = sub(x,this.x) as Tensor1D,
dGi = sub(g,this.g) as Tensor1D;
Expand Down
17 changes: 14 additions & 3 deletions src/optimizers/lbfgs_function_optimizer_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import {convertToTensor} from '../tensor_util_env';
import {valueAndGrad} from '../gradients';
import {ENV} from '../environment';
import {randomUniform} from '../ops/array_ops';
import {strongWolfeLineSearch, LBFGSFunctionOptimizer} from './lbfgs_function_optimizer';
import {strongWolfeLineSearch,
LBFGSFunctionOptimizer,
LineSearchNoProgressError} from './lbfgs_function_optimizer';
import {dot} from '../ops/matmul';

function rosenbrock( x: Tensor|TensorLike ): Tensor
Expand Down Expand Up @@ -257,10 +259,19 @@ describeWithFlags('lbfgs', ALL_ENVS, () => {
{ initNegDir: g => g.div(DENOM) }
);

while( ! opt.g.abs().lessEqual( scalar(2**-12) ).all().dataSync()[0] )
const atol = scalar( Math.sqrt(ENV.get('EPSILON')) );
while( ! opt.g.abs().lessEqual(atol).all().dataSync()[0] )
{
++nSteps;
opt.step();
try {
opt.step();
}
catch(err) {
if( err instanceof LineSearchNoProgressError ) {
break;
}
throw err;
}
}

expect(nCalls).toBeLessThan(256);
Expand Down
0