Deep Learning Gymnastics #2: Tensor Indexing

Welcome to the second episode of the Deep Learning Gymnastics series. Hope you’re in good shape. Get warmed up. We start.

Today, we’ll talk about a simple yet important and powerful aspect of tensor manipulations: tensor indexing.

Batches and embeddings motivating example

At the heart of any modern deep learning model, you’ll most often deal with batches and embeddings.

Batches? Below is a toy example of what a batch from a training set could look like:

The numbers represent an index in a vocabulary of size N, representing any kind of entity. This could be letters or words in a language model, a movie in a recommender system, a segment on a map in an ETA model, or ads in an ad Network.

For the example, let’s assume those are letters (indexed between 0 and 26 for all letters + a special end character) as in the great Andrej Karpathy “makemore” series.

Embeddings? For each element of that vocabulary, you’ve learned a representation of its (latent) characteristics, represented by a vector of size k. This vector is often called embeddings. Continuing with our example above, let’s consider an embedding of size 4 for each element (in our case, english letters) of the vocabulary, i.e. a tensor of dimension (27, 4)

Here is the gymnastic exercise: you have a toy batch containing 8 examples of size 3, where each number in the example are taken from vocabulary of size 27 . You also have an embedding matrix of dimension (27,4), where each raw is an embedding vector of size 4, for all of the 27 element of the vocabulary. For each element of the batch, you need to fetch its embedding vector, to end up with a batch which is a tensor of dimension (8,3,4) . This is illustrated below

Tensor indexing, the PyTorch way

Let’s first generate the two input tensors (the same as the two inputs on the left of the picture above ) :

import torch
torch.manual_seed(18)

# Create a random batch of shape (8,3) 
# with indexes between 0 and 26
random_tensor = torch.randint(low=0, high=26, size=(8,3))

# Create a random embedding matrix of shape (27,4): 
# one vector for each of the 27 indexes elements
embeddings = torch.randn(size=(27, 4))

And now, let’s solve the gymnastic exercise. Take a deep breath, prepare the movement, and here you go:

embedded_batch = embeddings[random_tensor]

Yes, that’s right. PyTorch allows to pass a full tensor as the index. And it works like magic.

You can check the shape of the result, and observe it is indeed (8,3,4), as expected (see the picture above). Indeed, (8,3) is the shape of the initial batch, and for each element of it, we get the proper embedding vector of shape (1,4).

Let’s validate that the first element of the result (embedded_batch[0,0] ) corresponds to the embedding vector of the index of the first element of the batch. This corresponds to this part of the picture:

And sure enough, it worked 🎉 :

What about TensorFlow?

In TensorFlow, it is of course possible to achieve the same result, but this is done a bit differently.

The tf.gather function

Instead of injecting the batch directly as a (tensor) index in the embedding matrix, in TensorFlow we have to use a very powerful function: tf.gather .

You can read the details of the documentation, but essentially, the equivalent of the following PyTorch indexing:

embedded_batch = embeddings[random_tensor] 

in TensorFlow would be:

embedded_batch = tf.gather(embeddings,random_tensor)

And that’s all.

Full equivalent TensorFlow code below :

import tensorflow as tf
tf.random.set_seed(18)

# Create a random batch of shape (8,3) with indexes between 0 and 26
random_tensor = tf.random.uniform(shape=(8,3), minval=0, maxval=26, dtype=tf.int32)

# Create a random embedding matrix of shape (27,4): one vector for each of the 27 indexes elements
embeddings = tf.random.uniform((27,4), dtype=tf.float32)

# Solving the gymnastic exercise: creating an embedded batch with the tf.gather function
embedded_batch = tf.gather(embeddings,random_tensor)

# Validating the results
print(random_tensor)
print(embeddings)
print(embedded_batch.shape) # (8,3,4) which is the expected dimension
print(embedded_batch[0,0])

Hope you enjoyed the gymnastic lesson. Take some rest. Until the next one 🤸 .

References

Leave a Reply

Your email address will not be published. Required fields are marked *