Open In App

Tensorflow.js tf.train.Optimizer Class

Last Updated : 18 Aug, 2021
Improve
Improve
Like Article
Like
Save
Share
Report

Tensorflow.js is an open-source library developed by Google for running machine learning models and deep learning neural networks in the browser or node environment. The tf.train.Optimizer() class is used to extend Serializable class.

This tf.train.Optimizer() class contains three inbuilt functions which are illustrated below.

The tf.train.Optimizer() class .minimize() function is used to execute the given function f() and minimize the scalar output of f() by computing the gradients of y with respect to the given list of trainable variables denoted by varList. Moreover, if no list is provided, it compute gradients with respect to all the trainable variables.

 

Example 1:

Javascript




// Importing tensorflow
import tensorflow as tf
 
const xs = tf.tensor1d([0, 1, 2]);
const ys = tf.tensor1d([1.3, 2.5, 3.7]);
 
const x = tf.scalar(Math.random()).variable();
const y = tf.scalar(Math.random()).variable();
 
// Define a function f(x, y) = x + y.
const f = x => x.add(y);
const loss = (pred, label) =>
    pred.sub(label).square().mean();
 
  const learningRate = 0.05;
 
  // Create adagrad optimizer
  const optimizer =
  tf.train.adagrad(learningRate);
 
  // Train the model.
  for (let i = 0; i < 5; i++) {
  optimizer.minimize(() => loss(f(xs), ys));
  }
 
  // Make predictions.
  console.log(
  `x: ${x.dataSync()}, y: ${y.dataSync()}`);
  const preds = f(xs).dataSync();
  preds.forEach((pred, i) => {
  console.log(`x: ${i}, pred: ${pred}`);
  });


Output:

x: 0.9395854473114014, y: 1.0498266220092773
x: 0, pred: 1.0498266220092773
x: 1, pred: 2.0498266220092773
x: 2, pred: 3.0498266220092773

Example 2: The tf.train.Optimizer() class .computeGradients() function is used to execute f() and compute the gradient of the scalar output of f() with respect to the list of trainable variables provided by varList. Moreover, if no list is provided, it defaults to all the trainable variables.

Javascript




// Importing tensorflow
import * as tf from "@tensorflow/tfjs"
      
const xs = tf.tensor1d([3, 4, 5]);
const ys = tf.tensor1d([3.5, 4.7, 5.3]);
      
const x = tf.scalar(Math.random()).variable();
const y = tf.scalar(Math.random()).variable();
      
// Define a function f(x, y) = ( x^2 ) - y.
const f = x => (x.square()).sub(y);
const loss = (pred, label) =>
    pred.sub(label).square().mean();
      
const learningRate = 0.05;
      
// Create adam optimizer
const optimizer =
tf.train.adam(learningRate);
      
// Train the model.
for (let i = 0; i < 6; i++) {
optimizer.computeGradients(() => loss(f(xs), ys));
}
      
// Make predictions.
console.log(
`x: ${x.dataSync()}, y: ${y.dataSync()}`);
const preds = f(xs).dataSync();
preds.forEach((pred, i) => {
console.log(`x: ${i}, pred: ${pred}`);
});


Output:

x: 0.38272422552108765, y: 0.7651948928833008
x: 0, pred: 8.2348051071167
x: 1, pred: 15.2348051071167
x: 2, pred: 24.234806060791016

Example 3: The tf.train.Optimizer() class .applyGradients() function is used for updating variables by using the computed gradients.

Javascript




// Importing tensorflow
 import * as tf from "@tensorflow/tfjs"
 
 const xs = tf.tensor1d([0, 1, 2]);
 const ys = tf.tensor1d([1.58, 2.24, 3.41]);
 
 const x = tf.scalar(Math.random()).variable();
 const y = tf.scalar(Math.random()).variable();
 
 // Define a function f(x) = x^2 + y.
 const f = x => (x.square()).add(y);
 
 const learningRate = 0.05;
 
 // Create adagrad optimizer
 const optimizer =
 tf.train.rmsprop(learningRate);
 
 // Updating variable
 const varGradients = f(xs).dataSync();
 for (let i = 0; i < 5; i++){
 optimizer.applyGradients(varGradients);
 }
 
 // Make predictions.
 console.log(
 `x: ${x.dataSync()}, y: ${y.dataSync()}`);
 const preds = f(xs).dataSync();
 preds.forEach((pred, i) => {
 console.log(`x: ${i}, pred: ${pred}`);
});


Output:

x: -0.526353657245636, y: 0.15607579052448273
x: 0, pred: 0.15607579052448273
x: 1, pred: 1.1560758352279663
x: 2, pred: 4.156075954437256

Reference: https://js.tensorflow.org/api/latest/#class:train.Optimizer



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads