Python SDK Reference
The dtruntime Python package provides high-performance bindings to the Strata distributed runtime. It allows ML engineers to integrate distributed data loading and coordinated checkpointing into PyTorch or JAX training loops with minimal overhead, leveraging the underlying Rust core for non-blocking I/O.
Installation
The SDK is available as a pre-compiled wheel or can be built from source:
pip install dtruntime
TrainingOrchestrator
The TrainingOrchestrator is the primary interface for worker-to-coordinator communication. It manages the worker lifecycle, including registration, heartbeats, and synchronization barriers.
Constructor
from dtruntime import TrainingOrchestrator, WorkerConfig
config = WorkerConfig(
coordinator_url="http://coordinator-service:50051",
gpu_count=8,
heartbeat_interval_ms=5000
)
orchestrator = TrainingOrchestrator(config)
Methods
register()
Registers the worker with the global coordinator. This must be called before any data or checkpoint operations.
- Returns:
str(The unique Worker ID assigned by the coordinator).
barrier(name: str, timeout_ms: int = 30000)
Blocks execution until all workers in the training job reach this point.
- Arguments:
name: Unique identifier for the sync point (e.g., "epoch_start").timeout_ms: Maximum time to wait before raising a timeout error.
get_status()
Returns the current state of the worker as tracked by the coordinator.
DatasetRegistry
The DatasetRegistry manages distributed data sharding using consistent hashing. It ensures that data is partitioned evenly across workers and handles re-sharding automatically if a worker fails.
Methods
register_dataset(dataset_id: str, total_samples: int, shard_size: int, shuffle: bool = True)
Registers a new dataset with the coordinator.
- Arguments:
dataset_id: Unique name for the dataset.total_samples: Total number of records in the dataset.shard_size: Number of samples per shard.shuffle: Whether to enable deterministic shuffling across epochs.
get_shards(dataset_id: str, epoch: int) -> List[ShardInfo]
Retrieves the specific shard assignments for the local worker for a given epoch.
- Arguments:
dataset_id: The ID of the registered dataset.epoch: The current training epoch.
- Returns: A list of
ShardInfoobjects.
ShardInfo (Data Object)
| Property | Type | Description |
| :--- | :--- | :--- |
| shard_id | int | Unique index of the shard. |
| start_index | int | Global start index of the first sample in this shard. |
| end_index | int | Global end index (exclusive). |
| file_paths | List[str] | Physical paths or URIs associated with this shard. |
CheckpointManager
The CheckpointManager handles asynchronous state persistence to local storage or S3. It coordinates with the Rust AsyncCheckpointWriter to ensure model weights and optimizer states are saved without stalling the training loop.
Constructor
from dtruntime import CheckpointManager
manager = CheckpointManager(
storage_path="s3://my-bucket/checkpoints",
keep_count=5 # Number of recent checkpoints to retain
)
Methods
save(state_dict: bytes, step: int, epoch: int, is_full: bool = True)
Streams a state dictionary to the configured storage backend.
- Arguments:
state_dict: The serialized model/optimizer state (typically viaio.BytesIO).step: Current training step.epoch: Current training epoch.is_full: IfFalse, marks the checkpoint as incremental.
load_latest() -> CheckpointInfo
Retrieves metadata for the most recent successful checkpoint.
get_checkpoint(checkpoint_id: str) -> bytes
Downloads and returns the raw bytes of a specific checkpoint.
Integration Example: PyTorch
import torch
import io
from dtruntime import TrainingOrchestrator, DatasetRegistry, CheckpointManager
# 1. Initialize Runtime
orchestrator = TrainingOrchestrator(...)
dataset_reg = DatasetRegistry()
ckpt_manager = CheckpointManager("s3://checkpoints/llama-7b")
worker_id = orchestrator.register()
# 2. Get Data Assignments
shards = dataset_reg.get_shards("imagenet", epoch=0)
# Use shards to initialize your DataLoader...
# 3. Training Loop
for epoch in range(10):
for step, batch in enumerate(dataloader):
# ... train ...
if step % 1000 == 0:
# Sync workers before checkpoint
orchestrator.barrier(f"ckpt_{step}")
if worker_id == "worker-0": # Primary worker saves state
buffer = io.BytesIO()
torch.save(model.state_dict(), buffer)
ckpt_manager.save(buffer.getvalue(), step, epoch)
orchestrator.barrier(f"epoch_{epoch}_complete")
Error Handling
The SDK raises specific exceptions mapped from the Rust core:
dtruntime.ConnectionError: Failed to reach the coordinator.dtruntime.StorageError: Issue writing to S3 or local disk.dtruntime.SyncError: Barrier timeout or membership mismatch.dtruntime.ShardError: Dataset not found or invalid shard request.