Importing the Relevant Libraries
Make sure to remove CUDA
if you don’t have an Nvidia GPU.
using Plots, Flux, LinearAlgebra, CUDA, ProgressMeter
using Flux: @functor
using Statistics: mean
using MLDatasets: MNIST
Defining the Neural Network
Creating a Module Instance
Julia does not offer class-based object-orientation like Python does, so if you’re coming from PyTorch, then things look quite a bit different β but also not completely different. The fundamental paradigm of Julia’s design is multiple-dispatch, so pretty much everything we heavily rely on function-overloading. If those two terms are foreign to you, I recommend prompting your favorite LLM or search engine about them before continuing.
First thing we need to do is to create a struct to hold our layers/submodules. Let’s as an example loook at a block of a residual neural network:
struct ResNetBlock
conv1
conv2
skipconv
end
Julia will already offer you a default inner constructor for this struct by calling ResNetBlock(arg1, arg2, arg3)
, but we will create an additional outer constructor that makes sure that arg1
, arg2
and arg3
are actually the layers that we want to end up in the struct’s attributes. An outer constructor is not really a built-in property of the language, but it is simply an overloaded function of the inner constructor that returns an instance of the struct. Since we did not specify what type the struct’s attributes should have, the signature of the inner constructor would look like this, taking any types of input arguments and returning an instance of ResNetBlock
, i.e. ResNetBlock(::Any, ::Any, ::Any)::ResNetBlock
.
We will now overload this by defining the following outer constructor, which takes two Integer
(the super-type for any integer type in Julia) and returns an instance of ResNetBlock
by calling the inner constructor and passing the layers that we actually want to have in those attributes. Also realize that return
is not needed in Julia, but could be added for readability.
function ResNetBlock(in_channels::Integer, out_channels::Integer)
ResNetBlock(
Conv((3,3), in_channels => out_channels, pad=SamePad()),
Conv((3,3), out_channels => out_channels, pad=SamePad()),
Conv((1,1), in_channels => out_channels, pad=SamePad())
)
end
Output:
ResNetBlock
Making our Module behave like a Function
In PyTorch we would use __call__
or forward
functions to make objects callable. In Julia we can define function-like objects or functors by creating methods that have the instance of the struct as the method signature. We also type-annotate the passed argument as some kind of array with single-precision floats of dimension 4 β AbstractArray
is the super-type of all Julia arrays, such as the standard Array
, but also the CuArray
for Nvidia GPUs.
function (m::ResNetBlock)(x::AbstractArray{Float32, 4})
skip = m.skipconv(x)
x = relu(m.conv1(x))
x = m.conv2(x) + skip
return relu(x)
end
As the final step we need to ensure that our ResNetBlock
behaves like a neural network module. This means that it should be able to recursively load all parameters so that we can either differentiate with respect to them (for training), or that we can copy all of them together between our CPU and our GPU. In PyTorch, this would be taken care of by subclassing nn.Module
, in Flux, this is taken care of by registering our module as a @functor
.
@functor ResNetBlock
Repeating the same for the final network structure
struct ResNetMNIST
convs
pool
fc
end
function ResNetMNIST()
conv1 = ResNetBlock(1, 64)
conv2 = ResNetBlock(64, 128)
conv3 = ResNetBlock(128, 256)
pool = MaxPool((2,2), pad=SamePad())
out_feat = 256 * 4 * 4
fc = Dense(out_feat => 10)
return ResNetMNIST(
[conv1, conv2, conv3],
pool,
fc
)
end
function (m::ResNetMNIST)(x::AbstractArray{Float32, 4})
for conv in m.convs
x = m.pool(conv(x))
end
bs = size(x, ndims(x))
other = div(length(x), size(x,4))
return m.fc(reshape(x, (other, bs)))
end
@functor ResNetMNIST
Loading the Dataset
This will look very familiar to the PyTorch user, except maybe that we use a weird :train
statement to load our training split. :train
is of type Symbol
, bu explaining Julia’s symbols is out of the scope of this article. One important thing to note here is that the array dimension order in Flux will be inverted compared to PyTorch: PyTorch uses the B C H W orientation, whereas Flux uses W H C B.
bs = 8
train_ds = MNIST(:train)
test_ds = MNIST(:test)
train_dl = Flux.DataLoader(dataset, batchsize=bs, shuffle=true)
test_dl = Flux.DataLoader(dataset, batchsize=bs, shuffle=false)
Output:
7500-element DataLoader(::MNIST, batchsize=8)
with first element:
(; features = 28Γ28Γ8 Array{Float32, 3}, targets = 8-element Vector{Int64})
The Training Loop
Julia has an emphasis on the functional programming paradigm and Flux’s autodifferntiation engine Zygote also makes use of that. Therefore the syntax of computing gradients is much more similar to JAX than to PyTorch. You need to remind yourself of the mathematical definition of our optimization problem: We want to find the optimal parameters of our model by minimizing a loss between the true class distribution of our data $p(c|x)$ and our predicted class distribution $NN(c|\theta, x)$.
$$ \theta^\star = \argmin_{\theta}\mathcal{L} (f(c|\theta, x), p(c|x)) $$We do this by gradient descent, so in every optimization step we need to calculate
$$ \frac{\partial \mathcal{L} (f(c|\theta, x), p(c|x))}{\partial \theta} $$which is why we need to create a function that takes the model (the parameters $\theta$) as input and returns the loss, so that we can then differentiate that function by the model.
$$ \theta \mapsto \mathcal{L} (f(c|\theta, x), p(c|x)) $$In Julia we can define such a function in a single line (assuming that img
and target
are variables within the scope of the function)
loss(m) = Flux.logitcrossentropy(m(img), target)
For any function, e.g. func(arg1, arg2)
, Flux allows you to calculate gradients with respect to its passed parameter as
grad1, grad2 = Flux.gradient(func, arg1, arg2)
and since our function only takes the model as its parameters, we can get the gradients as
model = ResNetMNIST()
grads_theta = Flux.gradient(loss, model)
or additionally we can get the loss value returned by using Flux.withgradient(loss, model)
. With this knowledge we are now able to formulate our training loop.
model = ResNetMNIST() |> gpu
optim = Flux.setup(Flux.Adam(0.0001), model)
epochs = 30
train_losses = []
test_losses = []
classes = 0:9
@showprogress for epoch in 1:epochs
train_losses_step = []
for (i,x) in enumerate(dataloader)
img = x.features |> gpu
img = reshape(img, (28, 28, 1, size(img)[end]))
target = x.targets
target = Flux.onehotbatch(target, classes) |> gpu
loss(m) = Flux.logitcrossentropy(m(img), target)
loss, grads = Flux.withgradient(loss, model)
Flux.update!(optim, model, grads[1])
push!(train_losses_step, loss |> cpu)
end
push!(train_losses, mean(train_losses_step))
test_losses_step = []
for (i,x) in enumerate(dataloader)
img = x.features |> gpu
img = reshape(img, (28, 28, 1, size(img)[end]))
target = x.targets
target = Flux.onehotbatch(target, classes) |> gpu
loss(m) = Flux.logitcrossentropy(m(img), target)
loss = loss(model)
push!(test_losses_step, loss |> cpu)
end
push!(test_losses, mean(test_losses_step))
end
Output:
[32mProgress: 100%|βββββββββββββββββββββββββββββββββββββββββ| Time: 0:03:50[39m[K
p1 = plot(train_losses, xlabel="steps", ylabel="loss", label="train loss")
p2 = plot(test_losses, xlabel="steps", ylabel="loss", label="test loss")
plot(p1, p2, layout = 2, fmt=:png, size=(900, 400))
Output:
<img src="" />