OrbitalSimulator/plot_trajectories.py

571 lines
21 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Plot orbital trajectories from binary output file generated by the orbital simulator.
The binary file contains snapshots serialized with bincode, where each snapshot has:
- time: f64
- bodies: array of Body structs
- name: String
- mass: f64
- position: [f64; 3] (x, y, z)
- velocity: [f64; 3] (vx, vy, vz)
- acceleration: [f64; 3] (ax, ay, az)
Usage:
python plot_trajectories.py <trajectory_file.bin>
"""
import sys
import struct
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import numpy as np
from collections import defaultdict
import argparse
class BinaryReader:
def __init__(self, data):
self.data = data
self.pos = 0
def read_u64(self):
result = struct.unpack('<Q', self.data[self.pos:self.pos+8])[0]
self.pos += 8
return result
def read_f64(self):
result = struct.unpack('<d', self.data[self.pos:self.pos+8])[0]
self.pos += 8
return result
def read_string(self):
# Read length (u64) then string bytes
length = self.read_u64()
result = self.data[self.pos:self.pos+length].decode('utf-8')
self.pos += length
return result
def read_vec3(self):
# Read 3 f64 values for position/velocity/acceleration
x = self.read_f64()
y = self.read_f64()
z = self.read_f64()
return np.array([x, y, z])
def read_trajectory_file(filename):
"""Read the binary trajectory file and return parsed data."""
with open(filename, 'rb') as f:
data = f.read()
reader = BinaryReader(data)
snapshots = []
try:
while reader.pos < len(data):
# Read snapshot
time = reader.read_f64()
# Read number of bodies (u64)
num_bodies = reader.read_u64()
bodies = []
for _ in range(num_bodies):
# Read Body struct
name = reader.read_string()
mass = reader.read_f64()
position = reader.read_vec3()
velocity = reader.read_vec3()
acceleration = reader.read_vec3()
bodies.append({
'name': name,
'mass': mass,
'position': position,
'velocity': velocity,
'acceleration': acceleration
})
snapshots.append({
'time': time,
'bodies': bodies
})
except struct.error:
# End of file or corrupted data
pass
return snapshots
def organize_trajectories(snapshots):
"""Organize snapshots into trajectories by body name."""
trajectories = defaultdict(list)
times = []
for snapshot in snapshots:
times.append(snapshot['time'])
for body in snapshot['bodies']:
trajectories[body['name']].append(body['position'])
# Convert lists to numpy arrays
for name in trajectories:
trajectories[name] = np.array(trajectories[name])
return dict(trajectories), np.array(times)
def plot_trajectories_2d(trajectories, times, center_body=None):
"""Plot 2D trajectories (X-Y plane)."""
plt.figure(figsize=(12, 10))
colors = plt.cm.tab10(np.linspace(0, 1, len(trajectories)))
for i, (name, positions) in enumerate(trajectories.items()):
x = positions[:, 0]
y = positions[:, 1]
plt.plot(x, y, color=colors[i], alpha=0.7, linewidth=1.5, label=name)
# Mark starting position
plt.plot(x[0], y[0], 'o', color=colors[i], markersize=8, alpha=0.8)
# Mark current position
plt.plot(x[-1], y[-1], 's', color=colors[i], markersize=6, alpha=0.8)
plt.xlabel('X Position (m)')
plt.ylabel('Y Position (m)')
title = 'Orbital Trajectories (X-Y Plane)'
if center_body:
title += f' - Centered on {center_body}'
plt.title(title)
plt.legend()
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.tight_layout()
def plot_trajectories_3d(trajectories, times, center_body=None):
"""Plot 3D trajectories."""
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')
colors = plt.cm.tab10(np.linspace(0, 1, len(trajectories)))
for i, (name, positions) in enumerate(trajectories.items()):
x = positions[:, 0]
y = positions[:, 1]
z = positions[:, 2]
ax.plot(x, y, z, color=colors[i], alpha=0.7, linewidth=1.5, label=name)
# Mark starting position
ax.scatter(x[0], y[0], z[0], color=colors[i], s=100, alpha=0.8, marker='o')
# Mark current position
ax.scatter(x[-1], y[-1], z[-1], color=colors[i], s=60, alpha=0.8, marker='s')
ax.set_xlabel('X Position (m)')
ax.set_ylabel('Y Position (m)')
ax.set_zlabel('Z Position (m)')
title = 'Orbital Trajectories (3D)'
if center_body:
title += f' - Centered on {center_body}'
ax.set_title(title)
ax.legend()
# Make axes equal
max_range = 0
for positions in trajectories.values():
range_val = np.max(np.abs(positions))
max_range = max(max_range, range_val)
ax.set_xlim(-max_range, max_range)
ax.set_ylim(-max_range, max_range)
ax.set_zlim(-max_range, max_range)
def plot_energy_over_time(snapshots, times):
"""Plot energy evolution over time."""
plt.figure(figsize=(12, 6))
total_energies = []
kinetic_energies = []
for snapshot in snapshots:
ke = 0
pe = 0
bodies = snapshot['bodies']
# Calculate kinetic energy
for body in bodies:
v_squared = np.sum(body['velocity']**2)
ke += 0.5 * body['mass'] * v_squared
# Calculate potential energy (simplified, assuming G=1 in normalized units)
G = 6.67430e-11 # You might need to adjust this based on your normalization
for i in range(len(bodies)):
for j in range(i+1, len(bodies)):
r = np.linalg.norm(bodies[i]['position'] - bodies[j]['position'])
pe -= G * bodies[i]['mass'] * bodies[j]['mass'] / r
kinetic_energies.append(ke)
total_energies.append(ke + pe)
plt.plot(times, kinetic_energies, label='Kinetic Energy', alpha=0.8)
plt.plot(times, total_energies, label='Total Energy', alpha=0.8)
plt.xlabel('Time (s)')
plt.ylabel('Energy (J)')
plt.title('Energy Conservation Over Time')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
def plot_animated_2d(trajectories, times, interval=50, center_body=None):
"""Create animated 2D trajectory plot."""
fig, ax = plt.subplots(figsize=(12, 10))
colors = plt.cm.tab10(np.linspace(0, 1, len(trajectories)))
body_names = list(trajectories.keys())
# Set up the plot
ax.set_xlabel('X Position (m)')
ax.set_ylabel('Y Position (m)')
title = 'Animated Orbital Trajectories (X-Y Plane)'
if center_body:
title += f' - Centered on {center_body}'
ax.set_title(title)
ax.grid(True, alpha=0.3)
# Calculate plot limits
all_positions = np.concatenate(list(trajectories.values()))
margin = 0.1
x_range = np.max(all_positions[:, 0]) - np.min(all_positions[:, 0])
y_range = np.max(all_positions[:, 1]) - np.min(all_positions[:, 1])
x_center = (np.max(all_positions[:, 0]) + np.min(all_positions[:, 0])) / 2
y_center = (np.max(all_positions[:, 1]) + np.min(all_positions[:, 1])) / 2
max_range = max(x_range, y_range) * (1 + margin)
ax.set_xlim(x_center - max_range/2, x_center + max_range/2)
ax.set_ylim(y_center - max_range/2, y_center + max_range/2)
ax.set_aspect('equal')
# Initialize plot elements
trajectory_lines = []
body_points = []
body_trails = []
for i, name in enumerate(body_names):
# Trajectory line (will grow over time)
line, = ax.plot([], [], color=colors[i], alpha=0.7, linewidth=1.5, label=name)
trajectory_lines.append(line)
# Current body position
point, = ax.plot([], [], 'o', color=colors[i], markersize=8, alpha=0.9)
body_points.append(point)
# Trail of recent positions
trail, = ax.plot([], [], 'o', color=colors[i], markersize=3, alpha=0.3)
body_trails.append(trail)
# Time display
time_text = ax.text(0.02, 0.98, '', transform=ax.transAxes, fontsize=12,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
ax.legend(loc='upper right')
def animate(frame):
current_time = times[frame]
# Format time for display
if current_time >= 86400: # More than a day
time_str = f"Time: {current_time/86400:.1f} days"
elif current_time >= 3600: # More than an hour
time_str = f"Time: {current_time/3600:.1f} hours"
else:
time_str = f"Time: {current_time:.1f} seconds"
time_text.set_text(time_str)
# Update each body
for i, name in enumerate(body_names):
positions = trajectories[name]
# Update trajectory line (show path up to current time)
x_data = positions[:frame+1, 0]
y_data = positions[:frame+1, 1]
trajectory_lines[i].set_data(x_data, y_data)
# Update current position
if frame < len(positions):
current_pos = positions[frame]
body_points[i].set_data([current_pos[0]], [current_pos[1]])
# Update trail (last 20 positions)
trail_start = max(0, frame - 20)
trail_x = positions[trail_start:frame, 0]
trail_y = positions[trail_start:frame, 1]
body_trails[i].set_data(trail_x, trail_y)
return trajectory_lines + body_points + body_trails + [time_text]
num_frames = len(times)
anim = animation.FuncAnimation(fig, animate, frames=num_frames,
interval=interval, blit=True, repeat=True)
return fig, anim
def plot_animated_3d(trajectories, times, interval=50, center_body=None):
"""Create animated 3D trajectory plot."""
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')
colors = plt.cm.tab10(np.linspace(0, 1, len(trajectories)))
body_names = list(trajectories.keys())
# Set up the plot
ax.set_xlabel('X Position (m)')
ax.set_ylabel('Y Position (m)')
ax.set_zlabel('Z Position (m)')
title = 'Animated Orbital Trajectories (3D)'
if center_body:
title += f' - Centered on {center_body}'
ax.set_title(title)
# Calculate plot limits
all_positions = np.concatenate(list(trajectories.values()))
max_range = np.max(np.abs(all_positions)) * 1.1
ax.set_xlim(-max_range, max_range)
ax.set_ylim(-max_range, max_range)
ax.set_zlim(-max_range, max_range)
# Initialize plot elements
trajectory_lines = []
body_points = []
for i, name in enumerate(body_names):
# Trajectory line
line, = ax.plot([], [], [], color=colors[i], alpha=0.7, linewidth=1.5, label=name)
trajectory_lines.append(line)
# Current body position
point = ax.scatter([], [], [], color=colors[i], s=100, alpha=0.9)
body_points.append(point)
# Time display
time_text = ax.text2D(0.02, 0.98, '', transform=ax.transAxes, fontsize=12,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
ax.legend(loc='upper right')
def animate(frame):
current_time = times[frame]
# Format time for display
if current_time >= 86400: # More than a day
time_str = f"Time: {current_time/86400:.1f} days"
elif current_time >= 3600: # More than an hour
time_str = f"Time: {current_time/3600:.1f} hours"
else:
time_str = f"Time: {current_time:.1f} seconds"
time_text.set_text(time_str)
# Update each body
for i, name in enumerate(body_names):
positions = trajectories[name]
# Update trajectory line
x_data = positions[:frame+1, 0]
y_data = positions[:frame+1, 1]
z_data = positions[:frame+1, 2]
trajectory_lines[i].set_data(x_data, y_data)
trajectory_lines[i].set_3d_properties(z_data)
# Update current position
if frame < len(positions):
current_pos = positions[frame]
# Remove old scatter point and create new one
body_points[i].remove()
body_points[i] = ax.scatter([current_pos[0]], [current_pos[1]], [current_pos[2]],
color=colors[i], s=100, alpha=0.9)
return trajectory_lines + body_points + [time_text]
num_frames = len(times)
anim = animation.FuncAnimation(fig, animate, frames=num_frames,
interval=interval, blit=False, repeat=True)
return fig, anim
def calculate_animation_params(num_frames, target_duration_sec=60, manual_interval=None):
"""Calculate animation parameters for optimal viewing experience."""
if manual_interval is not None:
# User specified manual interval
interval_ms = manual_interval
total_duration_sec = num_frames * interval_ms / 1000.0
time_scale_factor = target_duration_sec / total_duration_sec
return interval_ms, total_duration_sec, time_scale_factor, True
# Auto-calculate interval for target duration
target_duration_ms = target_duration_sec * 1000
optimal_interval = max(10, target_duration_ms // num_frames) # Minimum 10ms for smooth animation
actual_duration_sec = num_frames * optimal_interval / 1000.0
time_scale_factor = target_duration_sec / actual_duration_sec
return optimal_interval, actual_duration_sec, time_scale_factor, False
def center_trajectories_on_body(trajectories, center_body_name):
"""Center all trajectories relative to the specified body."""
if center_body_name not in trajectories:
available_bodies = list(trajectories.keys())
raise ValueError(f"Body '{center_body_name}' not found. Available bodies: {available_bodies}")
center_trajectory = trajectories[center_body_name]
centered_trajectories = {}
for body_name, trajectory in trajectories.items():
# Subtract the center body's position from each body's trajectory
centered_trajectories[body_name] = trajectory - center_trajectory
return centered_trajectories
def main():
parser = argparse.ArgumentParser(description='Plot orbital trajectories from binary file')
parser.add_argument('trajectory_file', help='Binary trajectory file to plot')
parser.add_argument('--2d-only', action='store_true', dest='two_d_only', help='Only show 2D plot')
parser.add_argument('--3d-only', action='store_true', dest='three_d_only', help='Only show 3D plot')
parser.add_argument('--energy', action='store_true', help='Show energy plot')
parser.add_argument('--animate', action='store_true', help='Show animated trajectories')
parser.add_argument('--save-animation', type=str, help='Save animation as MP4 file')
parser.add_argument('--static', action='store_true', help='Show static plots (default if no --animate)')
parser.add_argument('--interval', type=int, help='Animation interval in milliseconds (default: auto-scaled to ~60s total)')
parser.add_argument('--target-duration', type=int, default=60, help='Target animation duration in seconds (default: 60)')
parser.add_argument('--center', type=str, help='Center animation on specified body (e.g., "Sun", "Earth")')
parser.add_argument('--list-bodies', action='store_true', help='List available bodies and exit')
args = parser.parse_args()
print(f"Reading trajectory file: {args.trajectory_file}")
try:
snapshots = read_trajectory_file(args.trajectory_file)
print(f"Loaded {len(snapshots)} snapshots")
if not snapshots:
print("No data found in file!")
return
trajectories, times = organize_trajectories(snapshots)
# Handle list-bodies option
if args.list_bodies:
print(f"Available bodies in trajectory file:")
for body_name in sorted(trajectories.keys()):
print(f" - {body_name}")
return
print(f"Bodies found: {list(trajectories.keys())}")
print(f"Time range: {times[0]:.2e} - {times[-1]:.2e} seconds")
print(f"Number of time steps: {len(times)}")
# Center trajectories on specified body if requested
original_trajectories = trajectories.copy()
if args.center:
try:
trajectories = center_trajectories_on_body(trajectories, args.center)
print(f"🎯 Centering animation on: {args.center}")
except ValueError as e:
print(f"❌ Error: {e}")
return
# Check if we should animate or show static plots
show_animation = args.animate or args.save_animation
if show_animation:
# Calculate animation parameters
interval_ms, anim_duration_sec, time_scale, is_manual = calculate_animation_params(
len(times), args.target_duration, args.interval
)
print(f"\n🎬 Animation Settings:")
print(f" Total frames: {len(times)}")
print(f" Animation duration: {anim_duration_sec:.1f} seconds")
print(f" Frame interval: {interval_ms}ms")
if is_manual:
print(f" ⚙️ Using manual interval (--interval {args.interval})")
if time_scale != 1.0:
print(f" ⏱️ Time scale: {time_scale:.2f}x (animation {'faster' if time_scale > 1 else 'slower'} than target)")
else:
print(f" 🤖 Auto-scaled for {args.target_duration}s target duration")
print(f" ⏱️ Time scale: 1.0x (optimized)")
simulation_duration = times[-1] - times[0]
if simulation_duration > 0:
compression_ratio = simulation_duration / anim_duration_sec
if compression_ratio >= 86400:
print(f" 📈 Compression: {compression_ratio/86400:.1f} days of simulation per second of animation")
elif compression_ratio >= 3600:
print(f" 📈 Compression: {compression_ratio/3600:.1f} hours of simulation per second of animation")
else:
print(f" 📈 Compression: {compression_ratio:.1f}x real-time")
print()
print("Creating animated plots...")
animations = []
# Create animated plots
if not args.three_d_only:
print("Creating 2D animation...")
fig_2d, anim_2d = plot_animated_2d(trajectories, times, interval_ms, args.center)
animations.append((fig_2d, anim_2d, '2d'))
if not args.two_d_only:
print("Creating 3D animation...")
fig_3d, anim_3d = plot_animated_3d(trajectories, times, interval_ms, args.center)
animations.append((fig_3d, anim_3d, '3d'))
# Save animations if requested
if args.save_animation:
for fig, anim, plot_type in animations:
filename = f"{args.save_animation}_{plot_type}.mp4"
print(f"Saving {plot_type.upper()} animation to {filename}...")
try:
anim.save(filename, writer='ffmpeg', fps=20)
print(f"Animation saved to {filename}")
except Exception as e:
print(f"Error saving animation: {e}")
print("Note: You may need to install ffmpeg for video export")
plt.show()
else:
print("Creating static plots...")
# Use original trajectories for static plots unless centering is requested
plot_trajectories = trajectories if args.center else original_trajectories
# Plot static trajectories
if not args.three_d_only:
plot_trajectories_2d(plot_trajectories, times, args.center)
if not args.two_d_only:
plot_trajectories_3d(plot_trajectories, times, args.center)
if args.energy:
plot_energy_over_time(snapshots, times)
plt.show()
except FileNotFoundError:
print(f"Error: File '{args.trajectory_file}' not found!")
except Exception as e:
print(f"Error reading file: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()