Senior Research Software Engineer
ICCS - University of Cambridge
2025-07-02
To access links or follow on your own device these slides can be found at:
jackatkinson.net/slides
Except where otherwise noted, these presentation materials are licensed under the Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) License.
Vectors and icons by SVG Repo under CC0(1.0) or FontAwesome under SIL OFL 1.1
Large, complex, many-part systems.
Subgrid processes are largest source of uncertainty
Microphysics by Sisi Chen Public Domain
Staggered grid by NOAA under Public Domain
Globe grid with box by Caltech under Fair use
Subgrid processes are largest source of uncertainty
Microphysics by Sisi Chen Public Domain
Staggered grid by NOAA under Public Domain
Globe grid with box by Caltech under Fair use
Neural Net by 3Blue1Brown under fair dealing.
Pikachu © The Pokemon Company, used under fair dealing.
Many large scientific models are written in Fortran (or C, or C++).
Much machine learning is conducted in Python.
Mathematical Bridge by cmglee used under CC BY-SA 3.0
PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation.”
TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.
We consider 2 types:
Computational
Developer
At the academic end of research both have an equal effect on ‘time-to-science’.
Especially when extensive research software support is unavailable.
iso_c_binding
.We will:
libtorch
C++libtorch
C++ APIPython
env
Python
runtime
xkcd #1987 by Randall Munroe, used under CC BY-NC 2.5
Build using CMake,
or link via Make like NetCDF (instructions included)
FCFLAGS += -I<path/to/install>/include/ftorch
LDFLAGS += -L<path/to/install>/lib64 -lftorch
Find it on :
pt2ts.py
aids users in saving PyTorch models to TorchscriptFind it on :
Find it on :
Find it on :
import torch
import torchvision
# Load pre-trained model and put in eval mode
model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
model.eval()
# Create dummmy input
dummy_input = torch.ones(1, 3, 224, 224)
# Trace model and save
traced_model = torch.jit.trace(model, dummy_input)
frozen_model = torch.jit.freeze(traced_model)
frozen_model.save("/path/to/saved_model.pt")
TorchScript
use ftorch
implicit none
real, dimension(5), target :: in_data, out_data ! Fortran data structures
type(torch_tensor), dimension(1) :: input_tensors, output_tensors ! Set up Torch data structures
type(torch_model) :: torch_net
integer, dimension(1) :: tensor_layout = [1]
in_data = ... ! Prepare data in Fortran
! Create Torch input/output tensors from the Fortran arrays
call torch_tensor_from_array(input_tensors(1), in_data, torch_kCPU)
call torch_tensor_from_array(output_tensors(1), out_data, torch_kCPU)
call torch_model_load(torch_net, 'path/to/saved/model.pt', torch_kCPU) ! Load ML model
call torch_model_forward(torch_net, input_tensors, output_tensors) ! Infer
call further_code(out_data) ! Use output data in Fortran immediately
! Cleanup
call torch_delete(model)
call torch_delete(in_tensors)
call torch_delete(out_tensor)
Cast Tensors to GPU in Fortran:
! Load in from Torchscript
call torch_model_load(torch_net, 'path/to/saved/model.pt', torch_kCUDA, device_index=0)
! Cast Fortran data to Tensors
call torch_tensor_from_array(in_tensor(1), in_data, tensor_layout, torch_kCUDA, device_index=0)
call torch_tensor_from_array(out_tensor(1), out_data, tensor_layout, torch_kCPU)
Effective HPC simulation requires MPI_Gather()
for efficient data transfer.
forpy
in Espinosa et al. (2022)1
libtorch
is included on the software stack on Derecho
Derecho by NCAR
Work by Will Chapman of NCAR/M2LInES
As representations of physics models have inherent, sometimes systematic, biases.
Run CESM for 9 years relaxing hourly to ERA5 observation (data assimilation)
Train CNN to predict anomaly increment at each level
Apply online as part of predictive runs
To replace a BiCGStab bottleneck in the GloSea6 Seasonal Forecasting model See Park and Chung (2025) - DOI: 10.3390/atmos16010060
Implementation of nonlinear interactions in the WaveWatch III model. Ikuyajolu et al. Preprint at essopenarchive.org
Implementation of a new convection trigger in the CAM model. Miller et al. In Preparation.
Our own paper in JOSS. Please cite if you use FTorch! FTorch: a library for coupling PyTorch models to Fortran, Atkinson et al. (2025) DOI: 10.21105/joss.07602
Suppose we want to use a loss function involving downstream model code, e.g.,
\[J(\theta)=\int_\Omega(u-u_{ML}(\theta))^2\;\mathrm{d}x,\]
where \(u\) is the solution from the physical model and \(u_{ML}(\theta)\) is the solution from a hybrid model with some ML parameters \(\theta\).
Computing \(\mathrm{d}J/\mathrm{d}\theta\) requires differentiating Fortran code as well as ML code.
autograd
functionality from Torch.
requires_grad
argument and backward
methods.=
,+
,-
,*
,/
,**
). interface operator (*)
module procedure torch_tensor_multiply
end interface
!> Overloads multiplication operator for two tensors.
function torch_tensor_multiply(tensor1, tensor2) result(output)
use, intrinsic :: iso_c_binding, only : c_associated
type(torch_tensor), intent(in) :: tensor1 !! First tensor to be multiplied
type(torch_tensor), intent(in) :: tensor2 !! Second tensor to be multiplied
type(torch_tensor) :: output !! Tensor to hold the product
! [CC interface definition]
call torch_tensor_multiply_c(output%p, tensor1%p, tensor2%p)
end function torch_tensor_multiply
"""Based on https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html."""
import torch
# Construct input tensors with requires_grad=True
a = torch.tensor([2.0, 3.0], requires_grad=True)
b = torch.tensor([6.0, 4.0], requires_grad=True)
# Compute some mathematical expression
Q = 3 * (a**3 - b * b / 3)
# Reverse mode
Q.backward(gradient=torch.ones_like(Q))
print(a.grad)
print(b.grad)
use ftorch
type(torch_tensor) :: a, b, Q, multiplier, divisor, dQda, dQdb
real, dimension(2), target :: Q_arr, dQda_arr, dQdb_arr
! Construct input tensors with requires_grad=.true.
call torch_tensor_from_array(a, [2.0, 3.0], torch_kCPU, requires_grad=.true.)
call torch_tensor_from_array(b, [6.0, 4.0], torch_kCPU, requires_grad=.true.)
! Workaround for scalar multiplication and division using 0D tensors
call torch_tensor_from_array(multiplier, [3.0], torch_kCPU)
call torch_tensor_from_array(divisor, [3.0], torch_kCPU)
! Compute some mathematical expression
call torch_tensor_from_array(Q, Q_arr, torch_kCPU)
Q = multiplier * (a**3 - b * b / divisor)
! Reverse mode
call torch_tensor_backward(Q)
call torch_tensor_from_array(dQda, dQda_arr, torch_kCPU)
call torch_tensor_from_array(dQdb, dQdb_arr, torch_kCPU)
call torch_tensor_get_gradient(a, dQda)
call torch_tensor_get_gradient(b, dQdb)
print *, dQda_arr
print *, dQdb_arr
\[\begin{bmatrix}f_1\\f_2\\f_3\\f_4\end{bmatrix}=\mathbf{f}(\mathbf{x};\mathbf{a})=\mathbf{a}\bullet\mathbf{x}\equiv\begin{bmatrix}a_1x_1\\a_2x_2\\a_3x_3\\a_4x_4\end{bmatrix}\] Starting from \(\mathbf{a}=\mathbf{x}:=\begin{bmatrix}1,1,1,1\end{bmatrix}^T\), optimise the \(\mathbf{a}\) vector such that \(\mathbf{f}(\mathbf{x};\mathbf{a})=\mathbf{b}:=\begin{bmatrix}1,2,3,4\end{bmatrix}^T\).
Loss function: \(\ell(\mathbf{a})=\overline{(\mathbf{f}(\mathbf{x};\mathbf{a})-\mathbf{b})^2}\).
torch_tensor_get_gradient
after each call to torch_tensor_backward
or torch_tensor_zero_grad
.
torch::optim::Adam
, torch::optim::SGD
, etc., as well as their zero_grad
and step
methods.torch_tensor_sum
and torch_tensor_mean
, though.losses
In both cases we achieve \(\mathbf{f}(\mathbf{x};\mathbf{a})=\begin{bmatrix}1,2,3,4\end{bmatrix}^T\).
autograd
functionality exposed using iso_c_binding
.Get in touch:
The ICCS received support from
Wilkes3 (CSD3)
Observations:
MPI_gather
to reduce overheadsFollowing the comparisons and MiMA experiments we performed detailed benchmarking to examine the library performance.