Mayalen Etcheverry | March, 2024

Reproduce in Notebook Github Repository

This blogpost provides a tutorial on how to combine:

- The 🚀
**Transformer model**🚀 and more specifically Andrej Karphathy's nanoGPT architecture which is the decoder-only version of the Transformer model from the original "Attention is all you Need" paper (as well as OpenAI's GPT-2/3 models architecture) - A 😺
**drawing application**😺 that I've dug up from a very nice paper proposed by David Ha and Douglas Eck called sketchRNN, where they propose to train (recurrent) neural networks to learn to generate human-like doodles as a*sequence of strokes*, in a manner similar to how children learn to depict objects (and even abstract concepts) with only a few pen strokes

You can run the code directly on your browser using google colab, or you can download it as a notebook and run it on your local machine. Note that it will require training the transformer model in an auto-regressive fashion on a database of sketches, which can take varying time depending on your hardware setup. I personally did it on colab with a V100 GPU, which takes approximately 20 mins to run using the proposed hyperparameters.
If you do not have a GPU, I recommend playing with a smaller model by decreasing `embd`

or `num_head`

, and `n_layers`

for instance.

**Disclaimer:** This tutorial is just me playing with transformers around a fun project that I can combine with our CNC-drawing machine ✏
Nothing groundbreaking here as both Transformers and SketchRNN are from 2017, old times in the fast-paced world of machine learning!
In fact I'm quite late to the party as a 2020 paper by Ribeiro et
al. called "SketchFormer" already did something very similar, although with a slightly different architecture/pipeline as they used an encoder-decoder architecture as well as discretized tokens (or continuous but deterministic tokens) whereas I used a decoder-only architecture as well as a mixture density network (MDN) output layer.
More below 👇

Before diving into the code, let's look at what we will need/implement for the dataset, neural network model, and training loss.

The code uses the Quick, Draw! Dataset as training data.
The model is currently trained on the `cat`

class of this dataset, but other set of classes can easily be tested by changing the
`data_classes`

hyperparameter in the notebook.

For each class, the dataset contains a set of 70K sketches for training, 2.5K for validation and 2.5K for testing. A sketch is represented as a sequence of pen stroke actions where each action $a$ is a vector of 3 elements $a=(\Delta x, \Delta y, p)$. The $(\Delta x, \Delta y)$ values are continuous and represent the offset from the current pen position to the previous one, and are normalized in the notebook to have a standard deviations of 1. The $p$ value is discrete: $0$ for drawing, $1$ for lifting the pen and $2$ for end of sketch indicating that no subsequent points will be rendered.

The model architecture is summarized on the right image and below.

As you can see, the backbone is very similar to the decoder architecture of the famous Transformer model from the the "Attention is all you Need" paper.

There are however some minor differences within the decoder blocks which follow the nanoGPT tutorial:

- Layer normalization comes
*before*the multi-head attention (MHA) and feedforward (FF) layers - As there is no encoder, cross-attention layers are removed

👉 For more details on the self-attention mechanism and for an in-depth understanding of this architecture be sure to check the nanoGPT tutorial which I highly recommend to anyone interested!

The **input** and **output** layers are however quite different than in the original Transformer as we're dealing with sequences of *continuous* strokes and *discrete* pen actions, and not simply strings.

Let's start with the input layer. The input data, which is a partial sketch, is a sequence of stroke-3 tuples $(dx, dy, p)$ where $(dx, dy) \in \mathbb{R}^2$ and $p \in \lbrace 0,1,2 \rbrace$.

- The pen action $p$ is discrete and can take only 3-values, so the
`Pen Embedding`

layer is a simple table of 3 d-dimensional embeddings which we implement with the nn.Embedding layer. - The stroke action $(dx, dy)$, however, can take continuous values. Therefore, for the
`Stroke Embedding`

layer, we use a projector mapping the 2D-strokes into a d-dimensional space. - The positional embedding is, as traditionally,
a nn.Embedding table of size
`block_size`

(max content length) which describes the position of a stroke in the sequence.

For the output layer, we divide it in two heads:

- The
`MDN Head`

which stands for*Mixture Density Network*and that we detail below. - The
`Pen Head`

which is simply a Linear Layer which outputs logits of size 3. The output logits are used as parameters of a*Categorical Distribution*to sample pen actions at each time step, akin to what is done for string sequence generation.

Mixture Density Networks, originally proposed by
Christopher Bishop in 1994,
uses the output of a neural network as *parameters of a probability distribution* instead of direct output values.

In our case, as shown above, the network outputs are parameters $\lbrace \pi_j, \mu_j, \Sigma_j \rbrace_{j=1..M}$ of a Multivariate Gaussian Mixture Model (GMM) with $M=20$ components that we use to sample stroke actions from:

$$\begin{align*} a &\sim f(\theta) \newline f(\theta) &= \sum^{M}_{j=1} \pi_j \: \mathcal{N} (\mu_j, \Sigma_j) \end{align*}$$

The intuition here is that if we would directly train our network to predict strokes with a vanilla regression task, e.g. with MLE loss,
it would only learn to recover the average output stroke per input sequence. The average in that case can often be not correct: when drawing a cat face for instance, after drawing the ears
some might move to the eyes whereas others might draw the face contour and learning the average of these strokes options will result in something pretty bad looking.
By modelling the output values with a GMM (or other distribution) instead, the network can learn to generate *multiple* output values for a given
input hence is more likely to recover the true data distribution.

If you think about it, Transformers already do a similar thing with discrete sequences: they probabilistically generate an output token given an input token sequence
by modelling the outputs with a categorical distribution. MDN is a way to extend this idea to *continuous* tokens.

👉 David Ha has made a complete tutorial
on MDN that I highly recommend to intuitively understand why this can be useful for many modern ML tasks.
You can also have a look at
this google colab where I extend David Ha's tutorial to play
with MDN with *full* covariance matrix on a task with *2-dimensional* output, and with a new implementation
where we use `MultivariateNormal`

and `OneHotCategorical`

torch distributions to implement the GMM,
as well as `torch.logsumexp`

to compute the loss for numerical stability, akin to what is done in this repo.

Our loss is the same than the *reconstruction loss* $L_R$ of the Sketch-RNN paper, which basically maximilizes
the log-likelihood of the generated probability distributions to explain the training data.
More precisely $L_R$ is the sum of the negative log-likelihood of 1) the predicted GMM distributions
in explaining the $ \lbrace\Delta x_i\rbrace$ stroke actions ($L_s$) and 2) the predicted categorical
distributions in explaining the $\lbrace p_i \rbrace$ pen actions ($L_p$):

$$\begin{align*} L_R &= L_s + L_p \newline L_s &= - \frac{1}{N_{max}} \sum_{i=1}^{N_s} \log (\sum_{j=1}^{L} \prod_{j,i} \pi_{i,j} \mathcal{N} (\Delta x_i | \mu_{i,j}, \Sigma_{i,j})) \newline L_p &= - \frac{1}{N_{max}} \sum_{i=1}^{N_{max}} \log q_{i}[p_i] \end{align*}$$

where $(\pi_{i,j}, \mu_{i,j}, \Sigma_{i,j})_{j=1..M}$ are the outputs of the `MDN head`

for the i-th entry, and $q_i$ are the outputs of the `Pen head`

.

```
!pip install svgwrite
```

Collecting svgwrite Downloading svgwrite-1.4.3-py3-none-any.whl (67 kB) [?25l [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/67.1 kB[0m [31m?[0m eta [36m-:--:--[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.1/67.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m [?25hInstalling collected packages: svgwrite Successfully installed svgwrite-1.4.3

```
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.distributions import MultivariateNormal, OneHotCategorical, Categorical
import requests
import io
import svgwrite
from IPython.display import Image, SVG, display, HTML
from google.colab.output import eval_js
```

```
torch.manual_seed(1337)
```

```
# Hyper-parameters
data_classes = ["cat"]
batch_size = 64 # how many independent sequences will we process in parallel?
# block_size = None # what is the maximum context length for predictions? here we take block_size = max number of strokes of our database
training_iters = 24000
eval_interval = training_iters // 20
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 10
embd = 384
embd_ffn = 4 * embd # 4times as in "attention is all you need paper"
num_heads = 6 #every head is embd/num_head dimensional
n_layers = 6
dropout = 0.2 # 20% of operations are randomly masked at each forward/backward pass
n_components = 20 # number of gaussians in the MDN output layer
```

```
# helper function for draw_strokes
def get_bounds(data, factor):
min_x = 0
max_x = 0
min_y = 0
max_y = 0
abs_x = 0
abs_y = 0
for i in range(len(data)):
x = float(data[i,0])/factor
y = float(data[i,1])/factor
abs_x += x
abs_y += y
min_x = min(min_x, abs_x)
min_y = min(min_y, abs_y)
max_x = max(max_x, abs_x)
max_y = max(max_y, abs_y)
return (min_x, max_x, min_y, max_y)
def create_path(data, factor, abs_x, abs_y, lift_pen=1):
command = "m"
p = "M%s,%s " % (abs_x, abs_y)
for i in range(len(data)):
if (lift_pen == 1):
command = "m"
elif (command != "l"):
command = "l"
else:
command = ""
x = float(data[i,0])/factor
abs_x += x
y = float(data[i,1])/factor
abs_y += y
lift_pen = data[i, 2]
p += command+str(x)+","+str(y)+" "
return p, abs_x, abs_y
# little function that displays vector images
def draw_strokes(data, factor=0.2, svg_filename='sample.svg', the_color="black", stroke_width=1):
min_x, max_x, min_y, max_y = get_bounds(data, factor)
dims = (50 + max_x - min_x, 50 + max_y - min_y)
dwg = svgwrite.Drawing(svg_filename, size=dims)
dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
abs_x = 25 - min_x
abs_y = 25 - min_y
p, _, _ = create_path(data, factor, abs_x, abs_y)
dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
dwg.save()
svg_str = dwg.tostring()
return svg_str
def draw_two_strokes(data1, data2, color1="black", color2="brown", factor=0.2, svg_filename="sample.svg", stroke_width=1):
min_x, max_x, min_y, max_y = get_bounds(torch.concatenate([data1, data2]), factor)
dims = (50 + max_x - min_x, 50 + max_y - min_y)
dwg = svgwrite.Drawing(svg_filename, size=dims)
dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
abs_x = 25 - min_x
abs_y = 25 - min_y
p1, abs_x, abs_y = create_path(data1, factor, abs_x, abs_y)
dwg.add(dwg.path(p1).stroke(color1, stroke_width).fill("none"))
p2, _, _ = create_path(data2, factor, abs_x, abs_y, lift_pen=0)
dwg.add(dwg.path(p2).stroke(color2, stroke_width).fill("none"))
dwg.save()
svg_str = dwg.tostring()
return svg_str
```

```
train_set, valid_set, test_set = [], [], []
for data_class in data_classes:
data_url = f"https://storage.googleapis.com/quickdraw_dataset/sketchrnn/{data_class}.npz"
response = requests.get(data_url)
response.raise_for_status()
load_data = np.load(io.BytesIO(response.content), allow_pickle=True, encoding='latin1')
train_set += load_data['train'].tolist()
valid_set += load_data['valid'].tolist()
test_set += load_data['test'].tolist()
print(len(train_set))
print(len(valid_set))
print(len(test_set))
```

70000 2500 2500

```
# get max len
max_len = 0
for x in train_set:
max_len = max(max_len,len(x))
print(max_len)
block_size = max_len
assert block_size <= 250
```

129

```
max_w = 0
max_h = 0
x_mean = 0.
y_mean = 0.
x_std = 1.
y_std = 1.
N = 0
for x in train_set:
min_x, max_x, min_y, max_y = get_bounds(x, factor=1)
max_w = max(max_w, max_x-min_x)
max_h = max(max_h, max_y-min_y)
x_std += ((x[:,0] - x_mean)**2).sum()
y_std += ((x[:,1] - y_mean)**2).sum()
N += len(x)
x_std = np.sqrt(x_std/N)
y_std = np.sqrt(y_std/N)
print(max_w, max_h, x_std, y_std)
```

2263.0 1116.0 51.54030056217729 32.8188575795014

```
# data loading
def get_batch(split):
# generate a small batch of data of inputs x and targets y
if split == "train":
data = train_set
elif split == "valid":
data = valid_set
else:
data = test_set
ix = torch.randint(len(data), (batch_size,))
xs = []
ys = []
lengths = []
for i in ix:
# convert in stroke-5 with normalizedx-y values to be roughly N(0,I)
x, y, p = torch.tensor(data[i]).T
x = torch.maximum(torch.minimum(x, torch.tensor([1000])), torch.tensor([-1000]))
x = (x - x_mean) / x_std
y = torch.maximum(torch.minimum(y, torch.tensor([1000])), torch.tensor([-1000]))
y = (y - y_mean) / y_std
p[-1] = 2.
d = torch.stack([x, y, p], -1)
# pad with empty values till block size
xs.append(torch.concatenate([d, torch.tensor([0., 0., 2.]).repeat((block_size-len(d), 1))]))
ys.append(torch.concatenate([d[1:], torch.tensor([0., 0., 2.]).repeat((block_size+1-len(d), 1))]))
lengths.append(len(data[i]))
xs, ys, lengths = torch.stack(xs), torch.stack(ys), torch.tensor(lengths)
mask = torch.arange(block_size).expand((batch_size, block_size)) <= lengths.unsqueeze(1).expand((batch_size, block_size))
xs, ys, mask = xs.to(device), ys.to(device), mask.to(device)
return xs, ys, mask
```

```
# draw random examples from the train set
xs, ys, mask = get_batch("train")
print(xs.shape)
n_samples = 10
svg_samples = [draw_strokes(xs[i], factor=0.1) for i in range(n_samples)]
no_wrap_div = '<div style="white-space: nowrap">'+'{}'*n_samples+'</div>'
display(HTML(no_wrap_div.format(*svg_samples)))
```

torch.Size([64, 129, 3])

```
class MDN(nn.Module):
"""
Mixture density network compatible with full covariance.
Adapted from https://github.com/haimengzhao/full-cov-mdn
[ Bishop, 1994 ]
Parameters
----------
dim_in: int; dimensionality of the covariates
dim_out: int; dimensionality of the response variable
n_components: int; number of components in the mixture model
full_cov: bool; whether to use full or diagonal covariance matrix
"""
def __init__(self, dim_in, dim_out, n_components, full_cov=True):
super().__init__()
self.pi_net = OneHotCategoricalNetwork(dim_in, n_components)
self.normal_net = NormalNetwork(dim_in, dim_out, n_components, full_cov)
def forward(self, x, tau=1.):
return self.pi_net(x, tau), self.normal_net(x, tau)
class NormalNetwork(nn.Module):
def __init__(self, in_dim, out_dim, n_components, full_cov=True):
super().__init__()
self.n_components = n_components
self.out_dim = out_dim
self.full_cov = full_cov
self.tril_indices = torch.tril_indices(row=out_dim, col=out_dim, offset=0)
self.mean_net = nn.Linear(in_dim, out_dim * n_components)
if full_cov:
# Cholesky decomposition of the covariance matrix
self.tril_net = nn.Linear(in_dim, int(out_dim * (out_dim + 1) / 2 * n_components))
else:
self.tril_net = nn.Linear(in_dim, out_dim * n_components)
def forward(self, x, tau=1.):
mean = self.mean_net(x).reshape(x.shape[0], x.shape[1], self.n_components, self.out_dim) # B, T, M, d
if self.full_cov:
tril_values = self.tril_net(x).reshape(x.shape[0], x.shape[1], self.n_components, -1) # B, T, M, (d**2+d)/2
tril = torch.zeros(mean.shape[0], mean.shape[1], mean.shape[2], mean.shape[3], mean.shape[3]).to(x.device) # B, T, M, d, d
tril[:, :, :, self.tril_indices[0], self.tril_indices[1]] = tril_values
# use diag = exp(diag) to ensure stric positivity of diagonal elements
tril.diagonal(dim1=-2, dim2=-1)[:] = tril.diagonal(dim1=-2, dim2=-1).exp()
else:
tril = self.tril_net(x).reshape(x.shape[0], x.shape[1], self.n_components, -1)
tril = torch.diag_embed(tril.exp())
tril *= tau
return MultivariateNormal(mean, scale_tril=tril)
class OneHotCategoricalNetwork(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.network = nn.Linear(in_dim, out_dim)
def forward(self, x, tau=1.):
logits = self.network(x) / tau
return OneHotCategorical(logits=logits)
class CategoricalNetwork(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.network = nn.Linear(in_dim, out_dim)
def forward(self, x, tau=1.):
logits = self.network(x) / tau
return Categorical(logits=logits)
```

```
# Attention Head
class Head(nn.Module):
def __init__(self, head_size):
super().__init__()
self.query = nn.Linear(embd, head_size, bias=False)
self.key = nn.Linear(embd, head_size, bias=False)
self.value = nn.Linear(embd, head_size, bias=False)
self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
def forward(self, x):
B, T, C = x.shape
q = self.query(x) # B, T, C
k = self.key(x) # B, T, C
# compute an attention score ("affinities")
wei = q@k.transpose(-2, -1) * C **(-0.5)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # "decoder" block because of triangular masking (autoregressive setting)
wei = F.softmax(wei, dim=-1)
# perform the weighted aggregation of the values
v = self.value(x) # B, T, C
out = wei @ v # B, T, C
return out
class MultiHead(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(embd, embd) #projection layer going back into the residual pathway
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.proj(out)
return out
class FeedForward(nn.Module):
def __init__(self, embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(embd, embd_ffn),
nn.ReLU(),
nn.Linear(embd_ffn, embd), # projection layer going back into the residual pathway
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
""" Transformer Block: communication/sensing followed by computation/update"""
def __init__(self, embd, num_heads):
super().__init__()
self.sa_heads = MultiHead(num_heads, embd//num_heads)
self.ffwd = FeedForward(embd)
self.ln1 = nn.LayerNorm(embd) #should be equivalent to LayerNorm1D
self.ln2 = nn.LayerNorm(embd)
def forward(self, x):
# x = self.sa_heads(x) # apply one head of self-attention (B, T, C) <=> "comunication" or "sense"
# x = self.ffwd(x) # (B, T, C) => this is one a per-token level <=> "update"
x = x + self.sa_heads(self.ln1(x)) # residual connection <=> "highway" of information and residual paths
x = x + self.ffwd(self.ln2(x)) # residual connection
return x
class TransformerModel(nn.Module):
def __init__(self):
super().__init__()
# each token directly reads off the logits for the next token from a lookup table
self.stroke_embedding_proj = nn.Linear(2, embd, bias=False)
self.pen_embedding_table = nn.Embedding(3, embd)
self.position_embedding_table = nn.Embedding(block_size, embd)
self.blocks = nn.Sequential(*[Block(embd, num_heads) for _ in range(n_layers)])
self.ln_f = nn.LayerNorm(embd)
self.mdn_head = MDN(embd, 2, n_components)
self.pen_head = CategoricalNetwork(embd, 3)
def forward(self, x, tau=1.):
B, T, C = x.shape
# assert C == 3
# idx and targets are both (B,T) tensor of integers
stroke_emb = self.stroke_embedding_proj(x[:, :, :2]) # (B,T,2) @ (2, embd) = (B, T, embd)
pen_emb = self.pen_embedding_table(x[:, :, 2].long()) # (B, T, embd)
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) #(T, embd)
x = stroke_emb + pen_emb + pos_emb # (B, T, embd)
# forward through attention heads
x = self.blocks(x) # (B, T, C)
x = self.ln_f(x)
# forward though mdn and head
pi_net, normal_net = self.mdn_head(x, tau=tau)
q_net = self.pen_head(x, tau=tau)
return pi_net, normal_net, q_net
def loss(self, x, targets, mask):
pi, normal, q = self.forward(x)
ys = targets[:, :, :2]
loglik = normal.log_prob(ys.unsqueeze(-2).expand_as(normal.loc))
Ls = -torch.logsumexp(torch.log(pi.probs) + loglik, dim=-1)
Ls *= mask
yp = targets[:, :, 2]
Lp = -q.log_prob(yp)
return Ls + Lp
def sample(self, x, tau=1.):
pi, normal, q = self.forward(x, tau)
s_samples = torch.sum(pi.sample().unsqueeze(-1) * normal.sample(), dim=-2)
p_samples = q.sample()
return torch.cat([s_samples, p_samples.unsqueeze(-1)], dim=-1)
@torch.no_grad()
def generate(self, x, max_new_tokens, tau=1., break_eos=True):
# x is (1, T, 3)
for _ in range(max_new_tokens):
# get the predictions
samples_next = self.sample(x, tau=tau)[:, -1, :].unsqueeze(1)
# append sampled stroke + pen index to the running sequence
x = torch.cat([x, samples_next], dim=1)
# break if end of sketch
if break_eos:
if samples_next[0,0,2] == 2:
return x
return x
```

```
model = TransformerModel()
model = model.to(device)
print(f"Model has {sum([p.nelement() for p in model.parameters()])} parameters")
```

Model has 10739451 parameters

```
X, Y, mask = get_batch("train")
Y_pred = model.sample(X)
print(X.shape, Y.shape, Y_pred.shape)
```

torch.Size([64, 129, 3]) torch.Size([64, 129, 3]) torch.Size([64, 129, 3])

```
@torch.no_grad()
def estimate_loss():
out = {}
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y, mask = get_batch(split)
loss = model.loss(X, Y, mask)
losses[k] = loss.mean()
out[split] = losses.mean()
return out
```

```
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
#lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=training_iters)
for iter in range(training_iters):
#every once in a while evaluate the loss on train and val sets
if iter % eval_interval == 0:
model.eval()
# evaluate loss
losses = estimate_loss()
print(f'step {iter}: lr {optimizer.param_groups[0]["lr"]:.6f}, train loss {losses["train"]:.4f}, val loss {losses["val"]:.4f}')
# display random samples at current stage of training
n_samples = 10
svg_samples = [draw_strokes(model.generate(torch.zeros(1, 1, 3).to(device), max_new_tokens=block_size-1, break_eos=True, tau=.4)[0], factor=0.1) for _ in range(n_samples)]
no_wrap_div = '<div style="white-space: nowrap">'+'{}'*n_samples+'</div>'
display(HTML(no_wrap_div.format(*svg_samples)))
model.train()
# sample a batch of data
xb, yb, mask = get_batch('train')
# evaluate the loss
loss = model.loss(xb, yb, mask).mean()
# backward pass
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
#lr_scheduler.step()
```

step 0: lr 0.000100, train loss 3.1467, val loss 3.1501

step 1200: lr 0.000100, train loss 0.6012, val loss 0.6239

step 2400: lr 0.000100, train loss 0.4337, val loss 0.4464

step 3600: lr 0.000100, train loss 0.4150, val loss 0.4295

step 4800: lr 0.000100, train loss 0.5068, val loss 0.5206

step 6000: lr 0.000100, train loss 0.3361, val loss 0.3572

step 7200: lr 0.000100, train loss 0.3663, val loss 0.3860

step 8400: lr 0.000100, train loss 0.3024, val loss 0.3062

step 9600: lr 0.000100, train loss 0.3049, val loss 0.3309

step 10800: lr 0.000100, train loss 0.5361, val loss 0.5466

step 12000: lr 0.000100, train loss 0.3183, val loss 0.3592

step 13200: lr 0.000100, train loss 0.2766, val loss 0.2947

step 14400: lr 0.000100, train loss 0.2584, val loss 0.2685

step 15600: lr 0.000100, train loss 0.4060, val loss 0.4023

step 16800: lr 0.000100, train loss 0.2405, val loss 0.2707

step 18000: lr 0.000100, train loss 0.3355, val loss 0.3577

step 19200: lr 0.000100, train loss 0.2157, val loss 0.2660

step 20400: lr 0.000100, train loss 0.2235, val loss 0.2534

step 21600: lr 0.000100, train loss 0.4038, val loss 0.4099

step 22800: lr 0.000100, train loss 0.1654, val loss 0.1942

👉 We see the loss as well as example generated samples every 1200 training steps (one step = one batch of 64 sketches).

We can see that before training (step 0), the sketches are very short because the pen action $p=2$ (end of sketch) is sampled uniformely hence too early.

During training, the model shortly learns to draw longer sequences with already curvy shapes reminiscent of face contours (step 1200). As training progresses it seems to learn semblance of eyes, ears, and weird moustaches, although not very coherent yet (step 6000).

After 6000 steps, it becomes more and more evident that the model learns to capture the essential components of a cat sketch: a round face, two pointy ears, two eyes, a nose and some slightly better (yet still weird) moustaches.

Obviously the model is far from perfect and there are still several failed scribbles and funny generalizations of cats...

I've finally let the model train for a total of 24000 steps, and maybe more would have been useful as the the training/validation loss kept going down, although the loss I'm plotting here is quite noisy (estimated on 10 batches only).

Ultimately we can see the kind of sketches we obtain at the end of training. It's not perfect but I'm quite happy with the results and confident that more can be done to improve these doodles even further 🐈.
[Optional] The below code saves the trained weights to your local machine, which can be useful if you do not want to re-train the model each time. Note that we already provide trained weights for the `cat`

dataclass on the github repository.

```
from google.colab import files
torch.save(model.state_dict(), f'model_{"_".join(data_classes)}.pth')
files.download(f'model_{"_".join(data_classes)}.pth')
```

```
model.eval()
```

```
# Evaluate Reconstruction Loss on the full test set
@torch.no_grad()
def test_loss():
out = {}
model.eval()
# convert in stroke-5 with normalizedx-y values to be roughly N(0,I)
xs, ys, lengths = [], [], []
for i in range(len(test_set)):
x, y, p = torch.tensor(test_set[i]).T
x = (x - x_mean) / x_std
y = (y - y_mean) / y_std
p[-1] = 2.
d = torch.stack([x, y, p], -1)
# pad with empty values till block size
xs.append(torch.concatenate([d, torch.tensor([0., 0., 2.]).repeat((block_size-len(d), 1))]))
ys.append(torch.concatenate([d[1:], torch.tensor([0., 0., 2.]).repeat((block_size+1-len(d), 1))]))
lengths.append(len(test_set[i]))
xs, ys, lengths = torch.stack(xs), torch.stack(ys), torch.tensor(lengths)
mask = torch.arange(block_size).expand((len(test_set), block_size)) <= lengths.unsqueeze(1).expand((len(test_set), block_size))
xs, ys, mask = xs.to(device), ys.to(device), mask.to(device)
loss = model.loss(xs, ys, mask)
return loss
print(f"LR: {test_loss().mean()}")
```

LR: 0.21352873742580414

👉 The loss over the whole test set is $L_R \approx 0.21$.

```
# Complete Sketch from the test set with the trained model
X, Y, _ = get_batch("test")
n_samples = 5
n_per_sample = 10
t_init = 5
for i in range(n_samples):
svg_samples = [draw_strokes(model.generate(X[i, :t_init].unsqueeze(0).to(device), max_new_tokens=block_size-t_init, break_eos=True, tau=.4)[0], factor=0.1) for _ in range(n_per_sample)]
no_wrap_div = '<div style="white-space: nowrap">'+'{}'*n_per_sample+'</div>'
display(HTML(no_wrap_div.format(*svg_samples)))
```