working checkpoint system

This commit is contained in:
Thomas Faour 2025-05-31 12:52:30 -04:00
parent ed1e913366
commit a67b188ed9
12 changed files with 54 additions and 40 deletions

BIN
last_checkpoint.npz Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,7 +1,7 @@
from enum import Enum from enum import Enum
import numpy as np import numpy as np
from ..units import Position, Velocity, Mass, Acceleration, Force from ..units import Position, Velocity, Mass, Acceleration
class Body: class Body:
""" """

View File

@ -1,7 +1,7 @@
from pathlib import Path from pathlib import Path
import numpy as np
from body import Body from .body import Body
from ..units import Position, Velocity, Mass, Acceleration, Force
class Simulator: class Simulator:
""" """
@ -31,7 +31,8 @@ class Simulator:
bodies: list[Body], bodies: list[Body],
step_size: float, step_size: float,
steps_per_save: int, steps_per_save: int,
output_file: Path output_file: Path,
current_step: int = 0,
overwrite_output: bool = False overwrite_output: bool = False
): ):
if output_file.exists() and not overwrite_output: if output_file.exists() and not overwrite_output:
@ -40,43 +41,63 @@ class Simulator:
self.output_file = output_file self.output_file = output_file
self.bodies = bodies self.bodies = bodies
self.step_size = step_size self.step_size = step_size
self.steps_to_take = steps_to_take
self.steps_per_save = steps_per_save self.steps_per_save = steps_per_save
self.current_step = current_step
if output_file.exists() and overwrite_output: if output_file.exists() and overwrite_output:
print(f"Warning! Overwriting file: {output_file}") print(f"Warning! Overwriting file: {output_file}")
_save_body_masses_to_file() #self._save_body_masses_to_file()
_checkpoint() self._checkpoint()
self.current_step = 0
@classmethod @classmethod
def from_checkpoint(cls, input_file: Path): def from_checkpoint(cls, output_file: Path):
data = np.load("last_checkpoint.npz")
positions = data["positions"]
velocities = data["velocities"]
masses = data["masses"]
step_size = data["steps"][0]
current_step = data["steps"][1]
steps_per_save = data["steps"][2]
bodies = [
Body(val[0], val[1], val[2]) for val in zip(
positions, velocities, masses
)
]
return cls(
bodies,
step_size,
steps_per_save,
output_file,
current_step,
)
def _save_body_masses_to_file(self):
masses_str = ' '.join([
str(body.m) for body in bodies
]) + '\n'
#if saving masses, we are always starting a new file,
#this will overwrite file.
self.output_file.write_text(masses_str)
def _checkpoint(self): def _checkpoint(self):
""" """
Two things - save high precision last checkpoint for resuming Two things - save high precision last checkpoint for resuming
then save lower precision text for trajectories then save lower precision text for trajectories
""" """
body_X_V_np = np.array([ body_X_np = np.array([
[body.X, body.Y] for body in self.bodies body.X for body in self.bodies
]) ])
body_ints_np = np.array([ body_V_np = np.array([
body.V for body in self.bodies
])
body_m_np = np.array([
body.m for body in self.bodies body.m for body in self.bodies
]) ])
body_m_np.append(self.step_size, self.current_step) stepsz_n_np = np.array([
np.savez("last_checkpoint.npz", self.step_size,
array=body_X_V_np, self.current_step,
ints = body_m_np) self.steps_per_save
])
np.savez("last_checkpoint.npz",
positions=body_X_np,
velocities=body_V_np,
masses=body_m_np,
steps=stepsz_n_np)

9
orbiter/units.py Normal file
View File

@ -0,0 +1,9 @@
import numpy as np
Position = np.array
Velocity = np.array
Acceleration = np.array
Mass = int

View File

@ -1,16 +0,0 @@
import numpy as np
class Position(np.array):
pass
class Velocity(np.array):
pass
class Acceleration(np.array):
pass
class Force(np.array):
pass
class Mass(int):
pass