-
-
Notifications
You must be signed in to change notification settings - Fork 164
New tutorial: Partitioned Burgers eq. 1D #670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vidulejs
wants to merge
61
commits into
develop
Choose a base branch
from
partitioned-burgers-1d
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
61 commits
Select commit
Hold shift + click to select a range
4149918
Partitioned Burgers eq. 1D initial commit. Dirichlet Neumann coupling.
vidulejs 1dabcb0
Rename Flux to Gradient (du_dx). Print residual and compare gradient…
vidulejs 6c8660c
Fully working Neumann surrogate participant. Scripts: run-partitione…
vidulejs 82accc2
Rename surrogate burgers to neumann-surrogate
vidulejs 2c425e4
activate-env.sh script
vidulejs 5ab08f6
Add readme and requirements.txt. Move DG old files
vidulejs c5844c1
Rename solver-scipy-fvolumes to solver-scipy. Add proper run and clea…
vidulejs 8b78407
Update README.md. Conform to tutorial images standard. Update run scr…
vidulejs 2357008
add equation to readme
vidulejs 5d66f9f
Remove DG solver
vidulejs 8213367
Create utils directory and move scripts. Create output directory, sep…
vidulejs 6546901
Update run scripts to conform to other preCICE tutorials. Fix package…
vidulejs 20cab48
clean.sh to clean-tutorial.sh
vidulejs fd63b1d
remove comment from precice-config.xml
vidulejs 9beb26b
Update README and add domain diagram image
vidulejs 1e8b1ea
Specify IC script in readme
vidulejs 11b91fc
Changelog entry
vidulejs 854ad44
Updated format: #704
vidulejs 292a672
Pre-commit formatting
vidulejs b193d05
Pre-commit formatting fix 2
vidulejs 6c99c35
Allow skipping venv setup
vidulejs 3c66728
Update readme ## Visualization
vidulejs bb434d9
Pre-commit formatting fix
vidulejs c06bede
Update changelog-entries/670.md
vidulejs 83218df
Update partitioned-burgers-1d/README.md
vidulejs 602a2b0
Merge branch 'develop' into partitioned-burgers-1d
vidulejs 51d1b2f
Update clean.sh scripts in solver directories to match the other tuto…
vidulejs ea81634
Exclude utils and output folders in cleaning-tools.sh script. This wi…
vidulejs e692a53
Update precice-config.xml. Fix logging and change from serial to para…
vidulejs bbf7f20
Use pytorch CPU version by default, user can configure GPU version by…
vidulejs 9a8dc51
Remove unnecessary comments and commented out code
vidulejs 84cafd5
Fix formatting with pre-commit hook. Why did it flatten one but not t…
vidulejs a476880
Remove html image formatting from README.md. Change section Visualiza…
vidulejs 1c69601
Update README.md
vidulejs ad6b1ee
Minor README changes
vidulejs 4f36c9d
Update changelog-entries/670.md
MakisH d50d7fa
Merge branch 'develop' into partitioned-burgers-1d
MakisH a64d1ca
Add metadata.yaml
MakisH 285b58e
Update venv handling
MakisH 98cc4a8
Merge branch 'develop' into partitioned-burgers-1d
MakisH 17b1244
Add *.npz to .gitignore
MakisH 6d85d91
Setup logging earlier
MakisH d18ade4
Update venv handling in additional scripts
MakisH 8bed094
Remove thin helper scripts
MakisH a739e1e
Adjust requirements.txt
MakisH 7005053
Add more details about using generate_ic.py
MakisH 20f8722
Make utils/generate-training-data.sh use solver-scipy/run.sh
MakisH ab00a15
Integrate into the system tests
MakisH 5153005
Set TUTORIALS_REF
MakisH 8c5deab
Add reference results for partitioned-burgers-1d
a57a54a
Reset TUTORIALS_REF
MakisH 5bbae3d
Merge branch 'develop' into partitioned-burgers-1d
MakisH 795f225
Further update venv handling
MakisH 89ce0ad
Merge branch 'develop' into partitioned-burgers-1d
MakisH 0b3ed91
Merge branch 'develop' into partitioned-burgers-1d
MakisH 973de67
Add additional info on the NN checkpoints
e9124d8
Comment about the visualization timestep argument
04ff643
Add reference to thesis
494a5b2
Fix typo in argument
a224e9a
Optionally point to my development repo?
b360199
Fix trailing space
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,7 @@ Makefile | |
| __pycache__/ | ||
| *.pyc | ||
| *.pyo | ||
| *.npz | ||
|
|
||
| # Rust | ||
| Cargo.lock | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| - Added new case: 1D partitioned Burgers' equation with one finite volume and one NN surrogate participant. [#670](https://github.com/precice/tutorials/pull/670) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| --- | ||
| title: Partitioned Burgers' equation 1D | ||
| permalink: tutorials-partitioned-burgers-1d.html | ||
| keywords: Python, Neural Network, Surrogate, Burgers Equation, Finite Volume, CFD | ||
| summary: This tutorial demonstrates the partitioned solution of the 1D Burgers' equation using preCICE and a neural network surrogate solver. | ||
| --- | ||
|
|
||
| {% note %} | ||
| Get the [case files of this tutorial](https://github.com/precice/tutorials/tree/develop/partitioned-burgers-1d), as continuously rendered here, or see the [latest released version](https://github.com/precice/tutorials/tree/master/partitioned-burgers-1d) (if there is already one). Read how in the [tutorials introduction](https://precice.org/tutorials.html). | ||
| {% endnote %} | ||
|
|
||
| ## Setup | ||
|
|
||
| We solve the 1D viscous Burgers' equation on the domain $[0,2]$: | ||
|
|
||
| $$ | ||
| \frac{\partial u}{\partial t} = \nu \frac{\partial^2 u}{\partial x^2} - u \frac{\partial u}{\partial x}, | ||
| $$ | ||
|
|
||
| where $u(x,t)$ is the scalar velocity field and $\nu$ is the viscosity. In this tutorial by default $\nu$ is very small ($10^{-12}$), but can be changed in the solver. | ||
|
|
||
| The domain is partitioned into participants at $x=1$: | ||
|
|
||
| - **Dirichlet**: Solves the left half $[0,1]$ and receives Dirichlet boundary conditions at the interface. | ||
| - **Neumann**: Solves the right half $[1,2]$ and receives Neumann boundary conditions at the interface. | ||
|
|
||
| Both outer boundaries use zero-gradient conditions $\frac{\partial u}{\partial x} = 0$. The problem can be solved for different initial conditions of superimposed sine waves, which can be generated using the provided script `utils/generate_ic.py`. | ||
|
|
||
|  | ||
| Diagram of the partitioned domain with an example initial condition. | ||
|
|
||
| ## Configuration | ||
|
|
||
| preCICE configuration (image generated using the [precice-config-visualizer](https://precice.org/tooling-config-visualization.html)): | ||
|
|
||
|  | ||
|
|
||
| ## Available solvers | ||
|
|
||
| Currently, the SciPy solver (`solver-scipy`) can be used for both sides, the NN surrogate solver (`neumann-surrogate`) only for the Neumann side. | ||
|
|
||
| - SciPy. Simple finite volume solver using Lax-Friedrichs fluxes and implicit Euler time stepping. | ||
| - Surrogate. Pre-trained neural network surrogate model. | ||
|
|
||
| The conservative formulation of the Burgers' equation is implemented in the SciPy solver. The surrogate is a neural network trained to autoregressively predict the next time step solution based on the current solution. The surrogate model was trained on solutions obtained with the SciPy solver. See [Initial condition](#initial-condition) for how to generate the training data. | ||
|
|
||
| Two pre-trained model checkpoints are provided in `neumann-surrogate/`, differing in how many unroll timesteps were used during the Backpropagation Through Time (BPTT) training phase. The two checkpoints, `CNN_RES_UNROLL_1.pth` and `CNN_RES_UNROLL_7.pth`, were trained, respectively, with a single-step prediction (rollout length 1) and a 7-step rollout, which improves stability over long autoregressive predictions. The checkpoint can be selected by changing `MODEL_NAME` in `neumann-surrogate/config.py`. | ||
|
|
||
| The full training pipeline is available in a separate development repository: [github.com/vidulejs/neural-adapter](https://github.com/vidulejs/neural-adapter). | ||
|
|
||
| {% note %} | ||
| The surrogate participant requires PyTorch and related dependencies. By default, the CPU version is installed. If you wish to use the GPU version, see the comment in `neumann-surrogate/requirements.txt`. The GPU version requires several gigabytes of disk space (~6.2Gb). | ||
| {% endnote %} | ||
|
|
||
| ## Running the simulation | ||
|
|
||
| ### Running the participants | ||
|
|
||
| To run the partitioned simulation, open two separate terminals and start each participant individually: | ||
|
|
||
| You can find the corresponding `run.sh` script for running the case in the folders corresponding to the participant you want to use: | ||
|
|
||
| ```bash | ||
| cd dirichlet-scipy | ||
| ./run.sh | ||
| ``` | ||
|
|
||
| and | ||
|
|
||
| ```bash | ||
| cd neumann-scipy | ||
| ./run.sh | ||
| ``` | ||
|
|
||
| or, to use the pretrained neural network surrogate participant: | ||
|
|
||
| ```bash | ||
| cd neumann-surrogate | ||
| ./run.sh | ||
| ``` | ||
|
|
||
| ### Initial condition | ||
|
|
||
| The initial condition file `initial_condition.npz` is automatically generated by the run scripts if it does not exist. | ||
| You can also manually generate it using the `utils/generate_ic.py` script: | ||
|
|
||
| ```bash | ||
| python3 utils/generate_ic.py | ||
| ``` | ||
|
|
||
| This script requires the Python libraries `numpy` and `matplotlib`. It accepts an optional argument `--epoch` as a random number generator seed, which defaults to zero. | ||
|
|
||
| To generate the training data, you can use the `utils/generate-training-data.sh` script from the tutorial root directory, which will generate data for different `--epoch` values: | ||
|
|
||
| ```bash | ||
| ./utils/generate-training-data.sh | ||
| ``` | ||
|
|
||
| ### Monolithic solution (reference) | ||
|
|
||
| You can run the whole domain using the monolithic solver for comparison: | ||
|
|
||
| ```bash | ||
| cd solver-scipy | ||
| ./run.sh | ||
| ``` | ||
|
|
||
| ## Post-processing | ||
|
|
||
| After both participants (and/or monolithic simulation) have finished, you can run the visualization script. | ||
| `visualize_partitioned_domain.py` generates plots comparing the partitioned and monolithic solutions. You can specify which timestep to plot. Call from the root of the tutorial: | ||
|
|
||
| ```bash | ||
| python3 utils/visualize_partitioned_domain.py --neumann neumann-surrogate/surrogate.npz 10 | ||
| ``` | ||
|
|
||
| The final argument defines the coupling time step to plot (here, `10`). It can range from `0` up to the total number of time steps performed in the run. | ||
|
|
||
| The script will produce the following output files in the `output/` directory: | ||
|
|
||
| - `full-domain-timestep-slice.png`: Solution $u$ at the selected timestep | ||
|
|
||
|  | ||
|
|
||
| - `gradient-timestep-slice.png`: Gradient $du/dx$ at the selected timestep | ||
|
|
||
| - `full-domain-evolution.png`: Time evolution of the solution | ||
|
|
||
|  | ||
|
|
||
| ## References | ||
|
|
||
| Dagis Daniels Vidulejs. "Coupling Neural Surrogates with Traditional Solvers using preCICE." Master's thesis, Technical University of Munich, 2025. |
|
MakisH marked this conversation as resolved.
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| #!/usr/bin/env sh | ||
| set -e -u | ||
|
|
||
| # shellcheck disable=SC1091 | ||
| . ../tools/cleaning-tools.sh | ||
|
|
||
| clean_tutorial . | ||
| clean_precice_logs . | ||
| rm -fv ./*.log | ||
| rm -fv ./*.vtu | ||
|
|
||
| # Clean up root directory | ||
| rm -f initial_condition.npz | ||
| rm -rf output/ |
|
vidulejs marked this conversation as resolved.
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| #!/usr/bin/env sh | ||
| set -e -u | ||
|
|
||
| # shellcheck disable=SC1091 | ||
| . ../../tools/cleaning-tools.sh | ||
|
|
||
| clean_precice_logs . | ||
| clean_case_logs . | ||
| rm -f dirichlet.npz |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| #!/usr/bin/env bash | ||
| set -e -u | ||
|
|
||
| . ../../tools/log.sh | ||
| exec > >(tee --append "$LOGFILE") 2>&1 | ||
|
|
||
| if [ ! -v PRECICE_TUTORIALS_NO_VENV ] | ||
| then | ||
| if [ ! -d ".venv" ]; then | ||
| python3 -m venv .venv | ||
| source .venv/bin/activate | ||
| pip install -r ../solver-scipy/requirements.txt && pip freeze > pip-installed-packages.log | ||
| else | ||
| source .venv/bin/activate | ||
| fi | ||
| fi | ||
|
|
||
| if [ ! -f "../initial_condition.npz" ]; then | ||
| echo "Generating initial condition..." | ||
| python3 ../utils/generate_ic.py | ||
| fi | ||
|
|
||
| python3 ../solver-scipy/solver.py Dirichlet | ||
|
|
||
| close_log |
Binary file added
BIN
+47.9 KB
...oned-burgers-1d/images/tutorials-partitioned-burgers-1d-full-domain-diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+127 KB
...ed-burgers-1d/images/tutorials-partitioned-burgers-1d-full-domain-evolution.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+176 KB
...rgers-1d/images/tutorials-partitioned-burgers-1d-full-domain-timestep-slice.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+88.3 KB
partitioned-burgers-1d/images/tutorials-partitioned-burgers-1d-precice-config.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| name: Partitioned Burgers' equation 1D | ||
| path: partitioned-burgers-1d # relative to git repo | ||
| url: https://precice.org/tutorials-partitioned-burgers-1d.html | ||
|
|
||
| participants: | ||
| - Dirichlet | ||
| - Neumann | ||
|
|
||
| cases: | ||
| dirichlet-scipy: | ||
| participant: Dirichlet | ||
| directory: ./dirichlet-scipy | ||
| run: ./run.sh | ||
| component: python-bindings | ||
|
|
||
| neumann-scipy: | ||
| participant: Neumann | ||
| directory: ./neumann-scipy | ||
| run: ./run.sh | ||
| component: python-bindings | ||
|
|
||
| neumann-surrogate: | ||
| participant: Neumann | ||
| directory: neumann-surrogate | ||
| run: ./run.sh | ||
| component: python-bindings |
|
vidulejs marked this conversation as resolved.
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| #!/usr/bin/env sh | ||
| set -e -u | ||
|
|
||
| # shellcheck disable=SC1091 | ||
| . ../../tools/cleaning-tools.sh | ||
|
|
||
| clean_precice_logs . | ||
| clean_case_logs . | ||
| rm -f neumann.npz |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| #!/usr/bin/env bash | ||
| set -e -u | ||
|
|
||
| . ../../tools/log.sh | ||
| exec > >(tee --append "$LOGFILE") 2>&1 | ||
|
|
||
| if [ ! -v PRECICE_TUTORIALS_NO_VENV ] | ||
| then | ||
| if [ ! -d ".venv" ]; then | ||
| python3 -m venv .venv | ||
| source .venv/bin/activate | ||
| pip install -r ../solver-scipy/requirements.txt && pip freeze > pip-installed-packages.log | ||
| else | ||
| source .venv/bin/activate | ||
| fi | ||
| fi | ||
|
|
||
| if [ ! -f "../initial_condition.npz" ]; then | ||
| echo "Generating initial condition..." | ||
| python3 ../utils/generate_ic.py | ||
| fi | ||
|
|
||
| python3 ../solver-scipy/solver.py Neumann | ||
|
|
||
| close_log |
|
vidulejs marked this conversation as resolved.
|
Binary file not shown.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| #!/usr/bin/env sh | ||
| set -e -u | ||
|
|
||
| # shellcheck disable=SC1091 | ||
| . ../../tools/cleaning-tools.sh | ||
|
|
||
| clean_precice_logs . | ||
| clean_case_logs . | ||
| rm -f surrogate.npz |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| import torch | ||
|
|
||
| # Model architecture | ||
| INPUT_SIZE = 128 + 2 # +2 for ghost cells | ||
| HIDDEN_SIZE = 64 # num filters | ||
| OUTPUT_SIZE = 128 | ||
|
|
||
| assert INPUT_SIZE >= OUTPUT_SIZE, "Input size must be greater or equal to output size." | ||
| assert (INPUT_SIZE - OUTPUT_SIZE) % 2 == 0, "Input and output sizes must differ by an even number (for ghost cells)." | ||
|
|
||
| NUM_RES_BLOCKS = 4 | ||
| KERNEL_SIZE = 5 | ||
| ACTIVATION = torch.nn.ReLU | ||
|
|
||
| MODEL_NAME = "CNN_RES_UNROLL_7.pth" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from torch.nn.utils import weight_norm | ||
|
|
||
|
|
||
| def pad_with_ghost_cells(input_seq, bc_left, bc_right): | ||
| return torch.cat([bc_left, input_seq, bc_right], dim=1) | ||
|
|
||
|
|
||
| class LinearExtrapolationPadding1D(nn.Module): | ||
| """Applies 'same' padding using linear extrapolation.""" | ||
|
|
||
| def __init__(self, kernel_size: int, dilation: int = 1): | ||
| super().__init__() | ||
| self.pad_total = dilation * (kernel_size - 1) | ||
| self.pad_beg = self.pad_total // 2 | ||
| self.pad_end = self.pad_total - self.pad_beg | ||
|
|
||
| def forward(self, x): | ||
| # Don't pad if not necessary | ||
| if self.pad_total == 0: | ||
| return x | ||
|
|
||
| ghost_cell_left = x[:, :, :1] | ||
| ghost_cell_right = x[:, :, -1:] | ||
|
|
||
| # Calculate the gradient at each boundary | ||
| grad_left = x[:, :, 1:2] - ghost_cell_left | ||
| grad_right = ghost_cell_right - x[:, :, -2:-1] | ||
|
|
||
| # Extrapolated padding tensors | ||
| left_ramp = torch.arange(self.pad_beg, 0, -1, device=x.device, dtype=x.dtype).view(1, 1, -1) | ||
| left_padding = ghost_cell_left - left_ramp * grad_left | ||
|
|
||
| right_ramp = torch.arange(1, self.pad_end + 1, device=x.device, dtype=x.dtype).view(1, 1, -1) | ||
| right_padding = ghost_cell_right + right_ramp * grad_right | ||
|
|
||
| return torch.cat([left_padding, x, right_padding], dim=2) | ||
|
|
||
|
|
||
| class ResidualBlock1D(nn.Module): | ||
| """A residual block that uses custom 'same' padding with linear extrapolation and weight normalization.""" | ||
|
|
||
| def __init__(self, channels, kernel_size=3, activation=nn.ReLU): | ||
| super(ResidualBlock1D, self).__init__() | ||
| self.activation = activation() | ||
| # Apply weight normalization | ||
| self.conv1 = weight_norm(nn.Conv1d(channels, channels, kernel_size, padding='valid', bias=True)) | ||
| self.ghost_padding1 = LinearExtrapolationPadding1D(kernel_size) | ||
| self.conv2 = weight_norm(nn.Conv1d(channels, channels, kernel_size, padding='valid', bias=True)) | ||
| self.ghost_padding2 = LinearExtrapolationPadding1D(kernel_size) | ||
|
|
||
| def forward(self, x): | ||
| identity = x | ||
|
|
||
| out = self.ghost_padding1(x) | ||
| out = self.conv1(out) | ||
| out = self.activation(out) | ||
|
|
||
| out = self.ghost_padding2(out) | ||
| out = self.conv2(out) | ||
|
|
||
| return self.activation(out) + identity | ||
|
|
||
|
|
||
| class CNN_RES(nn.Module): | ||
| """ | ||
| A CNN with residual blocks for 1D data. | ||
| Expects a pre-padded input with ghost_cells//2 number ghost cells on each side. | ||
| Applies a custom linear extrapolation padding for inner layers. | ||
| """ | ||
|
|
||
| def __init__(self, hidden_channels, num_blocks=2, kernel_size=3, activation=nn.ReLU, ghost_cells=2): | ||
| super(CNN_RES, self).__init__() | ||
| self.activation = activation() | ||
| self.hidden_channels = hidden_channels | ||
| self.num_blocks = num_blocks | ||
| self.kernel_size = kernel_size | ||
| assert ghost_cells % 2 == 0, "ghost_cells must be even" | ||
| self.ghost_cells = ghost_cells | ||
|
|
||
| self.ghost_padding = LinearExtrapolationPadding1D(self.ghost_cells + self.kernel_size) | ||
|
|
||
| # Apply weight normalization to the input convolution | ||
| self.conv_in = weight_norm(nn.Conv1d(1, hidden_channels, kernel_size=1, bias=True)) | ||
|
|
||
| layers = [ResidualBlock1D(hidden_channels, kernel_size, activation=activation) for _ in range(num_blocks)] | ||
| self.res_blocks = nn.Sequential(*layers) | ||
|
|
||
| self.conv_out = nn.Conv1d(hidden_channels, 1, kernel_size=1) | ||
|
|
||
| def forward(self, x): | ||
|
|
||
| if x.dim() == 2: | ||
| x = x.unsqueeze(1) # Add channel dim: (B, 1, L) | ||
|
|
||
| if not self.ghost_cells == 0: | ||
| x_padded = self.ghost_padding(x) | ||
|
|
||
| else: | ||
| x_padded = x | ||
|
|
||
| total_pad_each_side = self.ghost_padding.pad_beg + self.ghost_cells // 2 | ||
|
|
||
| out = self.activation(self.conv_in(x_padded)) # no extra padding here | ||
| out = self.res_blocks(out) | ||
| out = self.conv_out(out) # no extra padding here | ||
|
|
||
| if not self.ghost_cells == 0: | ||
| out = out[:, :, total_pad_each_side:-total_pad_each_side] # remove ghost cells, return only internal domain | ||
|
|
||
| return out.squeeze(1) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.