Distributed Checkpointing
The checkpointing system in Strata is designed for high-throughput model state persistence across distributed worker nodes. It utilizes an asynchronous, non-blocking architecture that supports multiple storage backends, including local filesystems and Amazon S3.
Core Concepts
Checkpointing in Strata is governed by a Manager-Writer pattern:
- CheckpointManager: Orchestrates the high-level lifecycle, including versioning, metadata tracking, and retention policies.
- AsyncCheckpointWriter: Handles the low-level data stream, ensuring that I/O operations do not block the main training loop.
Checkpoint Types
Strata supports granular checkpointing strategies to balance recovery speed with storage overhead:
| Type | Description |
| :--- | :--- |
| Full | A complete snapshot of model weights and optimizer state. |
| Incremental | Stores only the delta from the previous checkpoint. |
| OptimizerOnly | Persists optimizer states (e.g., Adam moments) without weights. |
| ModelOnly | Persists model weights only, typically used for inference exports. |
Metadata Structure
Every checkpoint is accompanied by a CheckpointMetadata object, which is used by the Coordinator to track training progress and facilitate recovery.
pub struct CheckpointMetadata {
pub id: String, // Unique identifier
pub step: u64, // Training step
pub epoch: u64, // Training epoch
pub path: String, // Storage URI (e.g., s3://bucket/path)
pub size_bytes: u64, // Checkpoint size
pub checkpoint_type: CheckpointType,
pub model_hash: Option<String>, // For integrity verification
pub metadata: HashMap<String, String>, // Custom user attributes
}
Using the Checkpoint Manager (Rust)
The CheckpointManager is the primary interface for workers to persist state. It is initialized with a configuration that defines the storage backend and retention logic.
Initialization
use checkpoint::{CheckpointManager, CheckpointManagerConfig};
let config = CheckpointManagerConfig {
storage_backend: "s3".to_string(),
checkpoint_dir: "checkpoints/my-experiment".to_string(),
keep_count: 5, // Retention: keep the 5 most recent checkpoints
};
let manager = CheckpointManager::new(config).await?;
Saving a Checkpoint
Checkpointing is asynchronous. The save method returns a handle or future, allowing the training loop to continue while data is streamed to storage.
// Example: Saving model state from a worker
let state_bytes = serialize_model(model);
manager.save(
state_bytes,
step,
epoch,
CheckpointType::Full
).await?;
Python Bindings
For ML practitioners, Strata provides high-level Python bindings via PyO3, allowing seamless integration with PyTorch or JAX training scripts.
from strata import CheckpointManager
# Initialize the manager
ckpt_manager = CheckpointManager(
storage_path="s3://my-bucket/checkpoints",
keep_count=10
)
# Save during training loop
if step % 1000 == 0:
# model_bytes is a bytes-like object (e.g., from torch.save into a buffer)
ckpt_info = ckpt_manager.save(
model_bytes,
step=step,
epoch=current_epoch
)
print(f"Checkpoint saved to {ckpt_info.path}")
Coordination and Consistency
The Coordinator service tracks all active checkpoints across the cluster. When a worker finishes writing a checkpoint, it sends a CheckpointAck to the coordinator.
Barrier Synchronization
To ensure consistent global checkpoints across hundreds of workers, Strata uses a Barrier Sync mechanism. Workers reach a training step boundary and wait for the Coordinator to release the barrier before initiating the checkpoint write.
Key Metrics for Monitoring:
- Checkpoint Throughput: Observed at ~500 MB/s for local NVMe and ~200 MB/s for S3.
- Barrier Latency: The p99 latency for 100 workers to synchronize is typically <50ms.
Recovery Workflow
In the event of a worker or node failure, the system performs the following steps:
- Detection: The Coordinator detects a heartbeat timeout.
- Lookup: The Coordinator identifies the latest successful
Fullcheckpoint for the specific task and shard. - Reassignment: The
ShardManagerreassigns the failed worker's shards to an active worker. - Restoration: The new worker queries the Coordinator for the
RecoveryResponse, which contains the path to the last valid checkpoint. - Resume: The worker loads the state and resumes training from the exact
epochandstepstored in the metadata.