Learning to generate human-like sketches with transformers

Mayalen Etcheverry | March, 2024

Reproduce in Notebook Github Repository

This blogpost provides a tutorial on how to combine:

You can run the code directly on your browser using google colab, or you can download it as a notebook and run it on your local machine. Note that it will require training the transformer model in an auto-regressive fashion on a database of sketches, which can take varying time depending on your hardware setup. I personally did it on colab with a V100 GPU, which takes approximately 20 mins to run using the proposed hyperparameters. If you do not have a GPU, I recommend playing with a smaller model by decreasing embd or num_head, and n_layers for instance.

Disclaimer: This tutorial is just me playing with transformers around a fun project that I can combine with our CNC-drawing machine ✏ Nothing groundbreaking here as both Transformers and SketchRNN are from 2017, old times in the fast-paced world of machine learning! In fact I'm quite late to the party as a 2020 paper by Ribeiro et al. called "SketchFormer" already did something very similar, although with a slightly different architecture/pipeline as they used an encoder-decoder architecture as well as discretized tokens (or continuous but deterministic tokens) whereas I used a decoder-only architecture as well as a mixture density network (MDN) output layer. More below 👇

Overview of the Sketch-Transformer Pipeline

Before diving into the code, let's look at what we will need/implement for the dataset, neural network model, and training loss.


The code uses the Quick, Draw! Dataset as training data. The model is currently trained on the cat class of this dataset, but other set of classes can easily be tested by changing the data_classes hyperparameter in the notebook.

For each class, the dataset contains a set of 70K sketches for training, 2.5K for validation and 2.5K for testing. A sketch is represented as a sequence of pen stroke actions where each action $a$ is a vector of 3 elements $a=(\Delta x, \Delta y, p)$. The $(\Delta x, \Delta y)$ values are continuous and represent the offset from the current pen position to the previous one, and are normalized in the notebook to have a standard deviations of 1. The $p$ value is discrete: $0$ for drawing, $1$ for lifting the pen and $2$ for end of sketch indicating that no subsequent points will be rendered.

Sketch-Transformer Model

The model architecture is summarized on the right image and below.

Decoder Backbone

As you can see, the backbone is very similar to the decoder architecture of the famous Transformer model from the the "Attention is all you Need" paper.

There are however some minor differences within the decoder blocks which follow the nanoGPT tutorial:

👉 For more details on the self-attention mechanism and for an in-depth understanding of this architecture be sure to check the nanoGPT tutorial which I highly recommend to anyone interested!

Input and Output Layers

The input and output layers are however quite different than in the original Transformer as we're dealing with sequences of continuous strokes and discrete pen actions, and not simply strings.

Let's start with the input layer. The input data, which is a partial sketch, is a sequence of stroke-3 tuples $(dx, dy, p)$ where $(dx, dy) \in \mathbb{R}^2$ and $p \in \lbrace 0,1,2 \rbrace$.

For the output layer, we divide it in two heads:

Mixture Density Network (MDN)

Mixture Density Networks, originally proposed by Christopher Bishop in 1994, uses the output of a neural network as parameters of a probability distribution instead of direct output values.

In our case, as shown above, the network outputs are parameters $\lbrace \pi_j, \mu_j, \Sigma_j \rbrace_{j=1..M}$ of a Multivariate Gaussian Mixture Model (GMM) with $M=20$ components that we use to sample stroke actions from:

$$\begin{align*} a &\sim f(\theta) \newline f(\theta) &= \sum^{M}_{j=1} \pi_j \: \mathcal{N} (\mu_j, \Sigma_j) \end{align*}$$

The intuition here is that if we would directly train our network to predict strokes with a vanilla regression task, e.g. with MLE loss, it would only learn to recover the average output stroke per input sequence. The average in that case can often be not correct: when drawing a cat face for instance, after drawing the ears some might move to the eyes whereas others might draw the face contour and learning the average of these strokes options will result in something pretty bad looking. By modelling the output values with a GMM (or other distribution) instead, the network can learn to generate multiple output values for a given input hence is more likely to recover the true data distribution.

If you think about it, Transformers already do a similar thing with discrete sequences: they probabilistically generate an output token given an input token sequence by modelling the outputs with a categorical distribution. MDN is a way to extend this idea to continuous tokens.

👉 David Ha has made a complete tutorial on MDN that I highly recommend to intuitively understand why this can be useful for many modern ML tasks. You can also have a look at this google colab where I extend David Ha's tutorial to play with MDN with full covariance matrix on a task with 2-dimensional output, and with a new implementation where we use MultivariateNormal and OneHotCategorical torch distributions to implement the GMM, as well as torch.logsumexp to compute the loss for numerical stability, akin to what is done in this repo.


Our loss is the same than the reconstruction loss $L_R$ of the Sketch-RNN paper, which basically maximilizes the log-likelihood of the generated probability distributions to explain the training data. More precisely $L_R$ is the sum of the negative log-likelihood of 1) the predicted GMM distributions in explaining the $ \lbrace\Delta x_i\rbrace$ stroke actions ($L_s$) and 2) the predicted categorical distributions in explaining the $\lbrace p_i \rbrace$ pen actions ($L_p$):

$$\begin{align*} L_R &= L_s + L_p \newline L_s &= - \frac{1}{N_{max}} \sum_{i=1}^{N_s} \log (\sum_{j=1}^{L} \prod_{j,i} \pi_{i,j} \mathcal{N} (\Delta x_i | \mu_{i,j}, \Sigma_{i,j})) \newline L_p &= - \frac{1}{N_{max}} \sum_{i=1}^{N_{max}} \log q_{i}[p_i] \end{align*}$$

where $(\pi_{i,j}, \mu_{i,j}, \Sigma_{i,j})_{j=1..M}$ are the outputs of the MDN head for the i-th entry, and $q_i$ are the outputs of the Pen head.

Let's implement it! 💻

Imports and Utils

!pip install svgwrite
Collecting svgwrite
  Downloading svgwrite-1.4.3-py3-none-any.whl (67 kB)
[?25l     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/67.1 kB ? eta -:--:--
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 67.1/67.1 kB 2.6 MB/s eta 0:00:00
[?25hInstalling collected packages: svgwrite
Successfully installed svgwrite-1.4.3
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.distributions import MultivariateNormal, OneHotCategorical, Categorical
import requests
import io
import svgwrite
from IPython.display import Image, SVG, display, HTML
from google.colab.output import eval_js
# Hyper-parameters
data_classes = ["cat"]
batch_size = 64 # how many independent sequences will we process in parallel?
# block_size = None # what is the maximum context length for predictions? here we take block_size = max number of strokes of our database
training_iters = 24000
eval_interval = training_iters // 20
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 10
embd = 384
embd_ffn = 4 * embd # 4times as in "attention is all you need paper"
num_heads = 6 #every head is embd/num_head dimensional
n_layers = 6
dropout = 0.2 # 20% of operations are randomly masked at each forward/backward pass
n_components = 20 # number of gaussians in the MDN output layer
# helper function for draw_strokes
def get_bounds(data, factor):
  min_x = 0
  max_x = 0
  min_y = 0
  max_y = 0

  abs_x = 0
  abs_y = 0
  for i in range(len(data)):
    x = float(data[i,0])/factor
    y = float(data[i,1])/factor
    abs_x += x
    abs_y += y
    min_x = min(min_x, abs_x)
    min_y = min(min_y, abs_y)
    max_x = max(max_x, abs_x)
    max_y = max(max_y, abs_y)

  return (min_x, max_x, min_y, max_y)

def create_path(data, factor, abs_x, abs_y, lift_pen=1):
  command = "m"
  p = "M%s,%s " % (abs_x, abs_y)
  for i in range(len(data)):
    if (lift_pen == 1):
      command = "m"
    elif (command != "l"):
      command = "l"
      command = ""
    x = float(data[i,0])/factor
    abs_x += x
    y = float(data[i,1])/factor
    abs_y += y
    lift_pen = data[i, 2]
    p += command+str(x)+","+str(y)+" "
  return p, abs_x, abs_y

# little function that displays vector images
def draw_strokes(data, factor=0.2, svg_filename='sample.svg', the_color="black", stroke_width=1):
  min_x, max_x, min_y, max_y = get_bounds(data, factor)
  dims = (50 + max_x - min_x, 50 + max_y - min_y)
  dwg = svgwrite.Drawing(svg_filename, size=dims)
  dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
  abs_x = 25 - min_x
  abs_y = 25 - min_y
  p, _, _ = create_path(data, factor, abs_x, abs_y)
  svg_str = dwg.tostring()
  return svg_str

def draw_two_strokes(data1, data2, color1="black", color2="brown", factor=0.2, svg_filename="sample.svg", stroke_width=1):
  min_x, max_x, min_y, max_y = get_bounds(torch.concatenate([data1, data2]), factor)
  dims = (50 + max_x - min_x, 50 + max_y - min_y)
  dwg = svgwrite.Drawing(svg_filename, size=dims)
  dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
  abs_x = 25 - min_x
  abs_y = 25 - min_y
  p1, abs_x, abs_y = create_path(data1, factor, abs_x, abs_y)
  dwg.add(dwg.path(p1).stroke(color1, stroke_width).fill("none"))
  p2, _, _ = create_path(data2, factor, abs_x, abs_y, lift_pen=0)
  dwg.add(dwg.path(p2).stroke(color2, stroke_width).fill("none"))
  svg_str = dwg.tostring()
  return svg_str

Prepare the Dataset

train_set, valid_set, test_set = [], [], []
for data_class in data_classes:
  data_url = f"https://storage.googleapis.com/quickdraw_dataset/sketchrnn/{data_class}.npz"
  response = requests.get(data_url)
  load_data = np.load(io.BytesIO(response.content), allow_pickle=True, encoding='latin1')
  train_set += load_data['train'].tolist()
  valid_set += load_data['valid'].tolist()
  test_set += load_data['test'].tolist()

# get max len
max_len = 0
for x in train_set:
    max_len = max(max_len,len(x))
block_size = max_len
assert block_size <= 250
max_w = 0
max_h = 0
x_mean = 0.
y_mean = 0.
x_std = 1.
y_std = 1.
N = 0
for x in train_set:
    min_x, max_x, min_y, max_y = get_bounds(x, factor=1)
    max_w = max(max_w, max_x-min_x)
    max_h = max(max_h, max_y-min_y)
    x_std += ((x[:,0] - x_mean)**2).sum()
    y_std += ((x[:,1] - y_mean)**2).sum()
    N += len(x)
x_std = np.sqrt(x_std/N)
y_std = np.sqrt(y_std/N)
print(max_w, max_h, x_std, y_std)
2263.0 1116.0 51.54030056217729 32.8188575795014
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    if split == "train":
        data = train_set
    elif split == "valid":
        data = valid_set
        data = test_set
    ix = torch.randint(len(data), (batch_size,))
    xs = []
    ys = []
    lengths = []
    for i in ix:
        # convert in stroke-5 with normalizedx-y values to be roughly N(0,I)
        x, y, p = torch.tensor(data[i]).T
        x = torch.maximum(torch.minimum(x, torch.tensor([1000])), torch.tensor([-1000]))
        x = (x - x_mean) / x_std
        y = torch.maximum(torch.minimum(y, torch.tensor([1000])), torch.tensor([-1000]))
        y = (y - y_mean) / y_std
        p[-1] = 2.
        d = torch.stack([x, y, p], -1)
        # pad with empty values till block size
        xs.append(torch.concatenate([d, torch.tensor([0., 0., 2.]).repeat((block_size-len(d), 1))]))
        ys.append(torch.concatenate([d[1:], torch.tensor([0., 0., 2.]).repeat((block_size+1-len(d), 1))]))
    xs, ys, lengths = torch.stack(xs), torch.stack(ys), torch.tensor(lengths)
    mask = torch.arange(block_size).expand((batch_size, block_size)) <=  lengths.unsqueeze(1).expand((batch_size, block_size))
    xs, ys, mask = xs.to(device), ys.to(device), mask.to(device)
    return xs, ys, mask
# draw random examples from the train set
xs, ys, mask = get_batch("train")

n_samples = 10
svg_samples = [draw_strokes(xs[i], factor=0.1) for i in range(n_samples)]
no_wrap_div = '<div style="white-space: nowrap">'+'{}'*n_samples+'</div>'
torch.Size([64, 129, 3])

Model: Autoregressive Transformer (Decoder) + MDN

class MDN(nn.Module):
    Mixture density network compatible with full covariance.
    Adapted from https://github.com/haimengzhao/full-cov-mdn

    [ Bishop, 1994 ]

    dim_in: int; dimensionality of the covariates
    dim_out: int; dimensionality of the response variable
    n_components: int; number of components in the mixture model
    full_cov: bool; whether to use full or diagonal covariance matrix
    def __init__(self, dim_in, dim_out, n_components, full_cov=True):
        self.pi_net = OneHotCategoricalNetwork(dim_in, n_components)
        self.normal_net = NormalNetwork(dim_in, dim_out, n_components, full_cov)

    def forward(self, x, tau=1.):
        return self.pi_net(x, tau), self.normal_net(x, tau)

class NormalNetwork(nn.Module):
    def __init__(self, in_dim, out_dim, n_components, full_cov=True):
        self.n_components = n_components
        self.out_dim = out_dim
        self.full_cov = full_cov
        self.tril_indices = torch.tril_indices(row=out_dim, col=out_dim, offset=0)
        self.mean_net = nn.Linear(in_dim, out_dim * n_components)
        if full_cov:
            # Cholesky decomposition of the covariance matrix
            self.tril_net = nn.Linear(in_dim, int(out_dim * (out_dim + 1) / 2 * n_components))
            self.tril_net = nn.Linear(in_dim, out_dim * n_components)

    def forward(self, x, tau=1.):
        mean = self.mean_net(x).reshape(x.shape[0], x.shape[1], self.n_components, self.out_dim) # B, T, M, d
        if self.full_cov:
            tril_values = self.tril_net(x).reshape(x.shape[0], x.shape[1], self.n_components, -1) # B, T, M, (d**2+d)/2
            tril = torch.zeros(mean.shape[0], mean.shape[1], mean.shape[2], mean.shape[3], mean.shape[3]).to(x.device) # B, T, M, d, d
            tril[:, :, :, self.tril_indices[0], self.tril_indices[1]] = tril_values
            # use diag = exp(diag) to ensure stric positivity of diagonal elements
            tril.diagonal(dim1=-2, dim2=-1)[:] = tril.diagonal(dim1=-2, dim2=-1).exp()
            tril = self.tril_net(x).reshape(x.shape[0], x.shape[1], self.n_components, -1)
            tril = torch.diag_embed(tril.exp())
        tril *= tau
        return MultivariateNormal(mean, scale_tril=tril)

class OneHotCategoricalNetwork(nn.Module):

    def __init__(self, in_dim, out_dim):
        self.network = nn.Linear(in_dim, out_dim)

    def forward(self, x, tau=1.):
        logits = self.network(x) / tau
        return OneHotCategorical(logits=logits)

class CategoricalNetwork(nn.Module):

    def __init__(self, in_dim, out_dim):
        self.network = nn.Linear(in_dim, out_dim)

    def forward(self, x, tau=1.):
        logits = self.network(x) / tau
        return Categorical(logits=logits)
# Attention Head
class Head(nn.Module):

    def __init__(self, head_size):
        self.query = nn.Linear(embd, head_size, bias=False)
        self.key = nn.Linear(embd, head_size, bias=False)
        self.value = nn.Linear(embd, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape

        q = self.query(x) # B, T, C
        k = self.key(x) # B, T, C

        # compute an attention score ("affinities")
        wei = q@k.transpose(-2, -1) * C **(-0.5)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # "decoder" block because of triangular masking (autoregressive setting)
        wei = F.softmax(wei, dim=-1)

        # perform the weighted aggregation of the values
        v = self.value(x)  # B, T, C
        out = wei @ v # B, T, C

        return out

class MultiHead(nn.Module):

    def __init__(self, num_heads, head_size):
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(embd, embd) #projection layer going back into the residual pathway

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

class FeedForward(nn.Module):

    def __init__(self, embd):
        self.net = nn.Sequential(
            nn.Linear(embd, embd_ffn),
            nn.Linear(embd_ffn, embd), # projection layer going back into the residual pathway

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer Block: communication/sensing followed by computation/update"""

    def __init__(self, embd, num_heads):
        self.sa_heads = MultiHead(num_heads, embd//num_heads)
        self.ffwd = FeedForward(embd)
        self.ln1 = nn.LayerNorm(embd) #should be equivalent to LayerNorm1D
        self.ln2 = nn.LayerNorm(embd)

    def forward(self, x):
        # x = self.sa_heads(x) # apply one head of self-attention (B, T, C) <=> "comunication" or "sense"
        # x = self.ffwd(x) # (B, T, C) => this is one a per-token level <=> "update"
        x = x + self.sa_heads(self.ln1(x)) # residual connection <=> "highway" of information and residual paths
        x = x + self.ffwd(self.ln2(x)) # residual connection

        return x

class TransformerModel(nn.Module):

    def __init__(self):
        # each token directly reads off the logits for the next token from a lookup table
        self.stroke_embedding_proj = nn.Linear(2, embd, bias=False)
        self.pen_embedding_table = nn.Embedding(3, embd)
        self.position_embedding_table = nn.Embedding(block_size, embd)
        self.blocks = nn.Sequential(*[Block(embd, num_heads) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(embd)
        self.mdn_head = MDN(embd, 2, n_components)
        self.pen_head = CategoricalNetwork(embd, 3)

    def forward(self, x, tau=1.):
        B, T, C = x.shape
        # assert C == 3

        # idx and targets are both (B,T) tensor of integers
        stroke_emb = self.stroke_embedding_proj(x[:, :, :2]) # (B,T,2) @ (2, embd) = (B, T, embd)
        pen_emb = self.pen_embedding_table(x[:, :, 2].long()) # (B, T, embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) #(T, embd)
        x = stroke_emb + pen_emb + pos_emb # (B, T, embd)

        # forward through attention heads
        x = self.blocks(x)  # (B, T, C)
        x = self.ln_f(x)

        # forward though mdn and head
        pi_net, normal_net = self.mdn_head(x, tau=tau)
        q_net = self.pen_head(x, tau=tau)

        return pi_net, normal_net, q_net

    def loss(self, x, targets, mask):
        pi, normal, q = self.forward(x)
        ys = targets[:, :, :2]
        loglik = normal.log_prob(ys.unsqueeze(-2).expand_as(normal.loc))
        Ls = -torch.logsumexp(torch.log(pi.probs) + loglik, dim=-1)
        Ls *= mask

        yp = targets[:, :, 2]
        Lp = -q.log_prob(yp)
        return Ls + Lp

    def sample(self, x, tau=1.):
        pi, normal, q = self.forward(x, tau)
        s_samples = torch.sum(pi.sample().unsqueeze(-1) * normal.sample(), dim=-2)
        p_samples = q.sample()
        return torch.cat([s_samples, p_samples.unsqueeze(-1)], dim=-1)

    def generate(self, x, max_new_tokens, tau=1., break_eos=True):

        # x is (1, T, 3)
        for _ in range(max_new_tokens):

            # get the predictions
            samples_next = self.sample(x, tau=tau)[:, -1, :].unsqueeze(1)

            # append sampled stroke + pen index to the running sequence
            x = torch.cat([x, samples_next], dim=1)

            # break if end of sketch
            if break_eos:
                if samples_next[0,0,2] == 2:
                    return x

        return x
model = TransformerModel()
model = model.to(device)
print(f"Model has {sum([p.nelement() for p in model.parameters()])} parameters")
Model has 10739451 parameters
X, Y, mask = get_batch("train")
Y_pred = model.sample(X)
print(X.shape, Y.shape, Y_pred.shape)
torch.Size([64, 129, 3]) torch.Size([64, 129, 3]) torch.Size([64, 129, 3])


def estimate_loss():
    out = {}
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y, mask = get_batch(split)
            loss = model.loss(X, Y, mask)
            losses[k] = loss.mean()
        out[split] = losses.mean()
    return out
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
#lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=training_iters)

for iter in range(training_iters):

  #every once in a while evaluate the loss on train and val sets
  if iter % eval_interval == 0:

      # evaluate loss
      losses = estimate_loss()
      print(f'step {iter}: lr {optimizer.param_groups[0]["lr"]:.6f}, train loss {losses["train"]:.4f}, val loss {losses["val"]:.4f}')

      # display random samples at current stage of training
      n_samples = 10
      svg_samples = [draw_strokes(model.generate(torch.zeros(1, 1, 3).to(device), max_new_tokens=block_size-1, break_eos=True, tau=.4)[0], factor=0.1) for _ in range(n_samples)]
      no_wrap_div = '<div style="white-space: nowrap">'+'{}'*n_samples+'</div>'


  # sample a batch of data
  xb, yb, mask = get_batch('train')

  # evaluate the loss
  loss = model.loss(xb, yb, mask).mean()

  # backward pass
step 0: lr 0.000100, train loss 3.1467, val loss 3.1501
step 1200: lr 0.000100, train loss 0.6012, val loss 0.6239
step 2400: lr 0.000100, train loss 0.4337, val loss 0.4464
step 3600: lr 0.000100, train loss 0.4150, val loss 0.4295
step 4800: lr 0.000100, train loss 0.5068, val loss 0.5206
step 6000: lr 0.000100, train loss 0.3361, val loss 0.3572
step 7200: lr 0.000100, train loss 0.3663, val loss 0.3860
step 8400: lr 0.000100, train loss 0.3024, val loss 0.3062
step 9600: lr 0.000100, train loss 0.3049, val loss 0.3309
step 10800: lr 0.000100, train loss 0.5361, val loss 0.5466
step 12000: lr 0.000100, train loss 0.3183, val loss 0.3592
step 13200: lr 0.000100, train loss 0.2766, val loss 0.2947
step 14400: lr 0.000100, train loss 0.2584, val loss 0.2685
step 15600: lr 0.000100, train loss 0.4060, val loss 0.4023
step 16800: lr 0.000100, train loss 0.2405, val loss 0.2707
step 18000: lr 0.000100, train loss 0.3355, val loss 0.3577
step 19200: lr 0.000100, train loss 0.2157, val loss 0.2660
step 20400: lr 0.000100, train loss 0.2235, val loss 0.2534
step 21600: lr 0.000100, train loss 0.4038, val loss 0.4099
step 22800: lr 0.000100, train loss 0.1654, val loss 0.1942

👉 We see the loss as well as example generated samples every 1200 training steps (one step = one batch of 64 sketches).

We can see that before training (step 0), the sketches are very short because the pen action $p=2$ (end of sketch) is sampled uniformely hence too early.

During training, the model shortly learns to draw longer sequences with already curvy shapes reminiscent of face contours (step 1200). As training progresses it seems to learn semblance of eyes, ears, and weird moustaches, although not very coherent yet (step 6000).

After 6000 steps, it becomes more and more evident that the model learns to capture the essential components of a cat sketch: a round face, two pointy ears, two eyes, a nose and some slightly better (yet still weird) moustaches.

Obviously the model is far from perfect and there are still several failed scribbles and funny generalizations of cats...

I've finally let the model train for a total of 24000 steps, and maybe more would have been useful as the the training/validation loss kept going down, although the loss I'm plotting here is quite noisy (estimated on 10 batches only).

Ultimately we can see the kind of sketches we obtain at the end of training. It's not perfect but I'm quite happy with the results and confident that more can be done to improve these doodles even further 🐈. [Optional] The below code saves the trained weights to your local machine, which can be useful if you do not want to re-train the model each time. Note that we already provide trained weights for the cat dataclass on the github repository.

from google.colab import files
torch.save(model.state_dict(), f'model_{"_".join(data_classes)}.pth')


# Evaluate Reconstruction Loss on the full test set
def test_loss():
    out = {}

    # convert in stroke-5 with normalizedx-y values to be roughly N(0,I)
    xs, ys, lengths = [], [], []
    for i in range(len(test_set)):
      x, y, p = torch.tensor(test_set[i]).T
      x = (x - x_mean) / x_std
      y = (y - y_mean) / y_std
      p[-1] = 2.
      d = torch.stack([x, y, p], -1)
      # pad with empty values till block size
      xs.append(torch.concatenate([d, torch.tensor([0., 0., 2.]).repeat((block_size-len(d), 1))]))
      ys.append(torch.concatenate([d[1:], torch.tensor([0., 0., 2.]).repeat((block_size+1-len(d), 1))]))
    xs, ys, lengths = torch.stack(xs), torch.stack(ys), torch.tensor(lengths)
    mask = torch.arange(block_size).expand((len(test_set), block_size)) <=  lengths.unsqueeze(1).expand((len(test_set), block_size))
    xs, ys, mask = xs.to(device), ys.to(device), mask.to(device)
    loss = model.loss(xs, ys, mask)
    return loss

print(f"LR: {test_loss().mean()}")
LR: 0.21352873742580414

👉 The loss over the whole test set is $L_R \approx 0.21$.

# Complete Sketch from the test set with the trained model
X, Y, _ = get_batch("test")
n_samples = 5
n_per_sample = 10
t_init = 5
for i in range(n_samples):
  svg_samples = [draw_strokes(model.generate(X[i, :t_init].unsqueeze(0).to(device), max_new_tokens=block_size-t_init, break_eos=True, tau=.4)[0], factor=0.1) for _ in range(n_per_sample)]
  no_wrap_div = '<div style="white-space: nowrap">'+'{}'*n_per_sample+'</div>'