Open In App

Tensorflow.js tf.LayersModel class .fit() Method

Last Updated : 01 Sep, 2021
Improve
Improve
Like Article
Like
Save
Share
Report

Tensorflow.js is an open-source library developed by Google for running machine learning models and deep learning neural networks in the browser or node environment.

The tf.LayersModel class .fit( ) method is used to train the model for the fixed number of epochs (iterations on a dataset).

Syntax:

fit(x, y, args?)

Parameters: This method accepts the following parameters.

  • x: It is tf.Tensor that contains all the input data.
  • y: It is tf.Tensor that contains all the output data.
  • args: It is object type, it’s variables are as follows:
    1. batchSize: It defines the number of samples that will propagate through training.
    2. epochs: It defines iteration over the training data arrays.
    3. verbose: It help in showing the progress for each epoch. If the value is 0 – It means no printed message during fit() call. If the value is 1 – It means in Node-js , it prints the progress bar. In the browser, it shows no action. Value 1 is the default value . 2 – Value 2 is not implemented yet.
    4. callbacks: It defines a list of callbacks to be call during training. Variable can have one or more of these callbacks onTrainBegin( ), onTrainEnd( ), onEpochBegin( ), onEpochEnd( ), onBatchBegin( ), onBatchEnd( ), onYield( ).
    5. validationSplit: It makes it easy for the user to split the training dataset into train and validation. For example : if it value is validation-Split = 0.5 ,it means use last 50% of data before shuffling for validation.
    6. validationData: It is used to give an estimate of the final model when selecting between final models.
    7. shuffle: This value defines the shuffle of the data before each epoch. It has no effect when stepsPerEpoch is not null.
    8. classWeight: It is used for weighting the loss function. It can be useful to tell the model to pay more attention to samples from an under-represented class.
    9. sampleWeight: It is an array of weights to apply to the model’s loss for each sample.
    10. initialEpoch: It is value define epoch at which to start training. It is useful for resuming a previous training run.
    11. stepsPerEpoch: It defines a number of batches of samples before declaring one epoch finished and starting the next epoch. It is equal to 1 if not determined.
    12. validationSteps: It is relevant if stepsPerEpoch is specified. The total number of steps to validate before stopping.
    13. yieldEvery: It defines the configuration of the frequency of yielding the main thread to other tasks. It can be auto, It means the yielding happens at a certain frame rate. batch, If the value is this, It yields every batch. epoch, If the value is this, It yields every epoch. any number, If the value is any number, it yields every number milliseconds.never, If the value is this, it never yields.

Returns: It returns the promise of history.

Example 1:

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model
const mymodel = tf.sequential({
     layers: [tf.layers.dense({units: 2, inputShape: [6]})]
});
  
// Compiling the above model
mymodel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
  
// Using for loop
for (let i = 0; i < 4; i++) {
     
  // Calling fit() method
  const his = await mymodel.fit(tf.zeros([6, 6]), tf.ones([6, 2]), {
       batchSize: 5,
       epochs: 4
   });
    
    // Printing output
    console.log(his.history.loss[1]);
}


Output:

0.9574100375175476
0.8151942491531372
0.694103479385376
0.5909997820854187

Example 2:

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model
const mymodel = tf.sequential({
     layers: [tf.layers.dense({units: 2, inputShape: [6], 
                               activation : "sigmoid"})]});
  
// Compiling the above model
mymodel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});   
  
// Calling fit() method
 const his = await mymodel.fit(tf.truncatedNormal([6, 6]), 
                             tf.randomNormal([6, 2]), { batchSize: 5,
                             epochs: 4, validationSplit: 0.2, 
                             shuffle: true, initialEpoch: 2, 
                             stepsPerEpoch: 1, validationSteps: 2});
    
// Printing output
console.log(JSON.stringify(his.history));


Output:

{"val_loss":[0.35800713300704956,0.35819053649902344],
"loss":[0.633269190788269,0.632409930229187]}

Reference: https://js.tensorflow.org/api/latest/#tf.LayersModel.fit



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads