FTorch - facilitating Hybrid Modelling

Jack Atkinson

Senior Research Software Engineer
ICCS - University of Cambridge

Joe Wallwork

Senior Research Software Engineer
ICCS - University of Cambridge

2025-07-02

Precursors

Slides and Materials

To access links or follow on your own device these slides can be found at:
jackatkinson.net/slides

Licensing

Except where otherwise noted, these presentation materials are licensed under the Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) License.

Vectors and icons by SVG Repo under CC0(1.0) or FontAwesome under SIL OFL 1.1

Motivation

Weather and Climate Models

Large, complex, many-part systems.

Parameteristion

Subgrid processes are largest source of uncertainty

Microphysics by Sisi Chen Public Domain
Staggered grid by NOAA under Public Domain
Globe grid with box by Caltech under Fair use

Parameteristion

Subgrid processes are largest source of uncertainty

Microphysics by Sisi Chen Public Domain
Staggered grid by NOAA under Public Domain
Globe grid with box by Caltech under Fair use

Machine Learning in Science

Neural Net by 3Blue1Brown under fair dealing.
Pikachu © The Pokemon Company, used under fair dealing.

Challenges

  • Reproducibility
    • Ensure net functions the same in-situ
  • Re-usability
    • Make ML parameterisations available to many models
    • Facilitate easy re-training/adaptation
  • Language Interoperation

Language interoperation

Many large scientific models are written in Fortran (or C, or C++).
Much machine learning is conducted in Python.

Mathematical Bridge by cmglee used under CC BY-SA 3.0
PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation.”
TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.

Possible solutions

  • Implement a NN in Fortran
    • Additional work, reproducibility issues, hard for complex architectures
  • Forpy
    • Easy to add, harder to use with ML, GPL, barely-maintained
  • SmartSim
    • Python ‘control centre’ around Redis: generic/versatile, learning curve, data copying
  • Fortran-Keras Bridge
    • Keras only, abandonware(?)

Efficiency

We consider 2 types:

Computational

Developer

At the academic end of research both have an equal effect on ‘time-to-science’.
Especially when extensive research software support is unavailable.

FTorch

Approach

  • PyTorch has a C++ backend and provides an API.
  • Binding Fortran to C is straightforward1 from 2003 using iso_c_binding.

We will:

  • Save the PyTorch models in a portable Torchscript format
    • to be run by libtorch C++
  • Provide a Fortran API
    • wrapping the libtorch C++ API
    • abstracting complex details from users

Approach

Python
env

Python
runtime

xkcd #1987 by Randall Munroe, used under CC BY-NC 2.5

Highlights - Developer

  • Easy to clone and install
    • CMake, supported on linux/unix and Windows™
  • Easy to link
    • Build using CMake,

    • or link via Make like NetCDF (instructions included)

      FCFLAGS += -I<path/to/install>/include/ftorch
      LDFLAGS += -L<path/to/install>/lib64 -lftorch

Find it on :

/Cambridge-ICCS/FTorch

Highlights - Developer

  • User tools
    • pt2ts.py aids users in saving PyTorch models to Torchscript
  • Examples suite
    • Take users through full process from trained net to Fortran inference
  • FOSS
    • licensed under MIT
    • contributions from users via GitHub welcome

Find it on :

/Cambridge-ICCS/FTorch

Highlights - Computation

  • Use framework’s implementations directly
    • feature and future support, and reproducible
  • Make use of the Torch backends for GPU offload
    • CUDA, HIP, MPS, and XPU enabled
  • No-copy access in memory (on CPU).
  • Indexing issues and associated reshape1 avoided with Torch strided accessor.

Find it on :

/Cambridge-ICCS/FTorch

Highlights - Computation

  • No-copy access in memory (on CPU).
  • Indexing issues and associated reshape1 avoided with Torch strided accessor.

Find it on :

/Cambridge-ICCS/FTorch

Some code

Model - Saving from Python

import torch
import torchvision

# Load pre-trained model and put in eval mode
model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
model.eval()

# Create dummmy input
dummy_input = torch.ones(1, 3, 224, 224)

# Trace model and save
traced_model = torch.jit.trace(model, dummy_input)
frozen_model = torch.jit.freeze(traced_model)
frozen_model.save("/path/to/saved_model.pt")

TorchScript

  • Statically typed subset of Python
  • Read by the Torch C++ interface (or any Torch API)
  • Produces intermediate representation/graph of NN
    • Including weights and biases
  • Trace for simple models, script also available

Fortran

 use ftorch
 
 implicit none
 
 real, dimension(5), target :: in_data, out_data  ! Fortran data structures
 
 type(torch_tensor), dimension(1) :: input_tensors, output_tensors  ! Set up Torch data structures
 type(torch_model) :: torch_net
 integer, dimension(1) :: tensor_layout = [1]
 
 in_data = ...  ! Prepare data in Fortran
 
 ! Create Torch input/output tensors from the Fortran arrays
 call torch_tensor_from_array(input_tensors(1), in_data, torch_kCPU)
 call torch_tensor_from_array(output_tensors(1), out_data, torch_kCPU)
 
 call torch_model_load(torch_net, 'path/to/saved/model.pt', torch_kCPU)  ! Load ML model
 call torch_model_forward(torch_net, input_tensors, output_tensors)      ! Infer
 
 call further_code(out_data)  ! Use output data in Fortran immediately
 
 ! Cleanup
call torch_delete(model)
call torch_delete(in_tensors)
call torch_delete(out_tensor)

GPU Acceleration

Cast Tensors to GPU in Fortran:

! Load in from Torchscript
call torch_model_load(torch_net, 'path/to/saved/model.pt', torch_kCUDA, device_index=0)

! Cast Fortran data to Tensors
call torch_tensor_from_array(in_tensor(1), in_data, tensor_layout, torch_kCUDA, device_index=0)
call torch_tensor_from_array(out_tensor(1), out_data, tensor_layout, torch_kCPU)



Effective HPC simulation requires MPI_Gather() for efficient data transfer.

Applications and Case Studies

MiMA - proof of concept

  • The origins of FTorch
    • Emulation of existing parameterisation
    • Coupled to an atmospheric model using forpy in Espinosa et al. (2022)1
    • Prohibitively slow and hard to implement
    • Asked for a faster, user-friendly implementation that can be used in future studies.


  • Follow up paper using FTorch: Uncertainty Quantification of a Machine Learning Subgrid-Scale Parameterization for Atmospheric Gravity Waves (Mansfield and Sheshadri 2024)
    • “Identical” offline networks have very different behaviours when deployed online.

ICON

  • Icosahedral Nonhydrostatic Weather and Climate Model
    • Developed by DKRZ (Deutsches Klimarechenzentrum)
    • Used by the DWD and Meteo-Swiss
  • Interpretable multiscale Machine Learning-Based Parameterizations of Convection for ICON (Heuer et al. 2023)1
    • Train U-Net convection scheme on high-res simulation
    • Deploy in ICON via FTorch coupling
    • Evaluate physical realism (causality) using SHAP values
    • Online stability improved when non-causal relations are eliminated from the net

CESM coupling

  • The Community Earth System Model
  • Part of CMIP (Coupled Model Intercomparison Project)
  • Make it easy for users
    • FTorch integrated into the build system (CIME)
    • libtorch is included on the software stack on Derecho
      • Improves reproducibility

Derecho by NCAR

CESM - Bias Correction

Work by Will Chapman of NCAR/M2LInES

  • As representations of physics models have inherent, sometimes systematic, biases.

  • Run CESM for 9 years relaxing hourly to ERA5 observation (data assimilation)

  • Train CNN to predict anomaly increment at each level

    • targeting just the MJO region
    • targeting globally
  • Apply online as part of predictive runs

  • Low hanging fruit: Don’t load model (with all its weights) at every timestep!

Others

  • To replace a BiCGStab bottleneck in the GloSea6 Seasonal Forecasting model See Park and Chung (2025) - DOI: 10.3390/atmos16010060

  • Implementation of nonlinear interactions in the WaveWatch III model. Ikuyajolu et al. Preprint at essopenarchive.org

  • Implementation of a new convection trigger in the CAM model. Miller et al. In Preparation.

  • Our own paper in JOSS. Please cite if you use FTorch! FTorch: a library for coupling PyTorch models to Fortran, Atkinson et al. (2025) DOI: 10.21105/joss.07602

Online Training and Autograd

Pros and Cons

  • +Avoids saving large volumes of training data.
  • +Avoids need to convert between Python and Fortran data formats.
  • +Possibility to expand loss function scope to include downstream model code.
  • -Difficult to implement in most frameworks.

Expanded Loss function

Suppose we want to use a loss function involving downstream model code, e.g.,

\[J(\theta)=\int_\Omega(u-u_{ML}(\theta))^2\;\mathrm{d}x,\]

where \(u\) is the solution from the physical model and \(u_{ML}(\theta)\) is the solution from a hybrid model with some ML parameters \(\theta\).

Computing \(\mathrm{d}J/\mathrm{d}\theta\) requires differentiating Fortran code as well as ML code.

Implementing AD in FTorch

  • Expose autograd functionality from Torch.
    • e.g., requires_grad argument and backward methods.
  • Overload mathematical operators (=,+,-,*,/,**).
  interface operator (*)
    module procedure torch_tensor_multiply
  end interface
  
  !> Overloads multiplication operator for two tensors.
  function torch_tensor_multiply(tensor1, tensor2) result(output)
    use, intrinsic :: iso_c_binding, only : c_associated
    type(torch_tensor), intent(in) :: tensor1  !! First tensor to be multiplied
    type(torch_tensor), intent(in) :: tensor2  !! Second tensor to be multiplied
    type(torch_tensor) :: output               !! Tensor to hold the product

    ! [CC interface definition]
    
    call torch_tensor_multiply_c(output%p, tensor1%p, tensor2%p)
  end function torch_tensor_multiply

Using AD - PyTorch

"""Based on https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html."""

import torch

# Construct input tensors with requires_grad=True
a = torch.tensor([2.0, 3.0], requires_grad=True)
b = torch.tensor([6.0, 4.0], requires_grad=True)

# Compute some mathematical expression
Q = 3 * (a**3 - b * b / 3)

# Reverse mode
Q.backward(gradient=torch.ones_like(Q))
print(a.grad)
print(b.grad)

Using AD - FTorch

use ftorch

type(torch_tensor) :: a, b, Q, multiplier, divisor, dQda, dQdb
real, dimension(2), target :: Q_arr, dQda_arr, dQdb_arr

! Construct input tensors with requires_grad=.true.
call torch_tensor_from_array(a, [2.0, 3.0], torch_kCPU, requires_grad=.true.)
call torch_tensor_from_array(b, [6.0, 4.0], torch_kCPU, requires_grad=.true.)

! Workaround for scalar multiplication and division using 0D tensors
call torch_tensor_from_array(multiplier, [3.0], torch_kCPU)
call torch_tensor_from_array(divisor, [3.0], torch_kCPU)

! Compute some mathematical expression
call torch_tensor_from_array(Q, Q_arr, torch_kCPU)
Q = multiplier * (a**3 - b * b / divisor)

! Reverse mode
call torch_tensor_backward(Q)
call torch_tensor_from_array(dQda, dQda_arr, torch_kCPU)
call torch_tensor_from_array(dQdb, dQdb_arr, torch_kCPU)
call torch_tensor_get_gradient(a, dQda)
call torch_tensor_get_gradient(b, dQdb)
print *, dQda_arr
print *, dQdb_arr

Putting it together - running an optimiser in FTorch

\[\begin{bmatrix}f_1\\f_2\\f_3\\f_4\end{bmatrix}=\mathbf{f}(\mathbf{x};\mathbf{a})=\mathbf{a}\bullet\mathbf{x}\equiv\begin{bmatrix}a_1x_1\\a_2x_2\\a_3x_3\\a_4x_4\end{bmatrix}\] Starting from \(\mathbf{a}=\mathbf{x}:=\begin{bmatrix}1,1,1,1\end{bmatrix}^T\), optimise the \(\mathbf{a}\) vector such that \(\mathbf{f}(\mathbf{x};\mathbf{a})=\mathbf{b}:=\begin{bmatrix}1,2,3,4\end{bmatrix}^T\).

Loss function: \(\ell(\mathbf{a})=\overline{(\mathbf{f}(\mathbf{x};\mathbf{a})-\mathbf{b})^2}\).

Gradients, Optimizers, and Loss

  • Gradients
    • Need to call torch_tensor_get_gradient after each call to torch_tensor_backward or torch_tensor_zero_grad.
      • Due to pointer management on C++ side, probably avoidable.
  • Optimizers
    • Expose torch::optim::Adam, torch::optim::SGD, etc., as well as their zero_grad and step methods.
    • This already enables some cool AD applications in FTorch.
  • Loss functions
    • We haven’t exposed any built-in loss functions yet.
    • Implemented torch_tensor_sum and torch_tensor_mean, though.

Putting it together - running an optimiser in FTorch

losses

In both cases we achieve \(\mathbf{f}(\mathbf{x};\mathbf{a})=\begin{bmatrix}1,2,3,4\end{bmatrix}^T\).

Case study - UKCA

  • Implicit timestepping, quasi-Newton, full LU decomposition.
  • For each time subinterval to be integrated:
    • Start with \(\Delta t=3600\).
    • Try to integrate with the current timestep size.
    • If any grid-box fails, half the step and try again.

Summary

  • Use of ML within traditional numerical models
    • A growing area that presents challenges
  • Language interoperation
    • FTorch provides a solution for scientists looking to implement torch models in Fortran
    • Designed with both computational and developer efficiency in mind
    • Has helped deliver science in climate research and beyond
      (Heuer et al. (2023), Mansfield and Sheshadri (2024))
    • Built into CESM to allow the userbase access
  • FTorch is exploring options for online training and AD
    • Torch autograd functionality exposed using iso_c_binding.
    • Exposed tools for optimization.
    • Work in progress on setting up online ML training.

Thanks for Listening

The ICCS received support from

References

Espinosa, Zachary I, Aditi Sheshadri, Gerald R Cain, Edwin P Gerber, and Kevin J DallaSanta. 2022. “Machine Learning Gravity Wave Parameterization Generalizes to Capture the QBO and Response to Increased CO2.” Geophysical Research Letters 49 (8): e2022GL098174.
Heuer, Helge, Mierk Schwabe, Pierre Gentine, Marco A Giorgetta, and Veronika Eyring. 2023. “Interpretable Multiscale Machine Learning-Based Parameterizations of Convection for ICON.” arXiv Preprint arXiv:2311.03251.
Mansfield, Laura A, and Aditi Sheshadri. 2024. “Uncertainty Quantification of a Machine Learning Subgrid-Scale Parameterization for Atmospheric Gravity Waves.” Authorea Preprints.

Strong Scaling

Wilkes3 (CSD3)

  • 3rd Generation AMD EPYC 64-Core CPUs
  • NVIDIA A100-SXM-80GB GPUs

Observations:

  • Data transfer to GPU becomes important
    • Suggest using MPI_gather to reduce overheads
  • CPU Net scales well

Benchmarking

Following the comparisons and MiMA experiments we performed detailed benchmarking to examine the library performance.