Training

Run and manage training jobs on Riven's distributed compute infrastructure.

Training Overview

The Riven AI Platform provides a managed training infrastructure that handles resource provisioning, job scheduling, and artifact management. Training jobs run as isolated Kubernetes pods with dedicated GPU allocation, and all outputs — checkpoints, logs, and metrics — are automatically persisted to the platform's object store.

The training lifecycle follows four stages: Pending (waiting for resources), Running (actively training), Completed (finished successfully), and Failed (terminated with errors). You can monitor each stage through the CLI, API, or Grafana dashboards.

Creating a Training Job

Training jobs can be submitted via the CLI or the REST API. The CLI provides a streamlined workflow for common scenarios:

Terminal
bash
# Submit a training job from a config file
riven train submit --config train-config.yaml
 
# Submit with inline overrides
riven train submit \
  --image riven/pytorch-trainer:latest \
  --gpu 4 \
  --gpu-type a100 \
  --dataset s3://datasets/my-corpus \
  --output s3://models/my-model/v2 \
  --epochs 3 \
  --batch-size 32

Training Configuration

Here is a complete training configuration with dataset format and expected outputs:

train-config.yaml
yaml
name: qwen3-summarization-finetune
image: riven/pytorch-trainer:latest
 
model:
  base: "Qwen/Qwen3-8B"
  output: s3://models/qwen3-summarization/v1
 
dataset:
  train: s3://datasets/code-summaries/train.jsonl
  eval: s3://datasets/code-summaries/eval.jsonl
  format: jsonl    # Each line: {"input": "...", "output": "..."}
 
resources:
  gpu: 4
  gpu_type: a100
  memory: 64Gi
  cpu: 16
 
config:
  epochs: 3
  batch_size: 8            # per-GPU batch size
  learning_rate: 2e-5
  lr_scheduler: cosine
  warmup_ratio: 0.1
  weight_decay: 0.01
  gradient_accumulation_steps: 4
  max_seq_length: 4096
  fp16: true
  save_steps: 500
  eval_steps: 250
  logging_steps: 10
 
checkpointing:
  strategy: best           # best, last, every_n_steps
  metric: eval_loss
  max_checkpoints: 3

Dataset Format

Training datasets use JSONL format with one example per line:

train.jsonl
json
{"input": "def fibonacci(n):\n    if n <= 1:\n        return n\n    return fibonacci(n-1) + fibonacci(n-2)", "output": "Recursive implementation of the Fibonacci sequence. Returns the nth Fibonacci number."}
{"input": "class UserService:\n    def __init__(self, db):\n        self.db = db\n    async def get_user(self, id):\n        return await self.db.users.find_one({'_id': id})", "output": "User service with MongoDB backend. Provides async user lookup by ID."}

You can also submit jobs programmatically via the REST API:

api-example.py
python
import requests
 
response = requests.post(
    "https://api.riven-ai.dev/v1/training/jobs",
    headers={"Authorization": "Bearer <token>"},
    json={
        "name": "qwen3-finetune",
        "image": "riven/pytorch-trainer:latest",
        "resources": {"gpu": 4, "gpu_type": "a100"},
        "config": {
            "dataset": "s3://datasets/my-corpus",
            "epochs": 3,
            "batch_size": 32,
            "learning_rate": 2e-5,
        },
    },
)
print(response.json()["job_id"])

Distributed Training

For large models that require multiple nodes, the platform supports PyTorch Distributed Data Parallel (DDP) out of the box. Specify the number of nodes and GPUs per node, and the platform handles worker coordination, NCCL setup, and failure recovery.

distributed-config.yaml
yaml
name: qwen3-distributed-finetune
image: riven/pytorch-trainer:latest
 
distributed:
  strategy: ddp
  nodes: 4
  gpus_per_node: 8
  backend: nccl
 
resources:
  gpu_type: a100
  memory: 64Gi
  cpu: 16
 
config:
  dataset: s3://datasets/large-corpus
  epochs: 5
  batch_size: 16  # per-GPU batch size
  learning_rate: 1e-5
  gradient_accumulation_steps: 4

Distributed training requires at least 2 available nodes with matching GPU types. Check resource availability with riven cluster resources before submitting large jobs.

Hyperparameter Tuning

The platform includes a built-in hyperparameter sweep engine. Define a search space in your config and the platform will launch parallel trials, track metrics, and select the best configuration.

sweep-config.yaml
yaml
name: lr-sweep
base_config: train-config.yaml
 
sweep:
  strategy: bayesian   # or grid, random
  metric: eval_loss
  goal: minimize
  max_trials: 20
  parallel_trials: 4
 
parameters:
  learning_rate:
    distribution: log_uniform
    min: 1e-6
    max: 1e-3
  batch_size:
    values: [8, 16, 32, 64]
  warmup_ratio:
    distribution: uniform
    min: 0.0
    max: 0.2
Terminal
bash
# Launch a hyperparameter sweep
riven train sweep --config sweep-config.yaml
 
# Monitor sweep progress
riven train sweep status <sweep-id>

Monitoring Jobs

All training jobs emit real-time metrics that are collected by Prometheus and visualized in pre-built Grafana dashboards. Key metrics include:

  • Training Loss — Per-step and smoothed loss curves.
  • Evaluation Metrics — Validation loss, accuracy, and custom metrics at each checkpoint.
  • GPU Utilization — Per-GPU compute and memory utilization.
  • Throughput — Tokens per second and samples per second.

Example Metric Output

A typical training run produces output like:

Training output
text
Step  100/3000 | Loss: 2.341 | LR: 1.2e-5 | GPU Mem: 21.3/24.0 GB | Tokens/s: 12,450
Step  200/3000 | Loss: 1.892 | LR: 1.8e-5 | GPU Mem: 21.3/24.0 GB | Tokens/s: 12,520
Step  250/3000 | Eval Loss: 1.756 | Eval Accuracy: 0.723
Step  300/3000 | Loss: 1.654 | LR: 2.0e-5 | GPU Mem: 21.3/24.0 GB | Tokens/s: 12,480
Step  500/3000 | Loss: 1.203 | LR: 2.0e-5 | GPU Mem: 21.3/24.0 GB | Tokens/s: 12,510
Step  500/3000 | Checkpoint saved → s3://models/qwen3-summarization/v1/checkpoint-500
Terminal
bash
# View job status and metrics
riven train status <job-id>
 
# Stream live logs
riven train logs <job-id> --follow
 
# Open Grafana dashboard for a job
riven train dashboard <job-id>

Training logs and checkpoints are retained for 30 days by default. You can configure retention policies per project in the platform settings.