Reducing the overheads for coupling machine learning models to Fortran

ML&DL Seminars, LSCE - IPSL, Paris

Jack Atkinson

ICCS/Cambridge

Simon Clifford

ICCS/Cambridge

Athena Elafrou

NVIDIA

Elliott Kasoar

STFC/ICCS

Tom Meltzer

ICCS/Cambridge

Dominic Orchard

ICCS/Cambridge/Kent

2023-11-28

Precursors

Licensing

Except where otherwise noted, these presentation materials are licensed under the Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) License.

Vectors and icons by SVG Repo used under CC0(1.0)

Slides

To access links or follow on your own device these slides can be found at
https://jackatkinson.net/slides

Introduction

The ICCS

The Institute of Computing for Climate Science

  • Domain-specific group based at the University of Cambridge
  • Embedded support to several international climate science projects

Climate Modelling

Climate models are large, complex, many-part systems.

Machine Learning

We typically think of Deep Learning as an end-to-end process;
a black box with an input and an output.

Who’s that Pokémon?

\[\begin{bmatrix}\vdots\\a_{23}\\a_{24}\\a_{25}\\a_{26}\\a_{27}\\\vdots\\\end{bmatrix}=\begin{bmatrix}\vdots\\0\\0\\1\\0\\0\\\vdots\\\end{bmatrix}\] It’s Pikachu!

Neural Net by 3Blue1Brown under fair dealing.
Pikachu © The Pokemon Company, used under fair dealing.

Machine Learning in Science

Neural Net by 3Blue1Brown under fair dealing.
Pikachu © The Pokemon Company, used under fair dealing.

Language interoperation

Many large scientific models are written in Fortran (or C, or C++).
Much machine learning is conducted in Python.

Mathematical Bridge by cmglee used under CC BY-SA 3.0
PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation.”
TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.

Solutions

Considerations

There are 2 types of efficiency:

  • Computational

  • Developer

An ideal solution should:

  • not generate excess additional work,
    • not require advanced computing skills,
    • have a minimal learning curve,
  • not add excess dependencies,
  • be easy to maintain, and
  • maximise performance.

Possible solutions

  • Implement a NN in Fortran
  • Forpy/CFFI
  • SmartSim
  • Fortran-Keras Bridge

Possible solutions

  • Implement a NN in Fortran
  • Forpy/CFFI
  • SmartSim
  • Fortran-Keras Bridge
  • e.g. inference-engine, neural-fortran, own custom solution etc.

  • Removes the two-language problem

  • How do you ensure you port the model correctly?
  • ML libraries are highly optimised, probably more so than your code.

Possible solutions

  • Implement a NN in Fortran
  • Forpy/CFFI
  • SmartSim
  • Fortran-Keras Bridge
  • Brings python types into Fortran

  • Easy to add forpy.mod file and compile

  • Verbose, with a learning curve
  • Need to manage and link python environment
  • Increases dependencies

Possible solutions

  • Implement a NN in Fortran
  • Forpy/CFFI
  • SmartSim
  • Fortran-Keras Bridge
  • Pass data between workers through a network glue layer
  • May be necessary for certain architectures

  • Highly versatile - deals with data, not endpoints.

  • Learning curve
  • Involves data copying

Possible solutions

  • Implement a NN in Fortran
  • Forpy/CFFI
  • SmartSim
  • Fortran-Keras Bridge

  • Pure Fortran

  • TensorFlow (Keras) only
  • Inactive and incomplete



Other suggestions include fifo pipes, YAC (Arnold et al. 2023)

Possible solutions

Python
env

Python
runtime

xkcd #1987 by Randall Munroe, used under CC BY-NC 2.5

Introducing FTorch

Approach

  • PyTorch (and TensorFlow) have C++ backends and provide APIs to access.
  • Binding Fortran to C is straightforward1 from 2003 using iso_c_binding.

We will:

  • Archive PyTorch model as Torchscript
    • Statically typed subset of Python
    • Produces Intermediate Representation/graph of NN
    • Read and run via any Torch API
  • Provide Fortran API to abstract complex details from users
    • Wrapping the libtorch C++ API

Performant - Computational

No-copy access in memory (CPU).

Indexing issues and associated reshape can be avoided with Torch accessor.

Ease of use - Installation

CMake

  • Install libtorch (or PyTorch)
  • Clone FTorch
  • Build using CMake
    (instructions provided)
  • Install
  • Link


  • Clone and build alongside code
  • Guidance on linking provided
    Similar to any other lib e.g. NetCDF

Tested on:

  • Linux
  • macOS
  • Windows

CMake is a trademark of Kitware.

Ease of use

Examples

  • Guide users through entire process:
    • Saving a python model
    • to running in Fortran.
  • Includes:
    • Basic user-defined net
    • Preloaded (ResNet-18) case
    • Gravity wave drag scientific example

Tools

  • pt2ts.py script facilitates saving models to TorchScript.

Support

  • Use frameworks’ implementations directly
    • feature support
    • future support
    • direct translation of python models 1

Licensing and FOSS

The libraries are licensed under MIT and available as FOSS.

  • Highly permissive for use by all
  • OS development on GitHub using issues and PRs

Code

Saving model to TorchScript

import torch
import torchvision

# Load pre-trained model and put in eval mode
model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
model.eval()

# Create dummmy input
dummy_input = torch.ones(1, 3, 224, 224)

# Trace model and save
traced_model = torch.jit.trace(model, dummy_input)
frozen_model = torch.jit.freeze(traced_model)
frozen_model.save("saved_model.pt")

Loading a Torch model

! Use the FTorch Library
use :: ftorch

implicit none

! Define a Torch module
type(torch_module) :: model

! Load in from Torchscript
model = torch_module_load('/path/to/saved/model.pt')

Creating Tensors

use, intrinsic :: iso_fortran_env, only : sp => real32

! Use the FTorch Library
use :: ftorch

implicit none

! Fortran variables
real(sp), dimension(1,3,244,244), target :: in_data
real(sp), dimension(1, 1000), target :: out_data
integer, parameter :: n_inputs = 1
integer :: in_layout(4) = [1,2,3,4]
integer :: out_layout(2) = [1,2]

! Torch Tensors
type(torch_tensor), dimension(1) :: in_tensor
type(torch_tensor) :: out_tensor

! Populate Fortran data
call random_number(in_data)

! Cast Fortran data to Tensors
! Create input/output tensors from the above arrays
in_tensor(1) = torch_tensor_from_array(in_data, in_layout, torch_kCPU)
out_tensor = torch_tensor_from_array(out_data, out_layout, torch_kCPU)

Running model

! Infer
call torch_module_forward(model, in_tensor, n_inputs, out_tensor)

Cleaning up

! Cleanup
call torch_module_delete(model)
call torch_tensor_delete(in_tensor(1))
call torch_tensor_delete(out_tensor)

! Use Fortran array `out_data` elsewhere in code

Complete Code

use, intrinsic :: iso_fortran_env, only : sp => real32

! Use the FTorch Library
use :: ftorch

implicit none

! Define a Torch module
type(torch_module) :: model

! Fortran variables
real(sp), dimension(1,3,244,244), target :: in_data
real(sp), dimension(1, 1000), target :: out_tensor
integer, parameter :: n_inputs = 1
integer :: in_layout(4) = [1,2,3,4]
integer :: out_layout(2) = [1,2]

! Load in from Torchscript
model = torch_module_load('/path/to/saved/model.pt')

! Populate Fortran data
call random_number(in_data)

! Cast Fortran data to Tensors
! Create input/output tensors from the above arrays
in_tensor(1) = torch_tensor_from_array(in_data, in_layout, torch_kCPU)
out_tensor = torch_tensor_from_array(out_data, out_layout, torch_kCPU)

! Infer
call torch_module_forward(model, in_tensor, n_inputs, out_tensor)

! Cleanup
call torch_module_delete(model)
call torch_tensor_delete(in_tensor(1))
call torch_tensor_delete(out_tensor)

! Use Fortran array `out_data` elsewhere in code

GPU Acceleration

Save to TorchScript GPU from python:

# Set device as cuda
device = torch.device('cuda')

# Move model and dummy input to device before saving to TorchScript
model = model.to(device)
model.eval()
dummy_input = dummy_input.to(device)

# Trace model and save
traced_model = torch.jit.trace(model, dummy_input)
frozen_model = torch.jit.freeze(traced_model)
frozen_model.save("saved_gpu_model.pt")

Cast Tensors to GPU in Fortran:

! Load in from Torchscript
model = torch_module_load('/path/to/saved/gpu/model.pt')

! Cast Fortran data to Tensors
in_tensor(1) = torch_tensor_from_array(in_data, in_layout, torch_kCUDA)
out_tensor = torch_tensor_from_array(out_data, out_layout, torch_kCPU)

Case Study

Gravity Wave parameterisation in MiMA

  • Neural Net
    • Emulating Alexander and Dunkerton (1999) gravity wave parameterisation.
    • Fully-connected multi-layer net with identical Pytorch and TensorFlow versions
    • Trained offline in PyTorch
    • Initially interfaced (slowly) using forpy (Espinosa et al. 2022)

Results

  • Zonal-mean zonal winds (m/s)

  • Averaged over ±5 deg lat.

  • Pressure (height) vs. time Hovmöller diagram

  • FTorch exactly reproduces results of direct python call

  • NN is stable and reproduces QBO
    Tends to slightly over/under-predict

Coding example

Replace the forpy connected net with our direct coupled approach.

Test both PyTorch and TensorFlow.


Given a Fortran program with model inputs in arrays,

the original coupling using forpy requires

45

lines of boilerplate code,

whilst our library takes

27.


A fork of MiMA with these implementations of the interfaces is at:
https://github.com/DataWaveProject/MiMA-machine-learning

Strong Scaling

Wilkes3 (CSD3)

  • 3rd Generation AMD EPYC 64-Core CPUs
  • NVIDIA A100-SXM-80GB GPUs

Observations:

  • Data transfer to GPU becomes important
    • Suggest using MPI_gather to reduce overheads
  • CPU Net scales well

Benchmarking

Following the comparisons and MiMA experiments we performed detailed benchmarking to examine the library performance.

Conclusions

Take away messages

  • Machine learning has many potential applications in scientific computing
  • Leveraging it effectively requires care
  • FTorch allows easy and efficient deployment of ML within Fortran models
    • Minimal change required to instrument existing codes
    • Designed to be simple and familiar
  • For new ML projects we advise using PyTorch

Future work

  • Continuous development to improve UX
    • Abstraction of C bindings, Torch efficiency etc.
  • Implement functionalities beyond inference?
    • Online training is likely to become important
  • Implementation into CESM and other models
    • Including general guidelines

Get involved

  • Inform potential users
    • Further testing and feedback wanted!
  • FTorch team always keen to assist

Further discussion

EGU

PASC (TBC)

  • 3rd – 5th June 2024, Zurich, Switzerland
  • Minisymposium: Interfacing Machine Learning with Physics-Based Models
  • Speaker interest required asap

Thank You

The ICCS is funded by  

References

Alexander, MJ, and TJ Dunkerton. 1999. “A Spectral Parameterization of Mean-Flow Forcing Due to Breaking Gravity Waves.” Journal of the Atmospheric Sciences 56 (24): 4167–82.
Arnold, C., S. Sharma, T. Weigel, and D. Greenberg. 2023. “Efficient and Stable Coupling of the SuperdropNet Deep Learning-Based Cloud Microphysics (V0.1.0) to the ICON Climate and Weather Model (V2.6.5).” EGUsphere 2023: 1–17. https://doi.org/10.5194/egusphere-2023-2047.
Espinosa, Zachary I, Aditi Sheshadri, Gerald R Cain, Edwin P Gerber, and Kevin J DallaSanta. 2022. “Machine Learning Gravity Wave Parameterization Generalizes to Capture the QBO and Response to Increased CO2.” Geophysical Research Letters 49 (8): e2022GL098174.
Jucker, Martin, and EP Gerber. 2017. “Untangling the Annual Cycle of the Tropical Tropopause Layer with an Idealized Moist Model.” Journal of Climate 30 (18): 7339–58.

Bonus Content

Fortran-Tensorflow-lib

TensorFlow

  • C++ and C APIs
  • Archive model as Keras SavedModel
  • process_model provided to extract required opaque parameters and use API

Forpy vs FTorch comparison

e.g. Loading a Torch model

ie = forpy_initialize()

type(module_py) :: run_emulator
type(list) :: paths
type(object) :: model
type(tuple) :: args
type(str) :: py_model_dir

ie = str_create(py_model_dir, trim('/path/to/saved/model'))
ie = get_sys_path(paths)
ie = paths%append(py_model_dir)

! import python modules to `run_emulator`
ie = import_py(run_emulator, trim(model_name))
if (ie .ne. 0) then
    call err_print
    call error_mesg(__FILE__, __LINE__, "forpy model not loaded")
end if

! use python module `run_emulator` to load a trained model
ie = call_py(model, run_emulator, "name_of_init_function")
if (ie .ne. 0) then
    call err_print
    call error_mesg(__FILE__, __LINE__, "call to `initialize` failed")
end if

e.g. Loading a Torch model

type(torch_module) :: model

model = torch_module_load('/path/to/saved/model.pt')