Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: TypeScript Refactor #3

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

lukemovement
Copy link

Hey, I was just taking a punt at rewriting this into TypeScript to make it easier to use, but I ran into a few issues. So far I have gotten through most of the src directory, but I can't find use examples to match up a few of the type hinting. If you are alright to give me some pointers on it, I can finish off the build process so it can be used with the examples.

The places where I am struggling to match up the types are;
src/utils.mts:3
src/utils.mts:26
src/model.mts:771
src/model.mts:778

@lukemovement lukemovement changed the title TypeScript Refactor WIP: TypeScript Refactor Jul 27, 2023
@zemlyansky
Copy link
Owner

Hi @lukemovement, great idea! I use type hints in Python a lot; they really help with large codebases. Unfortunately, I haven't tried TypeScript yet, so it will take some time for me to grasp the main concepts.

Perhaps someone else can help with this refactor. I've just posted on ShowHN mentioning your PR:
https://news.ycombinator.com/item?id=36906098

@lukemovement
Copy link
Author

TypeScript is pretty similar to Java in regards to the type-hinting implementation. I'm slowly looking through the samples and found a few bits that I need. I'm actually unsure as to whether or not those bits I was having issues with are needed or not. They weren't used within demo.js and I was able to successfully return a prediction without using them. They are used within test.js at the root of the repository if you are familiar with that code and could shed some light on what that file does, please.

@robertleeplummerjr
Copy link

This has me very excited.

@zemlyansky
Copy link
Owner

Hi @lukemovement! I've been playing with TypeScript this weekend and have tried to implement more models in TensorFlow.js, such as ViT and CLIP. They share common layers with GPT, so I decided to merge them into a single repo: https://github.com/zemlyansky/modelzoo with corresponding NPM package: https://www.npmjs.com/package/modelzoo. I started it from scratch piecing some components together. Unfortunately, it's missing the original gpt-tfjs as well as your commits with the TypeScript refactor. I'm not sure if using a monorepo was a good long-term idea, but it seems to make iteration faster. What do you think?

@lukemovement
Copy link
Author

lukemovement commented Oct 23, 2023

Hi @lukemovement! I've been playing with TypeScript this weekend and have tried to implement more models in TensorFlow.js, such as ViT and CLIP. They share common layers with GPT, so I decided to merge them into a single repo: https://github.com/zemlyansky/modelzoo with corresponding NPM package: https://www.npmjs.com/package/modelzoo. I started it from scratch piecing some components together. Unfortunately, it's missing the original gpt-tfjs as well as your commits with the TypeScript refactor. I'm not sure if using a monorepo was a good long-term idea, but it seems to make iteration faster. What do you think?

The mono repo is down to preference. It can make merge requests more difficult to deal with on larger code bases. On a side note, are you aware you a reinventing the wheel with a lot of your code? Take a look at @tensorflow/tsjs-layers as they have Keras ported as in the Python version of TF. Edit: Maybe not, sorry struggling to wrap my head around it a bit. The code base tends to be more fragment in JS/TS project. E.g. a class per a file

I'm following the T5 architecture at the moment. Give this a try, it's much more effective at sharing data between embeddings as well as being less dependent on memory. I put it into an RNN, return the last thought, add it to the embeddings then normalize. A cross between T5 and BERT I guess, maybe closer to PaLM, I'm not sure

import * as tf from "@tensorflow/tfjs-node";
import { Attention } from "./multi-head-self-attention.mjs";

class _AttentionRNNCell extends tf.layers.RNNCell {
  units: number;
  attentionLayer: tf.layers.Layer;
  normalizeLayer: tf.layers.Layer;
  keyLayer: tf.layers.Layer;
  valueLayer: tf.layers.Layer;
  numHeads: number;
  stateSize: number;
  axis: number;

  constructor({
    units,
    keyLayer,
    valueLayer,
    numHeads,
    trainable = true,
  }: {
    units: number;
    keyLayer: tf.layers.Layer;
    valueLayer: tf.layers.Layer;
    numHeads: number;
    trainable?: boolean;
  }) {
    super({ trainable });
    this.units = units;
    this.axis = 1;
    this.attentionLayer = Attention({
      name: `attention`,
      axis: this.axis + 1,
    });
    this.keyLayer = keyLayer;
    this.valueLayer = valueLayer;
    this.numHeads = numHeads;
    this.stateSize = units;
    this.normalizeLayer = tf.layers.layerNormalization({});
    this.trainable = trainable;
  }

  build(inputShape: tf.Shape | tf.Shape[]): void {
    if ("number" !== typeof inputShape[0] && null !== inputShape[0]) {
      return this.build(inputShape[0]);
    }

    const dims = inputShape[inputShape.length - 1] as number;

    try {
      this.keyLayer.build([dims]);
      this.valueLayer.build([dims]);
      this.normalizeLayer.build([1, this.units]);
    } catch (error) {
      throw new Error(`Error building RNNCell: ${(error as any).message}`);
    }

    this.trainableWeights = [
      ...this.keyLayer.trainableWeights,
      ...this.valueLayer.trainableWeights,
      ...this.normalizeLayer.trainableWeights,
    ];

    this.nonTrainableWeights = [
      ...this.keyLayer.nonTrainableWeights,
      ...this.valueLayer.nonTrainableWeights,
      ...this.normalizeLayer.nonTrainableWeights,
    ];

    this.built = true;
  }

  computeHeadedShape(inputShape: number[]): number[] {
    const shape = [...inputShape] as number[];
    shape.splice(this.axis, 0, this.numHeads);
    shape[shape.length - 1] = this.units / this.numHeads;

    return shape;
  }

  /**
   * Perform the forward pass of the RNN cell.
   *
   * @param inputs - An array of input tensors.
   * @returns An array containing the output and recurrentKernel.
   */
  call(inputs: tf.Tensor<tf.Rank>[]): [tf.Tensor<tf.Rank>, tf.Tensor<tf.Rank>] {
    const input = inputs[0];
    const hPrev = inputs[1];

    const key = this.keyLayer.apply(input) as tf.Tensor<tf.Rank>;
    const value = this.valueLayer.apply(input) as tf.Tensor<tf.Rank>;
    input.dispose();

    const shape = this.computeHeadedShape(input.shape);

    const attention = this.attentionLayer.apply([
      hPrev.reshape(shape),
      key.reshape(shape),
      value.reshape(shape),
    ]) as tf.Tensor;
    key.dispose();
    value.dispose();

    const shapedAttention = attention.reshape(hPrev.shape);
    attention.dispose();

    const mul = tf.mul(hPrev, shapedAttention);

    const output = this.normalizeLayer.apply(mul) as tf.Tensor;
    mul.dispose();

    return [output, shapedAttention];
  }

  static get className() {
    return "AttentionRnnCell";
  }
}

tf.serialization.registerClass(_AttentionRNNCell);

export const AttentionRnnCell = (config: {
  units: number;
  keyLayer: tf.layers.Layer;
  valueLayer: tf.layers.Layer;
  numHeads: number;
  trainable?: boolean;
}) => new _AttentionRNNCell(config);
import type {
  Initializer,
  InitializerIdentifier,
} from "@tensorflow/tfjs-layers/dist/initializers";
import type { ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config.d.ts";
import type { Regularizer } from "@tensorflow/tfjs-layers/dist/regularizers";
import * as tf from "@tensorflow/tfjs-node";

class _Attention extends tf.layers.Layer {
  axis: number;

  constructor({ axis }: { axis: number }) {
    super();
    this.axis = axis;
  }

  computeOutputShape(inputShape: tf.Shape[]): tf.Shape | tf.Shape[] {
    return inputShape[0];
  }

  call(inputs: tf.Tensor[]) {
    const [query, key, value] = inputs as never as [
      tf.Tensor,
      tf.Tensor,
      tf.Tensor,
    ];

    const depth = query.shape[this.axis] as number;
    const logits = tf.matMul(query, key, false, true);
    query.dispose();
    key.dispose();

    const attention = tf.matMul(
      tf.softmax(logits.div(tf.scalar(Math.sqrt(depth)))),
      value,
    );
    logits.dispose();
    value.dispose();

    return attention;
  }

  static get className() {
    return "Attention";
  }
}

tf.serialization.registerClass(_Attention);

export const Attention = (config: { name: string; axis: number }) =>
  new _Attention(config);

``

@zemlyansky
Copy link
Owner

zemlyansky commented Oct 23, 2023

@lukemovement I am personally also in favor of the "do one thing" approach, but for now, it's more like a TypeScript learning exercise. If it goes further, it would still be possible to split the repo back into dozens of npm modules, collecting them into one through imports.

On layers: TensorFlow.js is quite tricky, as it requires tf.Layers in functional models. So you can't just mix regular ops (e.g., tf.add) when defining a model. On the other hand, you also can't just build and call other layers inside custom layers, as there's some issue with how weights are treated (that's why there's a custom dense op inside the Attention layer of gpt-tfjs). So it made sense to implement some missing pieces like the Slice layer. There's also a chance that I don't know what I'm doing, and there's an easier way to mix custom ops with tf.Layers, or some of them are already implemented :) T5 is very interesting indeed!

@zemlyansky
Copy link
Owner

zemlyansky commented Oct 23, 2023

@lukemovement In your example, you pass layers as parameters to the custom layer? That's interesting! Did that work up to grads calculation and training? I mean this.keyLayer and other

@lukemovement
Copy link
Author

@lukemovement In your example, you pass layers as parameters to the custom layer? That's interesting! Did that work up to grads calculation and training? I mean this.keyLayer and other

The layers passed through the constructer of the RNN cell are stock dense layers. The weights then have to be registered on the layer

    this.trainableWeights = [
      ...this.keyLayer.trainableWeights,
      ...this.valueLayer.trainableWeights,
      ...this.normalizeLayer.trainableWeights,
    ];

    this.nonTrainableWeights = [
      ...this.keyLayer.nonTrainableWeights,
      ...this.valueLayer.nonTrainableWeights,
      ...this.normalizeLayer.nonTrainableWeights,
    ];

For the imports, you can use import { dense } from "@tensorflow/tfjs-layer";, then treat the import as a function call.
For any missing layer, like tf.transpose for example, I would recommend using the log layer as a base to create the functionality you need.

@peacefulotter
Copy link
Contributor

Any news on this? Would be helpful to work with TS out-of-the-box!

@zemlyansky
Copy link
Owner

zemlyansky commented Dec 11, 2023

@peacefulotter, hi! Have you checked https://github.com/zemlyansky/modelzoo? GPT, ViT, CLIP are in Typescript there and pass tests. Documentation is still missing, but the GPT interface should be the same as gpt-tfjs one (upd. that model is mostly @lukemovement's TS fork with some small changes)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants