Skip to content

FNO battery module cooling example #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Battery Module Cooling Analysis with Fourier Neural Operator

This example demonstrates applying a Fourier Neural Operator [1] to heat analysis of a 3D battery module.

## Setup

Run the example by running [`example.mlx`](./example.mlx).

## Requirements

Requires
- [MATLAB](https://www.mathworks.com/products/matlab.html) (R2025a or newer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this run in earlier releases? What 25a features are requied?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK the example states we're using 25a functionality to handle the geometry.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the helper function in this example should work in some older releases than 25a. I'm not sure what the most recent Deep Learning Toolbox code I've used is, maybe the networkLayer from 24a, though it isn't strictly necessary. The PDE Toolbox introduced the femodel in 23a, but there was ThermalModel before that which might suffice for this example. So potentially we could write a version of the example that's supported quite a few releases back.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the old APIs like ThermalModel will eventually be removed, so I would strongly suggest staying with femodel (23a as Ben mentioned)

- [Deep Learning Toolbox™](https://www.mathworks.com/products/deep-learning.html)
- [Partial Differential Equation Toolbox™](https://uk.mathworks.com/products/pde.html)
- [Parallel Computing Toolbox™](https://uk.mathworks.com/products/parallel-computing.html) (for training on a GPU)

## References
[1] Li, Zongyi, Nikola Borislavov Kovachki, Kamyar Azizzadenesheli, Burigede Liu, Kaushik Bhattacharya, Andrew Stuart, and Anima Anandkumar. 2021. "Fourier Neural Operator for Parametric Partial Differential Equations." In International Conference on Learning Representations. https://openreview.net/forum?id=c8P9NQVtmnO.

#

Copyright 2025 The MathWorks, Inc.
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
function [geomModule, domainIDs, boundaryIDs, volume, boundaryArea, ReferencePoint] = createBatteryModuleGeometry(numCellsInModule, cellWidth,cellThickness,tabThickness,tabWidth,cellHeight,tabHeight, connectorHeight )
%% Uses Boolean geometry functionality in PDE Toolbox, which requires release R2025a or later.
% If you have an older version, use the helper function in this example:
% https://www.mathworks.com/help/pde/ug/battery-module-cooling-analysis-and-reduced-order-thermal-model.html

% First, create a single pouch cell by unioning the cell, tab and connector
% Cell creation
cell1 = fegeometry(multicuboid(cellThickness,cellWidth,cellHeight));
cell1 = translate(cell1,[cellThickness/2,cellWidth/2,0]);
% Tab creation
tab = fegeometry(multicuboid(tabThickness,tabWidth,tabHeight));
tabLeft = translate(tab,[cellThickness/2,tabWidth,cellHeight]);
tabRight = translate(tab,[cellThickness/2,cellWidth-tabWidth,cellHeight]);
% Union tabs to cells
geomPouch = union(cell1, tabLeft, KeepBoundaries=true);
geomPouch = union(geomPouch, tabRight, KeepBoundaries=true);
% Connector creation
overhang = (cellThickness-tabThickness)/2;
connector = fegeometry(multicuboid(tabThickness+overhang,tabWidth,connectorHeight));
connectorRight = translate(connector,[cellThickness/2+overhang/2,tabWidth,cellHeight+tabHeight]);
connectorLeft = translate(connector,[(cellThickness/2-overhang/2),cellWidth-tabWidth,cellHeight+tabHeight]);
% Union connectors to tabs
geomPouch = union(geomPouch,connectorLeft,KeepBoundaries=true);
geomPouch = union(geomPouch,connectorRight,KeepBoundaries=true);
% Scale and translate completed pouch cell to create mirrored cell
geomPouchMirrored = translate(scale(geomPouch,[-1 1 1]),[cellThickness,0,0]);
% Union individual pouches to create full module
% Union even-numbered pouch cells together (original cells)
geomForward = fegeometry;
for i = 0:2:numCellsInModule-1
offset = cellThickness*i;
geom_to_append = translate(geomPouch,[offset,0,0]);
geomForward = union(geomForward,geom_to_append);
end
% Union odd-numbered pouch cells together (mirrored cells)
geomBackward = fegeometry;
for i = 1:2:numCellsInModule-1
offset = cellThickness*i;
geom_to_append = translate(geomPouchMirrored,[offset,0,0]);
geomBackward = union(geomBackward,geom_to_append);
end
% Union to create completed geometry module
geomModule = union(geomForward,geomBackward,KeepBoundaries=true);
% Rotate and translate the geometry
geomModule = translate(scale(geomModule,[1 -1 1]),[0 cellWidth 0]);
% Mesh the geometry to use query functions for identifying cells and faces
geomModule = generateMesh(geomModule,GeometricOrder="linear");
% Create Reference Points for each geometry future
ReferencePoint.Cell = [cellThickness/2,cellWidth/2,cellHeight/2];
ReferencePoint.TabLeft = [cellThickness/2,tabWidth,cellHeight+tabHeight/2];
ReferencePoint.TabRight = [cellThickness/2,cellWidth-tabWidth,cellHeight+tabHeight/2];
ReferencePoint.ConnectorLeft = [cellThickness/2,tabWidth,cellHeight+tabHeight+connectorHeight/2];
ReferencePoint.ConnectorRight = [cellThickness/2,cellWidth-tabWidth,cellHeight+tabHeight+connectorHeight/2];
% Helper function to get the cell IDs belonging to cell, tab and connector
[~,~,t] = meshToPet(geomModule.Mesh);
elementDomain = t(end,:);
tr = triangulation(geomModule.Mesh.Elements',geomModule.Mesh.Nodes');
getCellID = @(point,cellNumber) elementDomain(pointLocation(tr,point+(cellNumber(:)-1)*[cellThickness,0,0]));
% Helper function to get the volume of the cells, tabs, and connectors
getVolumeOneCell = @(geomCellID) geomModule.Mesh.volume(findElements(geomModule.Mesh,"region",Cell=geomCellID));
getVolume = @(geomCellIDs) arrayfun(@(n) getVolumeOneCell(n),geomCellIDs);
% Initialize cell ID and volume structs
domainIDs(1:numCellsInModule) = struct(Cell=[], ...
TabLeft=[],TabRight=[], ...
ConnectorLeft=[],ConnectorRight=[]);
volume(1:numCellsInModule) = struct(Cell=[], ...
TabLeft=[],TabRight=[], ...
ConnectorLeft=[],ConnectorRight=[]);
% Helper function to get the IDs belonging to the left, right, front, back, top and bottom faces
getFaceID = @(offsetVal,offsetDirection,cellNumber) nearestFace(geomModule,...
ReferencePoint.Cell + offsetVal/2 .*offsetDirection ... % offset ref. point to face
+ cellThickness*(cellNumber(:)-1)*[1,0,0]); % offset to cell
% Initialize face ID and area structs
boundaryIDs(1:numCellsInModule) = struct(FrontFace=[],BackFace=[], ...
RightFace=[],LeftFace=[], ...
TopFace=[],BottomFace=[]);
boundaryArea(1:numCellsInModule) = struct(FrontFace=[],BackFace=[], ...
RightFace=[],LeftFace=[], ...
TopFace=[],BottomFace=[]);
% Loop over cell, left tab, right tab, left connector, and right connector to get cell IDs and volumes
for part = string(fieldnames(domainIDs))'
partid = num2cell(getCellID(ReferencePoint.(part),1:numCellsInModule));
[domainIDs.(part)] = partid{:};
volumesPart = num2cell(getVolume([partid{:}]));
[volume.(part)] = volumesPart{:};
end
% Loop over front, back, right, left, top, and bottom faces IDs and areas
dimensions = [cellThickness;cellThickness;cellWidth;cellWidth;cellHeight;cellHeight];
vectors = [-1,0,0;1,0,0;0,1,0;0,-1,0;0,0,1;0,0,-1];
areaFormula = [cellHeight*cellWidth;cellHeight*cellWidth;cellThickness*cellHeight;cellThickness*cellHeight;cellThickness*cellWidth - tabThickness*tabWidth;cellThickness*cellWidth - tabThickness*tabWidth];
i = 1;
for face = string(fieldnames(boundaryIDs))'
faceid = num2cell(getFaceID(dimensions(i),vectors(i,:),1:numCellsInModule));
[boundaryIDs.(face)] = faceid{:};
areasFace = num2cell(areaFormula(i)*ones(1,numCellsInModule));
[boundaryArea.(face)] = areasFace{:};
i = i+1;
end
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
classdef spectralConvolution3dLayer < nnet.layer.Layer ...
& nnet.layer.Formattable ...
& nnet.layer.Acceleratable
% spectralConvolution3dLayer A custom layer implementation of
% spectral convolution for data with 3 spatial dimensions.

properties
Cin
Cout
NumModes
end

properties (Learnable)
Weights
end

methods
function this = spectralConvolution3dLayer(numModes, outChannels, nvargs)
arguments
numModes (1,1) double
outChannels (1,1) double
nvargs.Name {mustBeTextScalar} = "spectralConv3d"
nvargs.Weights = []
end

this.Cout = outChannels;
this.NumModes = numModes;
this.Name = nvargs.Name;
this.Weights = nvargs.Weights;
end

function this = initialize(this, ndl)
inChannels = ndl.Size( finddim(ndl,'C') );
outChannels = this.Cout;
numModes = this.NumModes;

if isempty(this.Weights)
this.Cin = inChannels;
this.Weights = 1./(inChannels*outChannels).*( ...
rand([outChannels inChannels numModes numModes numModes]) + ...
1i.*rand([outChannels inChannels numModes numModes numModes]) );
end
end

function y = predict(this, x)

% Compute the 3d fft and retain only the low frequency modes as
% specified by NumModes.
x = real(x);
x = stripdims(x);
N = size(x, 1);
Nm = this.NumModes;
xft = fft(x, [], 1);
xft = xft(1:Nm,:,:,:,:);
xft = fft(xft, [], 2);
xft = xft(:,1:Nm,:,:,:);
xft = fft(xft, [], 3);
xft = xft(:,:,1:Nm,:,:);

% Multiply selected Fourier modes with the learnable weights.
xft = permute(xft, [4 5 1 2 3]);
yft = pagemtimes( this.Weights, xft );
yft = permute(yft, [3, 4, 5, 1, 2]);

% Make the frequency representation conjugate-symmetric such
% that the inverse Fourier transform is real-valued.
S = floor(N/2)+1 - this.NumModes;
idx = ceil(N/2):-1:2;
yft = cat(1, yft, zeros([S size(yft, 2:5)], 'like', yft));
yft = cat(1, yft, conj(yft(idx,:,:,:,:)));

yft = cat(2, yft, zeros([size(yft,1), S, size(yft,3:5)], like=yft));
yft = cat(2, yft, conj(yft(:,idx,:,:,:)));

yft = cat(3, yft, zeros([size(yft,[1,2]), S, size(yft,4:5)], like=yft));
yft = cat(3, yft, conj(yft(:,:,idx,:,:)));

% Return to physical space via 3d ifft
y = ifft(yft, [], 3, 'symmetric');
y = ifft(y,[],2, 'symmetric');
y = ifft(y,[],1, 'symmetric');

% Re-apply labels
y = dlarray(y, 'SSSCB');
y = real(y);
end
end
end