Introduction

In the first part I reimplemented the convolutional auto encoder from TimeCluster by Ali et al This time, I will adapt the model to handle all 300 flock agents.

In this notebook a 1D convolutional approach is evaluated

Data Noramlisation and preperation

Normalise the data the same way as before; into the range [0, 1] as per the paper.

We then create a sliding window using the defaults from the paper where stride = 1 and window_size = 60

Then we shuffle the data and split into train, validate and test subsets

function normalise(M) 
    min = minimum(minimum(eachcol(M)))
    max = maximum(maximum(eachcol(M)))
    return (M .- min) ./ (max - min)
end

normalised = Array(df) |> normalise

window_size = 60

data = slidingwindow(normalised',window_size,stride=1)

train, validate, test = splitobs(shuffleobs(data), (0.7,0.2));

Define the encoder and decoder

We can define the network shape in a couple of different ways:

  • Keeping the convolution 1 dimentional and simply increasing the number of features from 3 to 900 (3 * num_of_agents)
  • Using 2D convolution: window_size X num_of_agents x dimensions (3) x batch

In this notebook we will look at the 1D approach.

1D Convolution

Adjusted the dimension expansion from ≈ 21x to 10x Also an aditional Conv/ConvTranspose step is added to reduce the dimensionality of the encoded space further

function create_ae_1d()
  # Define the encoder and decoder networks 
  encoder = Chain(
  # 60x900xb
  Conv((9,), 900 => 9000, relu; pad = SamePad()),
  MaxPool((2,)),
  # 30x9000xb
  Conv((5,), 9000 => 4500, relu; pad = SamePad()),
  MaxPool((2,)),
  # 15x4500xb
  Conv((5,),4500 => 2250, relu; pad = SamePad()),
  # 15x2250xb
  MaxPool((3,)),
  Conv((3,),2250 => 1000, relu; pad = SamePad()),
  Conv((3,),1000 => 100, relu; pad = SamePad()),
  # 5x100xb
  Flux.flatten,
  Dense(500,100)
)
decoder = Chain(
  Dense(100,500),
  (x -> reshape(x, 5,100,:)),
  # 5x100xb
  ConvTranspose((3,), 100  => 1000, relu; pad = SamePad()),
  ConvTranspose((3,), 1000 => 2250, relu; pad = SamePad()),
  Upsample((3,)),
  # 15x2250xb
  ConvTranspose((5,), 2250 => 4500, relu; pad = SamePad()),
  Upsample((2,)),
  # 30x4500xb
  ConvTranspose((5,), 4500 => 9000, relu; pad = SamePad()),
  Upsample((2,)),
  # 60x9000xb
  ConvTranspose((9,), 9000 => 900, relu; pad = SamePad()),
  # 60x900xb
)
return (encoder, decoder)
end
create_ae_1d (generic function with 1 method)

Training

Training needs to be slightly adapted for each version of model we use. Also we now use train/validation/test sets for more accurate performance calculation. I've also added learning rate adjustment and automatic model saving.

function save_model(m, epoch, loss)
    model_row = LegolasFlux.ModelRow(; weights = fetch_weights(cpu(m)),architecture_version=1, loss=0.0001)
    write_model_row("1d_300_model-$epoch-$loss.arrow", model_row)
end

function rearrange_1D(x)
    permutedims(cat(x..., dims=3), [2,1,3])
end

function train_model_1D!(model, train, validate, opt; epochs=20, bs=16, dev=Flux.gpu)
    ps = Flux.params(model)
    local train_loss, train_loss_acc
    local validate_loss, validate_loss_acc
    local last_improvement = 0
    local prev_best_loss = 0.01
    local improvement_thresh = 5.0
    validate_losses = Vector{Float64}()
    for e in 1:epochs
        train_loss_acc = 0.0
        for x in eachbatch(train, size=bs)
            x  = rearrange_1D(x) |> dev
            gs = Flux.gradient(ps) do
                train_loss = Flux.Losses.mse(model(x),x)
                return train_loss
            end
            train_loss_acc += train_loss
            Flux.update!(opt, ps, gs)
        end
        validate_loss_acc = 0.0
        for y in eachbatch(validate, size=bs)
            y  = rearrange_1D(y) |> dev
            validate_loss = Flux.Losses.mse(model(y), y)
            validate_loss_acc += validate_loss
        end
        validate_loss_acc = round(validate_loss_acc / (length(validate)/bs); digits=6)
        train_loss_acc = round(train_loss_acc / (length(train)/bs) ;digits=6)
        if validate_loss_acc < 0.001
            if validate_loss_acc < prev_best_loss
                @info "new best accuracy $validate_loss_acc saving model..."
                save_model(model, e, validate_loss_acc)
                last_improvement = e
                prev_best_loss = validate_loss_acc
            elseif (e - last_improvement) >= improvement_thresh && opt.eta > 1e-5
                @info "Not improved in $improvement_thresh epochs. Dropping learning rate to $(opt.eta / 2.0)"
                opt.eta /= 2.0
                last_improvement = e # give it some time to improve
                improvement_thresh = improvement_thresh * 1.5
            elseif (e - last_improvement) >= 15
                @info "Not improved in 15 epochs. Converged I guess"
                break
            end
        end
        push!(validate_losses, validate_loss_acc)
        println("Epoch $e/$epochs\t train loss: $train_loss_acc\t validate loss: $validate_loss_acc")
    end
    validate_losses
 end
train_model_1D! (generic function with 1 method)
losses_0001       = train_model_1D!(model, train, validate, Flux.Optimise.ADAM(0.0001); epochs=200, bs=48);

┌ Warning: The specified values for size and/or count will result in 21 unused data points
└ @ MLDataPattern /opt/julia/packages/MLDataPattern/KlSmO/src/dataview.jl:205
Epoch 1/200	 train loss: 0.430836	 validate loss: 0.057712
Epoch 2/200	 train loss: 0.055009	 validate loss: 0.051833
Epoch 3/200	 train loss: 0.053515	 validate loss: 0.051719
Epoch 4/200	 train loss: 0.053483	 validate loss: 0.051669
Epoch 5/200	 train loss: 0.053481	 validate loss: 0.051648
Epoch 6/200	 train loss: 0.05348	 validate loss: 0.051622
Epoch 7/200	 train loss: 0.053478	 validate loss: 0.051622
Epoch 8/200	 train loss: 0.053471	 validate loss: 0.051615
Epoch 9/200	 train loss: 0.053464	 validate loss: 0.051602
Epoch 10/200	 train loss: 0.053461	 validate loss: 0.051603
Epoch 11/200	 train loss: 0.053447	 validate loss: 0.051597
Epoch 12/200	 train loss: 0.053427	 validate loss: 0.051574
Epoch 13/200	 train loss: 0.053222	 validate loss: 0.050061
Epoch 14/200	 train loss: 0.040074	 validate loss: 0.033838
Epoch 15/200	 train loss: 0.034933	 validate loss: 0.033222
Epoch 16/200	 train loss: 0.029958	 validate loss: 0.018663
Epoch 17/200	 train loss: 0.015721	 validate loss: 0.012654
Epoch 18/200	 train loss: 0.011821	 validate loss: 0.010214
Epoch 19/200	 train loss: 0.009049	 validate loss: 0.007028
Epoch 20/200	 train loss: 0.006481	 validate loss: 0.005952
Epoch 21/200	 train loss: 0.004894	 validate loss: 0.003345
Epoch 22/200	 train loss: 0.002088	 validate loss: 0.001259
Epoch 23/200	 train loss: 0.001	 validate loss: 0.001243
┌ Info: new best accuracy 0.000762 saving model...
└ @ Main In[9]:41
Epoch 24/200	 train loss: 0.000691	 validate loss: 0.000762
┌ Info: new best accuracy 0.000597 saving model...
└ @ Main In[9]:41
Epoch 25/200	 train loss: 0.000683	 validate loss: 0.000597
┌ Info: new best accuracy 0.000309 saving model...
└ @ Main In[9]:41
Epoch 26/200	 train loss: 0.000584	 validate loss: 0.000309
┌ Info: new best accuracy 0.000142 saving model...
└ @ Main In[9]:41
Epoch 27/200	 train loss: 0.000244	 validate loss: 0.000142
┌ Info: new best accuracy 9.1e-5 saving model...
└ @ Main In[9]:41
Epoch 28/200	 train loss: 0.000118	 validate loss: 9.1e-5
┌ Info: new best accuracy 7.4e-5 saving model...
└ @ Main In[9]:41
Epoch 29/200	 train loss: 8.6e-5	 validate loss: 7.4e-5
┌ Info: new best accuracy 6.3e-5 saving model...
└ @ Main In[9]:41
Epoch 30/200	 train loss: 7.1e-5	 validate loss: 6.3e-5
Epoch 31/200	 train loss: 6.5e-5	 validate loss: 6.5e-5
┌ Info: new best accuracy 5.6e-5 saving model...
└ @ Main In[9]:41
Epoch 32/200	 train loss: 6.3e-5	 validate loss: 5.6e-5
┌ Info: new best accuracy 4.8e-5 saving model...
└ @ Main In[9]:41
Epoch 33/200	 train loss: 5.3e-5	 validate loss: 4.8e-5
┌ Info: new best accuracy 4.2e-5 saving model...
└ @ Main In[9]:41
Epoch 34/200	 train loss: 4.6e-5	 validate loss: 4.2e-5
┌ Info: new best accuracy 3.9e-5 saving model...
└ @ Main In[9]:41
Epoch 35/200	 train loss: 4.2e-5	 validate loss: 3.9e-5
┌ Info: new best accuracy 3.6e-5 saving model...
└ @ Main In[9]:41
Epoch 36/200	 train loss: 3.8e-5	 validate loss: 3.6e-5
┌ Info: new best accuracy 3.3e-5 saving model...
└ @ Main In[9]:41
Epoch 37/200	 train loss: 3.5e-5	 validate loss: 3.3e-5
┌ Info: new best accuracy 3.1e-5 saving model...
└ @ Main In[9]:41
Epoch 38/200	 train loss: 3.3e-5	 validate loss: 3.1e-5
┌ Info: new best accuracy 2.9e-5 saving model...
└ @ Main In[9]:41
Epoch 39/200	 train loss: 3.1e-5	 validate loss: 2.9e-5
┌ Info: new best accuracy 2.7e-5 saving model...
└ @ Main In[9]:41
Epoch 40/200	 train loss: 2.9e-5	 validate loss: 2.7e-5
Epoch 41/200	 train loss: 2.7e-5	 validate loss: 2.8e-5
┌ Info: new best accuracy 2.6e-5 saving model...
└ @ Main In[9]:41
Epoch 42/200	 train loss: 2.7e-5	 validate loss: 2.6e-5
Epoch 43/200	 train loss: 3.0e-5	 validate loss: 2.9e-5
Epoch 44/200	 train loss: 6.2e-5	 validate loss: 8.7e-5
Epoch 45/200	 train loss: 0.000182	 validate loss: 0.000175
Epoch 46/200	 train loss: 0.000311	 validate loss: 0.000224
Epoch 47/200	 train loss: 0.00104	 validate loss: 0.000602
┌ Info: Not improved in 5 epochs. Dropping learning rate to 5.0e-5
└ @ Main In[9]:46
Epoch 48/200	 train loss: 0.000221	 validate loss: 5.7e-5
Epoch 49/200	 train loss: 3.7e-5	 validate loss: 2.6e-5
┌ Info: new best accuracy 2.4e-5 saving model...
└ @ Main In[9]:41
Epoch 50/200	 train loss: 2.5e-5	 validate loss: 2.4e-5
┌ Info: new best accuracy 2.2e-5 saving model...
└ @ Main In[9]:41
Epoch 51/200	 train loss: 2.2e-5	 validate loss: 2.2e-5
┌ Info: new best accuracy 2.1e-5 saving model...
└ @ Main In[9]:41
Epoch 52/200	 train loss: 2.1e-5	 validate loss: 2.1e-5
┌ Info: new best accuracy 2.0e-5 saving model...
└ @ Main In[9]:41
Epoch 53/200	 train loss: 2.0e-5	 validate loss: 2.0e-5
┌ Info: new best accuracy 1.9e-5 saving model...
└ @ Main In[9]:41
Epoch 54/200	 train loss: 1.9e-5	 validate loss: 1.9e-5
┌ Info: new best accuracy 1.8e-5 saving model...
└ @ Main In[9]:41
Epoch 55/200	 train loss: 1.9e-5	 validate loss: 1.8e-5
Epoch 56/200	 train loss: 1.8e-5	 validate loss: 1.8e-5
┌ Info: new best accuracy 1.7e-5 saving model...
└ @ Main In[9]:41
Epoch 57/200	 train loss: 1.8e-5	 validate loss: 1.7e-5
Epoch 58/200	 train loss: 1.7e-5	 validate loss: 1.7e-5
┌ Info: new best accuracy 1.6e-5 saving model...
└ @ Main In[9]:41
Epoch 59/200	 train loss: 1.7e-5	 validate loss: 1.6e-5
Epoch 60/200	 train loss: 1.7e-5	 validate loss: 1.6e-5
Epoch 61/200	 train loss: 1.6e-5	 validate loss: 1.6e-5
Epoch 62/200	 train loss: 1.6e-5	 validate loss: 1.6e-5
┌ Info: new best accuracy 1.5e-5 saving model...
└ @ Main In[9]:41
Epoch 63/200	 train loss: 1.6e-5	 validate loss: 1.5e-5
Epoch 64/200	 train loss: 1.5e-5	 validate loss: 1.5e-5
Epoch 65/200	 train loss: 1.5e-5	 validate loss: 1.5e-5
Epoch 66/200	 train loss: 1.5e-5	 validate loss: 1.5e-5
┌ Info: new best accuracy 1.4e-5 saving model...
└ @ Main In[9]:41
Epoch 67/200	 train loss: 1.5e-5	 validate loss: 1.4e-5
Epoch 68/200	 train loss: 1.4e-5	 validate loss: 1.4e-5
Epoch 69/200	 train loss: 1.4e-5	 validate loss: 1.4e-5
Epoch 70/200	 train loss: 1.4e-5	 validate loss: 1.4e-5
┌ Info: new best accuracy 1.3e-5 saving model...
└ @ Main In[9]:41
Epoch 71/200	 train loss: 1.4e-5	 validate loss: 1.3e-5
Epoch 72/200	 train loss: 1.4e-5	 validate loss: 1.3e-5
Epoch 73/200	 train loss: 1.3e-5	 validate loss: 1.3e-5
Epoch 74/200	 train loss: 1.3e-5	 validate loss: 1.3e-5
Epoch 75/200	 train loss: 1.3e-5	 validate loss: 1.3e-5
┌ Info: new best accuracy 1.2e-5 saving model...
└ @ Main In[9]:41
Epoch 76/200	 train loss: 1.3e-5	 validate loss: 1.2e-5
Epoch 77/200	 train loss: 1.2e-5	 validate loss: 1.2e-5
Epoch 78/200	 train loss: 1.2e-5	 validate loss: 1.2e-5
Epoch 79/200	 train loss: 1.2e-5	 validate loss: 1.2e-5
Epoch 80/200	 train loss: 1.2e-5	 validate loss: 1.2e-5
┌ Info: new best accuracy 1.1e-5 saving model...
└ @ Main In[9]:41
Epoch 81/200	 train loss: 1.2e-5	 validate loss: 1.1e-5
Epoch 82/200	 train loss: 1.1e-5	 validate loss: 1.1e-5
Epoch 83/200	 train loss: 1.1e-5	 validate loss: 1.1e-5
Epoch 84/200	 train loss: 1.1e-5	 validate loss: 1.1e-5
Epoch 85/200	 train loss: 1.1e-5	 validate loss: 1.1e-5
Epoch 86/200	 train loss: 1.1e-5	 validate loss: 1.1e-5
Epoch 87/200	 train loss: 1.1e-5	 validate loss: 1.1e-5
Epoch 88/200	 train loss: 1.1e-5	 validate loss: 1.1e-5
Epoch 89/200	 train loss: 1.1e-5	 validate loss: 1.1e-5
┌ Info: Not improved in 5 epochs. Dropping learning rate to 2.5e-5
└ @ Main In[9]:46
┌ Info: new best accuracy 1.0e-5 saving model...
└ @ Main In[9]:41
Epoch 90/200	 train loss: 1.0e-5	 validate loss: 1.0e-5
Epoch 91/200	 train loss: 1.0e-5	 validate loss: 1.0e-5
Epoch 92/200	 train loss: 1.0e-5	 validate loss: 1.0e-5
Epoch 93/200	 train loss: 1.0e-5	 validate loss: 1.0e-5
Epoch 94/200	 train loss: 1.0e-5	 validate loss: 1.0e-5
Epoch 95/200	 train loss: 1.0e-5	 validate loss: 1.0e-5
Epoch 96/200	 train loss: 1.0e-5	 validate loss: 1.0e-5
┌ Info: new best accuracy 9.0e-6 saving model...
└ @ Main In[9]:41
Epoch 97/200	 train loss: 1.0e-5	 validate loss: 9.0e-6
Epoch 98/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
Epoch 99/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
Epoch 100/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
Epoch 101/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
Epoch 102/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
Epoch 103/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
Epoch 104/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
Epoch 105/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
Epoch 106/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
Epoch 107/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
Epoch 108/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
Epoch 109/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
┌ Info: Not improved in 5 epochs. Dropping learning rate to 1.25e-5
└ @ Main In[9]:46
Epoch 110/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
Epoch 111/200	 train loss: 9.0e-6	 validate loss: 9.0e-6
┌ Info: new best accuracy 8.0e-6 saving model...
└ @ Main In[9]:41
Epoch 112/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 113/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 114/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 115/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 116/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 117/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 118/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 119/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 120/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 121/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 122/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 123/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 124/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 125/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
Epoch 126/200	 train loss: 8.0e-6	 validate loss: 8.0e-6
┌ Info: Not improved in 10 epochs. Converged I guess
└ @ Main In[9]:51

Results

We managed to get a validation loss of 8.0e-6, taking a 900x60 space down into a 1x100 vector

Lets see what this means in practical terms by comparing an input to an output:

test_data = rand(test)
create_gif_from_raw(test_data)
┌ Info: Saved animation to 
│   fn = /notebooks/anim_fps30.gif
└ @ Plots /opt/julia/packages/Plots/LSKOd/src/animation.jl:114
input = Flux.unsqueeze(test_data', 3)
output = new_model(input)
output = reshape(output, 60,900)'
create_gif_from_raw(output)
┌ Info: Saved animation to 
│   fn = /notebooks/anim_fps30.gif
└ @ Plots /opt/julia/packages/Plots/LSKOd/src/animation.jl:114

Conclusion

After a few hours of training on GPU, we can now reasonably encode the movement of the whole swarm (300 agents) over 60 timesteps into 100 variables. However, I want to reduce that encoding even further into ~10 parameters that can be used to sonify the dynamics.

Next time I will see how much further a can reduce the latent space - as well as seeing how useful other DR methods are when applied to the latent space.