Bayan Bennett

Inference Using Web Workers—Typing Practice w/ Machine Learning

TensorFlow
JavaScript

Articles in this series:

  1. Introduction
  2. Pseudo-English
  3. Keyboard Input
  4. Inference Using Web Workers

The finished project is located here: https://www.bayanbennett.com/projects/rnn-typing-practice

The last post, Keyboard Input, was all about preparing the data that will be used to generate subsequent lines of characters. This article looks at how that input is used and how a line of text is generated using a TensorFlow model. Additionally, this model will be running entirely on a WebWorker.

To learn more about web workers, I've created a soft introduction here: Barebones Web Workers.

Setting Up the Worker

The following code is all within the mlKeyboard.worker.js file. First, import the TensorflowJS file:

import * as tf from "@tensorflow/tfjs";

Enable production mode

tf.enableProdMode();

Set the backend to webgl so that we can levrage the GPU. If this is not available, the TensorFlowJS backend will default to cpu.

tf.setBackend("webgl").catch(console.warn);

Fetch the model

This must be loaded from an endpoint on the server. Once it has been loaded, save it to our worker's state object and post a message to the main window indicating that the worker is ready.

tf.loadLayersModel("/model-checkpoint/model.json").then((model) => {
  if (model === null) return;

  // Note: This `setState` function is not from React, but mimics the functionality
  setState({ model });

  postMessage({ type: actionTypes.main.workerReady });
});

During initialization as well as when the user finishes typing a line, the main window will send a message of type actionTypes.worker.getNextLine. Some other functions are run, but eventually the function below, generateLine, gets called. This function initializes an array of 34 characters and sequentially calls an async function that generates each key. 34 was selected as it fits well on both desktop and mobile screens.

const generateLine = async () => new Array(34).fill(null).reduce(
  (state) => addKey(state),
  Promise.resolve({
    ...state,
    nextLine: [],
  })
);

Normalizing the Bigram

In the last post, Keyboard Inputs, a bigram was generated that recorded the average delay between any pair of keystrokes.

Bigram

To create the above image as well as to use it in the model, it needs to be normalized.

In this case I elected to simply convert each value into the number of standard deviations that value is from the mean, then make it a power of ee.

y=exμσy=e^{\frac{x-\mu}{\sigma}}
const normalizeBigramTensor = ({ bigram, stdDev, mean }) => {
  // When the bigram is untouched (i.e. where all the values are `null`), the mean will be null.
  if (mean === null) return tf.ones([vocabSize, vocabSize]);

  return tf.tidy(() => {
    const adjustmentTensor = tf.tensor2d(keyMatrix);
    const meanArray = tf.fill(adjustmentTensor.shape, mean);

    // The bigram won't be full of values, instead of assigning a value, the null values are masked out.
    const maskTensor = adjustmentTensor.asType("bool");

    return adjustmentTensor
      .where(maskTensor, meanArray)
      .sub(mean)
      .divNoNan(stdDev)
      .exp();
  });
};

Generate Each Character

Now here's where things get interesting, running the ML model!

Generate the first character

The model needs an input of the last character, so what if one is not available? One option is to train the model with all the words starting with a space. In that case, the first character would always be a space.

However, in my case I wanted to have more control over the first character, so I opted to generate the first character.

const getFirstLetterTensor = (normalizedBigram) => {
  // Get the bigram row for spaces, so we can see the speed at which characters that follow a space are typed. Then remove the first two characters of this row, which are the null character and space as those shouldn't be options.
  const spaceRow = normalizedKeyMatrix[spaceInt].slice(2);

  return tf.tidy(() => {
    const spaceRowTensor = tf.tensor1d(spaceRow).log();

    // Get a character. Since we sliced off the first two positions, the output value should be incremented by 2.
    return tf.multinomial(spaceRowTensor, 1).add(tf.scalar(2));
  });
};

Get last character in the line

Nothing much here. Straightforward. Get the last character from the line, if the line is empty for some reason, return a space.

const getLastCharInt = (nextLine) => {
  const { length } = nextLine;
  if(length < 1) return spaceInt;
  return nextLine[length - 1];
};

Get the next character using inference

This is where the magic happens. It starts with tf.Sequential.predict generating the probabilities for each character, then the character adjustments are applied, and tf.multinomial returns a character.

const getNextLetterTensor = ({
  lastCharInt,
  normalizedBigram,
  spaceAdjustment,
}) => tf.tidy(() => {
  const lastCharTensor = tf.tensor1d([lastCharInt]);
  const prediction = model.predict(lastCharTensor, { batchSize: 1 });
  const squeezedPrediction = tf.squeeze(prediction, [0, 0]);
  const rowSpecificAdjustment = tf.tensor1d(normalizedBigram[lastCharInt]);
  const spaceAdjustmentTensor = tf.tensor1d(spaceAdjustment);
  const probabilityDistribution = squeezedPrediction
    .mul(rowSpecificAdjustment)
    .add(spaceAdjustmentTensor);
  return tf.multinomial(probabilityDistribution.log(), 1);
});

Calculating the probability of a space

You may have noticed the spaceAdjustment variable above. The dataset is comprised of words of varying lengths, which makes it difficult for the model to accurately predict when a word should end. Knowing the distribution of word lengths, the prediction can be nudged in the right direction. This is the space cumulative distribution function:

y=ek(xx0)y = e^{k(x-x_0)}

Where k=0.64766k = 0.64766 and x0=6.18632x_0 = 6.18632. These values were found by fitting to the word length distribution. For example, a 5-letter word will increase the probability of a space by 46%.

Putting it all together

Most of the important logic has been described above, the code below puts them all together.

const addKey = async (state) => {
  const { model, nextLine, spaceDistance, normalizedBigram } = state;

  const lastCharInt = getLastCharInt(nextLine);

  // If the last key is a space, a new word is being generated so the model should be reset and the first character generated.
  if (lastCharInt === spaceInt) {
    model.resetStates();
    const [nextLetterInt] = await getFirstLetterTensor(
      normalizedBigram
    ).data();
    return {
      ...state,
      spaceDistance: 0,
      nextLine: [...nextLine, nextLetterInt],
    };
  }

  const spaceAdjustment = new Array(vocabSize).fill(0);

  // Since space probability in the model is so low, I'm injecting my own probability
  spaceAdjustment[spaceInt] = calculateSpaceProbability(spaceDistance);

  const nextLetterTensor = getNextLetterTensor({
    lastCharInt,
    normalizedBigram,
    spaceAdjustment,
  });

  const [nextLetterInt] = await nextLetterTensor.data();

  // If the output of the model is a null or space character, return a space and reset the spaceDistance
  if (nextLetterInt === nullInt || nextLetterInt === spaceInt) {
    return {
      ...state,
      nextLine: [...nextLine, spaceInt],
      spaceDistance: 0,
    };
  }

  // Return the next letter and increment spaceDistance
  return {
    ...state,
    nextLine: [...nextLine, nextLetterInt],
    spaceDistance: spaceDistance + 1,
  };
};

Get the last character, if it's a space or null character, generate the first character for a word. Otherwise, generate the next character by running the last character through the model and outputting a prediction. If the next character is space or null, set the space distance to 0. This function is then repeated for each letter in the line.

Concluding thoughts

This post is a demonstration of how ML model inference can run on a WebWorker. I was pleasantly surprised by how little boilerplate code was needed. All the code is just the TensorFlowJS library and some helper functions.

Ultimately, approaches like these will be useful in creating web applications with advanced ML capabilities without impacting usability.

© 2022 Bayan Bennett