Mayalen Etcheverry | March, 2024
Reproduce in Notebook Github Repository
This blogpost provides a tutorial on how to combine:
You can run the code directly on your browser using google colab, or you can download it as a notebook and run it on your local machine. Note that it will require training the transformer model in an auto-regressive fashion on a database of sketches, which can take varying time depending on your hardware setup. I personally did it on colab with a V100 GPU, which takes approximately 20 mins to run using the proposed hyperparameters.
If you do not have a GPU, I recommend playing with a smaller model by decreasing embd
or num_head
, and n_layers
for instance.
Disclaimer: This tutorial is just me playing with transformers around a fun project that I can combine with our CNC-drawing machine ✏ Nothing groundbreaking here as both Transformers and SketchRNN are from 2017, old times in the fast-paced world of machine learning! In fact I'm quite late to the party as a 2020 paper by Ribeiro et al. called "SketchFormer" already did something very similar, although with a slightly different architecture/pipeline as they used an encoder-decoder architecture as well as discretized tokens (or continuous but deterministic tokens) whereas I used a decoder-only architecture as well as a mixture density network (MDN) output layer. More below 👇
Before diving into the code, let's look at what we will need/implement for the dataset, neural network model, and training loss.
The code uses the Quick, Draw! Dataset as training data.
The model is currently trained on the cat
class of this dataset, but other set of classes can easily be tested by changing the
data_classes
hyperparameter in the notebook.
For each class, the dataset contains a set of 70K sketches for training, 2.5K for validation and 2.5K for testing. A sketch is represented as a sequence of pen stroke actions where each action $a$ is a vector of 3 elements $a=(\Delta x, \Delta y, p)$. The $(\Delta x, \Delta y)$ values are continuous and represent the offset from the current pen position to the previous one, and are normalized in the notebook to have a standard deviations of 1. The $p$ value is discrete: $0$ for drawing, $1$ for lifting the pen and $2$ for end of sketch indicating that no subsequent points will be rendered.
The model architecture is summarized on the right image and below.
As you can see, the backbone is very similar to the decoder architecture of the famous Transformer model from the the "Attention is all you Need" paper.
There are however some minor differences within the decoder blocks which follow the nanoGPT tutorial:
👉 For more details on the self-attention mechanism and for an in-depth understanding of this architecture be sure to check the nanoGPT tutorial which I highly recommend to anyone interested!
The input and output layers are however quite different than in the original Transformer as we're dealing with sequences of continuous strokes and discrete pen actions, and not simply strings.
Let's start with the input layer. The input data, which is a partial sketch, is a sequence of stroke-3 tuples $(dx, dy, p)$ where $(dx, dy) \in \mathbb{R}^2$ and $p \in \lbrace 0,1,2 \rbrace$.
Pen Embedding
layer is a simple table of 3
d-dimensional embeddings which we implement
with the nn.Embedding layer.Stroke Embedding
layer, we use
a projector mapping the 2D-strokes into a d-dimensional space.block_size
(max
content length) which describes the position of a stroke in the sequence.For the output layer, we divide it in two heads:
MDN Head
which stands for Mixture Density Network and that we detail below.Pen Head
which is simply a Linear Layer which outputs logits of size 3. The output logits are used as parameters of a
Categorical Distribution to sample pen actions at each time step, akin to what is done for string sequence generation.Mixture Density Networks, originally proposed by Christopher Bishop in 1994, uses the output of a neural network as parameters of a probability distribution instead of direct output values.
In our case, as shown above, the network outputs are parameters $\lbrace \pi_j, \mu_j, \Sigma_j \rbrace_{j=1..M}$ of a Multivariate Gaussian Mixture Model (GMM) with $M=20$ components that we use to sample stroke actions from:
$$\begin{align*} a &\sim f(\theta) \newline f(\theta) &= \sum^{M}_{j=1} \pi_j \: \mathcal{N} (\mu_j, \Sigma_j) \end{align*}$$
The intuition here is that if we would directly train our network to predict strokes with a vanilla regression task, e.g. with MLE loss, it would only learn to recover the average output stroke per input sequence. The average in that case can often be not correct: when drawing a cat face for instance, after drawing the ears some might move to the eyes whereas others might draw the face contour and learning the average of these strokes options will result in something pretty bad looking. By modelling the output values with a GMM (or other distribution) instead, the network can learn to generate multiple output values for a given input hence is more likely to recover the true data distribution.
If you think about it, Transformers already do a similar thing with discrete sequences: they probabilistically generate an output token given an input token sequence by modelling the outputs with a categorical distribution. MDN is a way to extend this idea to continuous tokens.
👉 David Ha has made a complete tutorial
on MDN that I highly recommend to intuitively understand why this can be useful for many modern ML tasks.
You can also have a look at
this google colab where I extend David Ha's tutorial to play
with MDN with full covariance matrix on a task with 2-dimensional output, and with a new implementation
where we use MultivariateNormal
and OneHotCategorical
torch distributions to implement the GMM,
as well as torch.logsumexp
to compute the loss for numerical stability, akin to what is done in this repo.
Our loss is the same than the reconstruction loss $L_R$ of the Sketch-RNN paper, which basically maximilizes the log-likelihood of the generated probability distributions to explain the training data. More precisely $L_R$ is the sum of the negative log-likelihood of 1) the predicted GMM distributions in explaining the $ \lbrace\Delta x_i\rbrace$ stroke actions ($L_s$) and 2) the predicted categorical distributions in explaining the $\lbrace p_i \rbrace$ pen actions ($L_p$):
$$\begin{align*} L_R &= L_s + L_p \newline L_s &= - \frac{1}{N_{max}} \sum_{i=1}^{N_s} \log (\sum_{j=1}^{L} \prod_{j,i} \pi_{i,j} \mathcal{N} (\Delta x_i | \mu_{i,j}, \Sigma_{i,j})) \newline L_p &= - \frac{1}{N_{max}} \sum_{i=1}^{N_{max}} \log q_{i}[p_i] \end{align*}$$
where $(\pi_{i,j}, \mu_{i,j}, \Sigma_{i,j})_{j=1..M}$ are the outputs of the MDN head
for the i-th entry, and $q_i$ are the outputs of the Pen head
.
!pip install svgwrite
Collecting svgwrite Downloading svgwrite-1.4.3-py3-none-any.whl (67 kB) [?25l [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/67.1 kB[0m [31m?[0m eta [36m-:--:--[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.1/67.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m [?25hInstalling collected packages: svgwrite Successfully installed svgwrite-1.4.3
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.distributions import MultivariateNormal, OneHotCategorical, Categorical
import requests
import io
import svgwrite
from IPython.display import Image, SVG, display, HTML
from google.colab.output import eval_js
torch.manual_seed(1337)
# Hyper-parameters
data_classes = ["cat"]
batch_size = 64 # how many independent sequences will we process in parallel?
# block_size = None # what is the maximum context length for predictions? here we take block_size = max number of strokes of our database
training_iters = 24000
eval_interval = training_iters // 20
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 10
embd = 384
embd_ffn = 4 * embd # 4times as in "attention is all you need paper"
num_heads = 6 #every head is embd/num_head dimensional
n_layers = 6
dropout = 0.2 # 20% of operations are randomly masked at each forward/backward pass
n_components = 20 # number of gaussians in the MDN output layer
# helper function for draw_strokes
def get_bounds(data, factor):
min_x = 0
max_x = 0
min_y = 0
max_y = 0
abs_x = 0
abs_y = 0
for i in range(len(data)):
x = float(data[i,0])/factor
y = float(data[i,1])/factor
abs_x += x
abs_y += y
min_x = min(min_x, abs_x)
min_y = min(min_y, abs_y)
max_x = max(max_x, abs_x)
max_y = max(max_y, abs_y)
return (min_x, max_x, min_y, max_y)
def create_path(data, factor, abs_x, abs_y, lift_pen=1):
command = "m"
p = "M%s,%s " % (abs_x, abs_y)
for i in range(len(data)):
if (lift_pen == 1):
command = "m"
elif (command != "l"):
command = "l"
else:
command = ""
x = float(data[i,0])/factor
abs_x += x
y = float(data[i,1])/factor
abs_y += y
lift_pen = data[i, 2]
p += command+str(x)+","+str(y)+" "
return p, abs_x, abs_y
# little function that displays vector images
def draw_strokes(data, factor=0.2, svg_filename='sample.svg', the_color="black", stroke_width=1):
min_x, max_x, min_y, max_y = get_bounds(data, factor)
dims = (50 + max_x - min_x, 50 + max_y - min_y)
dwg = svgwrite.Drawing(svg_filename, size=dims)
dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
abs_x = 25 - min_x
abs_y = 25 - min_y
p, _, _ = create_path(data, factor, abs_x, abs_y)
dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
dwg.save()
svg_str = dwg.tostring()
return svg_str
def draw_two_strokes(data1, data2, color1="black", color2="brown", factor=0.2, svg_filename="sample.svg", stroke_width=1):
min_x, max_x, min_y, max_y = get_bounds(torch.concatenate([data1, data2]), factor)
dims = (50 + max_x - min_x, 50 + max_y - min_y)
dwg = svgwrite.Drawing(svg_filename, size=dims)
dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
abs_x = 25 - min_x
abs_y = 25 - min_y
p1, abs_x, abs_y = create_path(data1, factor, abs_x, abs_y)
dwg.add(dwg.path(p1).stroke(color1, stroke_width).fill("none"))
p2, _, _ = create_path(data2, factor, abs_x, abs_y, lift_pen=0)
dwg.add(dwg.path(p2).stroke(color2, stroke_width).fill("none"))
dwg.save()
svg_str = dwg.tostring()
return svg_str
train_set, valid_set, test_set = [], [], []
for data_class in data_classes:
data_url = f"https://storage.googleapis.com/quickdraw_dataset/sketchrnn/{data_class}.npz"
response = requests.get(data_url)
response.raise_for_status()
load_data = np.load(io.BytesIO(response.content), allow_pickle=True, encoding='latin1')
train_set += load_data['train'].tolist()
valid_set += load_data['valid'].tolist()
test_set += load_data['test'].tolist()
print(len(train_set))
print(len(valid_set))
print(len(test_set))
70000 2500 2500
# get max len
max_len = 0
for x in train_set:
max_len = max(max_len,len(x))
print(max_len)
block_size = max_len
assert block_size <= 250
129
max_w = 0
max_h = 0
x_mean = 0.
y_mean = 0.
x_std = 1.
y_std = 1.
N = 0
for x in train_set:
min_x, max_x, min_y, max_y = get_bounds(x, factor=1)
max_w = max(max_w, max_x-min_x)
max_h = max(max_h, max_y-min_y)
x_std += ((x[:,0] - x_mean)**2).sum()
y_std += ((x[:,1] - y_mean)**2).sum()
N += len(x)
x_std = np.sqrt(x_std/N)
y_std = np.sqrt(y_std/N)
print(max_w, max_h, x_std, y_std)
2263.0 1116.0 51.54030056217729 32.8188575795014
# data loading
def get_batch(split):
# generate a small batch of data of inputs x and targets y
if split == "train":
data = train_set
elif split == "valid":
data = valid_set
else:
data = test_set
ix = torch.randint(len(data), (batch_size,))
xs = []
ys = []
lengths = []
for i in ix:
# convert in stroke-5 with normalizedx-y values to be roughly N(0,I)
x, y, p = torch.tensor(data[i]).T
x = torch.maximum(torch.minimum(x, torch.tensor([1000])), torch.tensor([-1000]))
x = (x - x_mean) / x_std
y = torch.maximum(torch.minimum(y, torch.tensor([1000])), torch.tensor([-1000]))
y = (y - y_mean) / y_std
p[-1] = 2.
d = torch.stack([x, y, p], -1)
# pad with empty values till block size
xs.append(torch.concatenate([d, torch.tensor([0., 0., 2.]).repeat((block_size-len(d), 1))]))
ys.append(torch.concatenate([d[1:], torch.tensor([0., 0., 2.]).repeat((block_size+1-len(d), 1))]))
lengths.append(len(data[i]))
xs, ys, lengths = torch.stack(xs), torch.stack(ys), torch.tensor(lengths)
mask = torch.arange(block_size).expand((batch_size, block_size)) <= lengths.unsqueeze(1).expand((batch_size, block_size))
xs, ys, mask = xs.to(device), ys.to(device), mask.to(device)
return xs, ys, mask
# draw random examples from the train set
xs, ys, mask = get_batch("train")
print(xs.shape)
n_samples = 10
svg_samples = [draw_strokes(xs[i], factor=0.1) for i in range(n_samples)]
no_wrap_div = '<div style="white-space: nowrap">'+'{}'*n_samples+'</div>'
display(HTML(no_wrap_div.format(*svg_samples)))
torch.Size([64, 129, 3])
class MDN(nn.Module):
"""
Mixture density network compatible with full covariance.
Adapted from https://github.com/haimengzhao/full-cov-mdn
[ Bishop, 1994 ]
Parameters
----------
dim_in: int; dimensionality of the covariates
dim_out: int; dimensionality of the response variable
n_components: int; number of components in the mixture model
full_cov: bool; whether to use full or diagonal covariance matrix
"""
def __init__(self, dim_in, dim_out, n_components, full_cov=True):
super().__init__()
self.pi_net = OneHotCategoricalNetwork(dim_in, n_components)
self.normal_net = NormalNetwork(dim_in, dim_out, n_components, full_cov)
def forward(self, x, tau=1.):
return self.pi_net(x, tau), self.normal_net(x, tau)
class NormalNetwork(nn.Module):
def __init__(self, in_dim, out_dim, n_components, full_cov=True):
super().__init__()
self.n_components = n_components
self.out_dim = out_dim
self.full_cov = full_cov
self.tril_indices = torch.tril_indices(row=out_dim, col=out_dim, offset=0)
self.mean_net = nn.Linear(in_dim, out_dim * n_components)
if full_cov:
# Cholesky decomposition of the covariance matrix
self.tril_net = nn.Linear(in_dim, int(out_dim * (out_dim + 1) / 2 * n_components))
else:
self.tril_net = nn.Linear(in_dim, out_dim * n_components)
def forward(self, x, tau=1.):
mean = self.mean_net(x).reshape(x.shape[0], x.shape[1], self.n_components, self.out_dim) # B, T, M, d
if self.full_cov:
tril_values = self.tril_net(x).reshape(x.shape[0], x.shape[1], self.n_components, -1) # B, T, M, (d**2+d)/2
tril = torch.zeros(mean.shape[0], mean.shape[1], mean.shape[2], mean.shape[3], mean.shape[3]).to(x.device) # B, T, M, d, d
tril[:, :, :, self.tril_indices[0], self.tril_indices[1]] = tril_values
# use diag = exp(diag) to ensure stric positivity of diagonal elements
tril.diagonal(dim1=-2, dim2=-1)[:] = tril.diagonal(dim1=-2, dim2=-1).exp()
else:
tril = self.tril_net(x).reshape(x.shape[0], x.shape[1], self.n_components, -1)
tril = torch.diag_embed(tril.exp())
tril *= tau
return MultivariateNormal(mean, scale_tril=tril)
class OneHotCategoricalNetwork(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.network = nn.Linear(in_dim, out_dim)
def forward(self, x, tau=1.):
logits = self.network(x) / tau
return OneHotCategorical(logits=logits)
class CategoricalNetwork(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.network = nn.Linear(in_dim, out_dim)
def forward(self, x, tau=1.):
logits = self.network(x) / tau
return Categorical(logits=logits)
# Attention Head
class Head(nn.Module):
def __init__(self, head_size):
super().__init__()
self.query = nn.Linear(embd, head_size, bias=False)
self.key = nn.Linear(embd, head_size, bias=False)
self.value = nn.Linear(embd, head_size, bias=False)
self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
def forward(self, x):
B, T, C = x.shape
q = self.query(x) # B, T, C
k = self.key(x) # B, T, C
# compute an attention score ("affinities")
wei = q@k.transpose(-2, -1) * C **(-0.5)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # "decoder" block because of triangular masking (autoregressive setting)
wei = F.softmax(wei, dim=-1)
# perform the weighted aggregation of the values
v = self.value(x) # B, T, C
out = wei @ v # B, T, C
return out
class MultiHead(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(embd, embd) #projection layer going back into the residual pathway
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.proj(out)
return out
class FeedForward(nn.Module):
def __init__(self, embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(embd, embd_ffn),
nn.ReLU(),
nn.Linear(embd_ffn, embd), # projection layer going back into the residual pathway
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
""" Transformer Block: communication/sensing followed by computation/update"""
def __init__(self, embd, num_heads):
super().__init__()
self.sa_heads = MultiHead(num_heads, embd//num_heads)
self.ffwd = FeedForward(embd)
self.ln1 = nn.LayerNorm(embd) #should be equivalent to LayerNorm1D
self.ln2 = nn.LayerNorm(embd)
def forward(self, x):
# x = self.sa_heads(x) # apply one head of self-attention (B, T, C) <=> "comunication" or "sense"
# x = self.ffwd(x) # (B, T, C) => this is one a per-token level <=> "update"
x = x + self.sa_heads(self.ln1(x)) # residual connection <=> "highway" of information and residual paths
x = x + self.ffwd(self.ln2(x)) # residual connection
return x
class TransformerModel(nn.Module):
def __init__(self):
super().__init__()
# each token directly reads off the logits for the next token from a lookup table
self.stroke_embedding_proj = nn.Linear(2, embd, bias=False)
self.pen_embedding_table = nn.Embedding(3, embd)
self.position_embedding_table = nn.Embedding(block_size, embd)
self.blocks = nn.Sequential(*[Block(embd, num_heads) for _ in range(n_layers)])
self.ln_f = nn.LayerNorm(embd)
self.mdn_head = MDN(embd, 2, n_components)
self.pen_head = CategoricalNetwork(embd, 3)
def forward(self, x, tau=1.):
B, T, C = x.shape
# assert C == 3
# idx and targets are both (B,T) tensor of integers
stroke_emb = self.stroke_embedding_proj(x[:, :, :2]) # (B,T,2) @ (2, embd) = (B, T, embd)
pen_emb = self.pen_embedding_table(x[:, :, 2].long()) # (B, T, embd)
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) #(T, embd)
x = stroke_emb + pen_emb + pos_emb # (B, T, embd)
# forward through attention heads
x = self.blocks(x) # (B, T, C)
x = self.ln_f(x)
# forward though mdn and head
pi_net, normal_net = self.mdn_head(x, tau=tau)
q_net = self.pen_head(x, tau=tau)
return pi_net, normal_net, q_net
def loss(self, x, targets, mask):
pi, normal, q = self.forward(x)
ys = targets[:, :, :2]
loglik = normal.log_prob(ys.unsqueeze(-2).expand_as(normal.loc))
Ls = -torch.logsumexp(torch.log(pi.probs) + loglik, dim=-1)
Ls *= mask
yp = targets[:, :, 2]
Lp = -q.log_prob(yp)
return Ls + Lp
def sample(self, x, tau=1.):
pi, normal, q = self.forward(x, tau)
s_samples = torch.sum(pi.sample().unsqueeze(-1) * normal.sample(), dim=-2)
p_samples = q.sample()
return torch.cat([s_samples, p_samples.unsqueeze(-1)], dim=-1)
@torch.no_grad()
def generate(self, x, max_new_tokens, tau=1., break_eos=True):
# x is (1, T, 3)
for _ in range(max_new_tokens):
# get the predictions
samples_next = self.sample(x, tau=tau)[:, -1, :].unsqueeze(1)
# append sampled stroke + pen index to the running sequence
x = torch.cat([x, samples_next], dim=1)
# break if end of sketch
if break_eos:
if samples_next[0,0,2] == 2:
return x
return x
model = TransformerModel()
model = model.to(device)
print(f"Model has {sum([p.nelement() for p in model.parameters()])} parameters")
Model has 10739451 parameters
X, Y, mask = get_batch("train")
Y_pred = model.sample(X)
print(X.shape, Y.shape, Y_pred.shape)
torch.Size([64, 129, 3]) torch.Size([64, 129, 3]) torch.Size([64, 129, 3])
@torch.no_grad()
def estimate_loss():
out = {}
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y, mask = get_batch(split)
loss = model.loss(X, Y, mask)
losses[k] = loss.mean()
out[split] = losses.mean()
return out
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
#lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=training_iters)
for iter in range(training_iters):
#every once in a while evaluate the loss on train and val sets
if iter % eval_interval == 0:
model.eval()
# evaluate loss
losses = estimate_loss()
print(f'step {iter}: lr {optimizer.param_groups[0]["lr"]:.6f}, train loss {losses["train"]:.4f}, val loss {losses["val"]:.4f}')
# display random samples at current stage of training
n_samples = 10
svg_samples = [draw_strokes(model.generate(torch.zeros(1, 1, 3).to(device), max_new_tokens=block_size-1, break_eos=True, tau=.4)[0], factor=0.1) for _ in range(n_samples)]
no_wrap_div = '<div style="white-space: nowrap">'+'{}'*n_samples+'</div>'
display(HTML(no_wrap_div.format(*svg_samples)))
model.train()
# sample a batch of data
xb, yb, mask = get_batch('train')
# evaluate the loss
loss = model.loss(xb, yb, mask).mean()
# backward pass
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
#lr_scheduler.step()
step 0: lr 0.000100, train loss 3.1467, val loss 3.1501
step 1200: lr 0.000100, train loss 0.6012, val loss 0.6239
step 2400: lr 0.000100, train loss 0.4337, val loss 0.4464
step 3600: lr 0.000100, train loss 0.4150, val loss 0.4295
step 4800: lr 0.000100, train loss 0.5068, val loss 0.5206
step 6000: lr 0.000100, train loss 0.3361, val loss 0.3572
step 7200: lr 0.000100, train loss 0.3663, val loss 0.3860
step 8400: lr 0.000100, train loss 0.3024, val loss 0.3062
step 9600: lr 0.000100, train loss 0.3049, val loss 0.3309
step 10800: lr 0.000100, train loss 0.5361, val loss 0.5466
step 12000: lr 0.000100, train loss 0.3183, val loss 0.3592
step 13200: lr 0.000100, train loss 0.2766, val loss 0.2947
step 14400: lr 0.000100, train loss 0.2584, val loss 0.2685
step 15600: lr 0.000100, train loss 0.4060, val loss 0.4023
step 16800: lr 0.000100, train loss 0.2405, val loss 0.2707
step 18000: lr 0.000100, train loss 0.3355, val loss 0.3577
step 19200: lr 0.000100, train loss 0.2157, val loss 0.2660
step 20400: lr 0.000100, train loss 0.2235, val loss 0.2534
step 21600: lr 0.000100, train loss 0.4038, val loss 0.4099
step 22800: lr 0.000100, train loss 0.1654, val loss 0.1942
👉 We see the loss as well as example generated samples every 1200 training steps (one step = one batch of 64 sketches).
We can see that before training (step 0), the sketches are very short because the pen action $p=2$ (end of sketch) is sampled uniformely hence too early.
During training, the model shortly learns to draw longer sequences with already curvy shapes reminiscent of face contours (step 1200). As training progresses it seems to learn semblance of eyes, ears, and weird moustaches, although not very coherent yet (step 6000).
After 6000 steps, it becomes more and more evident that the model learns to capture the essential components of a cat sketch: a round face, two pointy ears, two eyes, a nose and some slightly better (yet still weird) moustaches.
Obviously the model is far from perfect and there are still several failed scribbles and funny generalizations of cats...
I've finally let the model train for a total of 24000 steps, and maybe more would have been useful as the the training/validation loss kept going down, although the loss I'm plotting here is quite noisy (estimated on 10 batches only).
Ultimately we can see the kind of sketches we obtain at the end of training. It's not perfect but I'm quite happy with the results and confident that more can be done to improve these doodles even further 🐈.
[Optional] The below code saves the trained weights to your local machine, which can be useful if you do not want to re-train the model each time. Note that we already provide trained weights for the cat
dataclass on the github repository.
from google.colab import files
torch.save(model.state_dict(), f'model_{"_".join(data_classes)}.pth')
files.download(f'model_{"_".join(data_classes)}.pth')
model.eval()
# Evaluate Reconstruction Loss on the full test set
@torch.no_grad()
def test_loss():
out = {}
model.eval()
# convert in stroke-5 with normalizedx-y values to be roughly N(0,I)
xs, ys, lengths = [], [], []
for i in range(len(test_set)):
x, y, p = torch.tensor(test_set[i]).T
x = (x - x_mean) / x_std
y = (y - y_mean) / y_std
p[-1] = 2.
d = torch.stack([x, y, p], -1)
# pad with empty values till block size
xs.append(torch.concatenate([d, torch.tensor([0., 0., 2.]).repeat((block_size-len(d), 1))]))
ys.append(torch.concatenate([d[1:], torch.tensor([0., 0., 2.]).repeat((block_size+1-len(d), 1))]))
lengths.append(len(test_set[i]))
xs, ys, lengths = torch.stack(xs), torch.stack(ys), torch.tensor(lengths)
mask = torch.arange(block_size).expand((len(test_set), block_size)) <= lengths.unsqueeze(1).expand((len(test_set), block_size))
xs, ys, mask = xs.to(device), ys.to(device), mask.to(device)
loss = model.loss(xs, ys, mask)
return loss
print(f"LR: {test_loss().mean()}")
LR: 0.21352873742580414
👉 The loss over the whole test set is $L_R \approx 0.21$.
# Complete Sketch from the test set with the trained model
X, Y, _ = get_batch("test")
n_samples = 5
n_per_sample = 10
t_init = 5
for i in range(n_samples):
svg_samples = [draw_strokes(model.generate(X[i, :t_init].unsqueeze(0).to(device), max_new_tokens=block_size-t_init, break_eos=True, tau=.4)[0], factor=0.1) for _ in range(n_per_sample)]
no_wrap_div = '<div style="white-space: nowrap">'+'{}'*n_per_sample+'</div>'
display(HTML(no_wrap_div.format(*svg_samples)))
# model_state_dict = torch.load("model_cat.pth")
# model.load_state_dict(model_state_dict)
In the github repository you can play with the the generate_samples.py
script which loads the trained model (saved under model_cat.pth
) and generate sketches as SVG images without having to re-run the whole notebook.
Below are example sketches we obtain with a temperature $\tau=0.4$:
Note that the random generations can result in weird sketches, like the one on the left here. It can also, but rarely, draw cats with a body although this is often badly done as not often seen in the training data.
We can also the have the model interact with a human simply by letting a human draw portions of the sketch, and letting the model complete the sequence. Below we provide a "demo code" enabling the human to generate the starting curve (sequence of strokes without lifting the pen) through a drawing interface. We can then let the model complete the drawing sequence to finish the cat.
# Get human's strokes
size = 400
input_strokes = draw(w=size, h=size)
# Process strokes
x_start = torch.tensor(input_strokes).float()
x_start = (x_start - size/2) / torch.tensor([x_std, y_std], dtype=torch.float32)
x_start = torch.concatenate([x_start[0].unsqueeze(0), x_start[1:] - x_start[:-1]])
x_start = torch.concatenate([x_start, torch.zeros((len(x_start), 1))], dim=1).unsqueeze(0).to(device)
print(x_start.shape)
# Feed to model
n_samples = 10
xs = [model.generate(x_start, max_new_tokens=block_size-x_start.shape[1], break_eos=True, tau=.4) for _ in range(n_samples)]
# Display completion
svg_samples = [draw_two_strokes(x_start[0], x[0, x_start.shape[1]:], factor=0.1) for x in xs]
no_wrap_div = '<div style="white-space: nowrap">'+'{}'*n_samples+'</div>'
display(HTML(no_wrap_div.format(*svg_samples)))
torch.Size([1, 36, 3])
👉 Here the human drew a oval shape (shown in black) and the model proposed 10 possible completions (shown in red), again with $\tau=0.4$.
Finally if you are interested in having a robot drawing the results, the simplest way is simply to buy a drawing machine 😀 You could for instance have a look at this one or this one with everything set up to hold a pen and start drawing!
In our case, we opted for another option (see aliexpress link) which requires a bit of fine-tuning as it is not originally intended for this stuff but for laser engraving. It is a CNC machine which you can buy without the laser/engraving frame, and has the advantage of having much bigger working area. However it needs a few steps and a bit of 3d printing to make it work, with all credits going to Antun Skuric for that 😀 : - First you need to add an additional motor to motorise the z-axis, originally the laser engraver did not have it as it does not move in z-axis (but the electronics have everything you need for the z-axis - checked it before buying) - Then you need to 3D-print (or make somehow differently) a pen holder of some kind - Then you might need to add few wooden slats to ensure proper fixing of the paper (need to be horizontal and to not move during drawing)
Once you have your machine, most of them use G-CODE based protocols and the absolute simplest program that we've found online that does the job of sending the commands to the machine from the PC is the UniversalGcode sender (and its also open source, which is nice).
We provide a code below to convert the model outputs into g-code format and download it on your local machine 👇
# Generate and Download G-code file to be printed
def generate_gcodes(xs, filename="sketch.gcode", factor=0.2):
with open(filename, "w") as f:
prev_left = 0
prev_bottom = 0
for i, x in enumerate(xs):
x = x[0]
min_x, max_x, min_y, max_y = get_bounds(x, factor)
abs_x = 10 - min_x + prev_left
abs_y = 10 - min_y + prev_bottom
lift_pen = 1
for dx, dy, p in x:
abs_x += dx/factor
abs_y -= dy/factor
if lift_pen == 1:
f.write(f"G0 Z5\n")
f.write(f"G0 X{abs_x} Y{abs_y}\n")
f.write(f"G0 Z0\n")
elif lift_pen == 0:
f.write(f"G1 X{abs_x} Y{abs_y}\n")
elif lift_pen == 2:
break
else:
raise ValueError
lift_pen = p
f.write(f"G0 Z5\n")
prev_left = prev_left + 10 + (max_x-min_x)
if i == 4:
prev_left = 0
prev_bottom = prev_bottom + 10 + 2*(max_x - min_x)
generate_gcodes(xs, "sketchs.gcode")
files.download("sketchs.gcode")
Awesome, we've reached the end of this tutorial, you should now have everything you need to go and draw those silly cat faces by yourself 😹
Things that I'd like to try next: