working checkpoint system

This commit is contained in:
Thomas Faour 2025-05-31 12:52:30 -04:00
parent ed1e913366
commit a67b188ed9
12 changed files with 54 additions and 40 deletions

BIN
last_checkpoint.npz Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,7 +1,7 @@
from enum import Enum
import numpy as np
from ..units import Position, Velocity, Mass, Acceleration, Force
from ..units import Position, Velocity, Mass, Acceleration
class Body:
"""

View File

@ -1,7 +1,7 @@
from pathlib import Path
import numpy as np
from body import Body
from ..units import Position, Velocity, Mass, Acceleration, Force
from .body import Body
class Simulator:
"""
@ -31,7 +31,8 @@ class Simulator:
bodies: list[Body],
step_size: float,
steps_per_save: int,
output_file: Path
output_file: Path,
current_step: int = 0,
overwrite_output: bool = False
):
if output_file.exists() and not overwrite_output:
@ -40,43 +41,63 @@ class Simulator:
self.output_file = output_file
self.bodies = bodies
self.step_size = step_size
self.steps_to_take = steps_to_take
self.steps_per_save = steps_per_save
self.current_step = current_step
if output_file.exists() and overwrite_output:
print(f"Warning! Overwriting file: {output_file}")
_save_body_masses_to_file()
_checkpoint()
#self._save_body_masses_to_file()
self._checkpoint()
self.current_step = 0
@classmethod
def from_checkpoint(cls, input_file: Path):
def from_checkpoint(cls, output_file: Path):
data = np.load("last_checkpoint.npz")
positions = data["positions"]
velocities = data["velocities"]
masses = data["masses"]
step_size = data["steps"][0]
current_step = data["steps"][1]
steps_per_save = data["steps"][2]
bodies = [
Body(val[0], val[1], val[2]) for val in zip(
positions, velocities, masses
)
]
return cls(
bodies,
step_size,
steps_per_save,
output_file,
current_step,
)
def _save_body_masses_to_file(self):
masses_str = ' '.join([
str(body.m) for body in bodies
]) + '\n'
#if saving masses, we are always starting a new file,
#this will overwrite file.
self.output_file.write_text(masses_str)
def _checkpoint(self):
"""
Two things - save high precision last checkpoint for resuming
then save lower precision text for trajectories
"""
body_X_V_np = np.array([
[body.X, body.Y] for body in self.bodies
body_X_np = np.array([
body.X for body in self.bodies
])
body_ints_np = np.array([
body_V_np = np.array([
body.V for body in self.bodies
])
body_m_np = np.array([
body.m for body in self.bodies
])
body_m_np.append(self.step_size, self.current_step)
np.savez("last_checkpoint.npz",
array=body_X_V_np,
ints = body_m_np)
stepsz_n_np = np.array([
self.step_size,
self.current_step,
self.steps_per_save
])
np.savez("last_checkpoint.npz",
positions=body_X_np,
velocities=body_V_np,
masses=body_m_np,
steps=stepsz_n_np)

9
orbiter/units.py Normal file
View File

@ -0,0 +1,9 @@
import numpy as np
Position = np.array
Velocity = np.array
Acceleration = np.array
Mass = int

View File

@ -1,16 +0,0 @@
import numpy as np
class Position(np.array):
pass
class Velocity(np.array):
pass
class Acceleration(np.array):
pass
class Force(np.array):
pass
class Mass(int):
pass