ml-failfast-validation
SKILL.md
ML Fail-Fast Validation
POC validation patterns to catch issues before committing to long-running ML experiments.
When to Use This Skill
Use this skill when:
- Starting a new ML experiment that will run for hours
- Validating model architecture before full training
- Checking gradient flow and data pipeline integrity
- Implementing POC validation checklists
- Debugging prediction collapse or gradient explosion issues
1. Why Fail-Fast?
| Without Fail-Fast | With Fail-Fast |
|---|---|
| Discover crash 4 hours in | Catch in 30 seconds |
| Debug from cryptic error | Clear error message |
| Lose GPU time | Validate before commit |
| Silent data issues | Explicit schema checks |
Principle: Validate everything that can go wrong BEFORE the expensive computation.
2. POC Validation Checklist
Minimum Viable POC (5 Checks)
def run_poc_validation():
"""Fast validation before full experiment."""
print("=" * 60)
print("FAIL-FAST POC VALIDATION")
print("=" * 60)
# [1/5] Model instantiation
print("\n[1/5] Model instantiation...")
model = create_model(architecture, input_size=n_features)
x = torch.randn(32, seq_len, n_features).to(device)
out = model(x)
assert out.shape == (32, 1), f"Output shape wrong: {out.shape}"
print(f" Input: (32, {seq_len}, {n_features}) -> Output: {out.shape}")
print(" Status: PASS")
# [2/5] Gradient flow
print("\n[2/5] Gradient flow...")
y = torch.randn(32, 1).to(device)
loss = F.mse_loss(out, y)
loss.backward()
grad_norms = [p.grad.norm().item() for p in model.parameters() if p.grad is not None]
assert len(grad_norms) > 0, "No gradients!"
assert all(np.isfinite(g) for g in grad_norms), "NaN/Inf gradients!"
print(f" Max grad norm: {max(grad_norms):.4f}")
print(" Status: PASS")
# [3/5] NDJSON artifact validation
print("\n[3/5] NDJSON artifact validation...")
log_path = output_dir / "experiment.jsonl"
with open(log_path, "a") as f:
f.write(json.dumps({"phase": "poc_start", "timestamp": datetime.now().isoformat()}) + "\n")
assert log_path.exists(), "Log file not created"
print(f" Log file: {log_path}")
print(" Status: PASS")
# [4/5] Epoch selector variation
print("\n[4/5] Epoch selector variation...")
epochs = []
for seed in [1, 2, 3]:
selector = create_selector()
# Simulate different validation results
for e in range(10, 201, 10):
selector.record(epoch=e, sortino=np.random.randn() * 0.1, sparsity=np.random.rand())
epochs.append(selector.select())
print(f" Selected epochs: {epochs}")
assert len(set(epochs)) > 1 or all(e == epochs[0] for e in epochs), "Selector not varying"
print(" Status: PASS")
# [5/5] Mini training (10 epochs)
print("\n[5/5] Mini training (10 epochs)...")
model = create_model(architecture, input_size=n_features).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
initial_loss = None
for epoch in range(10):
loss = train_one_epoch(model, train_loader, optimizer)
if initial_loss is None:
initial_loss = loss
print(f" Initial loss: {initial_loss:.4f}")
print(f" Final loss: {loss:.4f}")
print(" Status: PASS")
print("\n" + "=" * 60)
print("POC RESULT: ALL 5 CHECKS PASSED")
print("=" * 60)
Extended POC (10 Checks)
Add these for comprehensive validation:
# [6/10] Data loading
print("\n[6/10] Data loading...")
df = fetch_data(symbol, threshold)
assert len(df) > min_required_bars, f"Insufficient data: {len(df)} bars"
print(f" Loaded: {len(df):,} bars")
print(" Status: PASS")
# [7/10] Schema validation
print("\n[7/10] Schema validation...")
validate_schema(df, required_columns, "raw_data")
print(" Status: PASS")
# [8/10] Feature computation
print("\n[8/10] Feature computation...")
df = compute_features(df)
validate_schema(df, feature_columns, "features")
print(f" Features: {len(feature_columns)}")
print(" Status: PASS")
# [9/10] Prediction sanity
print("\n[9/10] Prediction sanity...")
preds = model(X_test).detach().cpu().numpy()
pred_std = preds.std()
target_std = y_test.std()
pred_ratio = pred_std / target_std
assert pred_ratio > 0.005, f"Predictions collapsed: ratio={pred_ratio:.4f}"
print(f" Pred std ratio: {pred_ratio:.2%}")
print(" Status: PASS")
# [10/10] Checkpoint save/load
print("\n[10/10] Checkpoint save/load...")
torch.save(model.state_dict(), checkpoint_path)
model2 = create_model(architecture, input_size=n_features)
model2.load_state_dict(torch.load(checkpoint_path))
print(" Status: PASS")
3. Schema Validation Pattern
The Problem
# BAD: Cryptic error 2 hours into experiment
KeyError: 'returns_vs' # Which file? Which function? What columns exist?
The Solution
def validate_schema(df, required: list[str], stage: str) -> None:
"""Fail-fast schema validation with actionable error messages."""
# Handle both DataFrame columns and DatetimeIndex
available = list(df.columns)
if hasattr(df.index, 'name') and df.index.name:
available.append(df.index.name)
missing = [c for c in required if c not in available]
if missing:
raise ValueError(
f"[{stage}] Missing columns: {missing}\n"
f"Available: {sorted(available)}\n"
f"DataFrame shape: {df.shape}"
)
print(f" Schema validation PASSED ({stage}): {len(required)} columns", flush=True)
# Usage at pipeline boundaries
REQUIRED_RAW = ["open", "high", "low", "close", "volume"]
REQUIRED_FEATURES = ["returns_vs", "momentum_z", "atr_pct", "volume_z",
"rsi_14", "bb_pct_b", "vol_regime", "return_accel", "pv_divergence"]
df = fetch_data(symbol)
validate_schema(df, REQUIRED_RAW, "raw_data")
df = compute_features(df)
validate_schema(df, REQUIRED_FEATURES, "features")
4. Gradient Health Checks
Basic Gradient Check
def check_gradient_health(model: nn.Module, sample_input: torch.Tensor) -> dict:
"""Verify gradients flow correctly through model."""
model.train()
out = model(sample_input)
loss = out.sum()
loss.backward()
stats = {"total_params": 0, "params_with_grad": 0, "grad_norms": []}
for name, param in model.named_parameters():
stats["total_params"] += 1
if param.grad is not None:
stats["params_with_grad"] += 1
norm = param.grad.norm().item()
stats["grad_norms"].append(norm)
# Check for issues
if not np.isfinite(norm):
raise ValueError(f"Non-finite gradient in {name}: {norm}")
if norm > 100:
print(f" WARNING: Large gradient in {name}: {norm:.2f}")
stats["max_grad"] = max(stats["grad_norms"]) if stats["grad_norms"] else 0
stats["mean_grad"] = np.mean(stats["grad_norms"]) if stats["grad_norms"] else 0
return stats
Architecture-Specific Checks
def check_lstm_gradients(model: nn.Module) -> dict:
"""Check LSTM-specific gradient patterns."""
stats = {}
for name, param in model.named_parameters():
if param.grad is None:
continue
# Check forget gate bias (should not be too negative)
if "bias_hh" in name or "bias_ih" in name:
# LSTM bias: [i, f, g, o] gates
hidden_size = param.shape[0] // 4
forget_bias = param.grad[hidden_size:2*hidden_size]
stats["forget_bias_grad_mean"] = forget_bias.mean().item()
# Check hidden-to-hidden weights
if "weight_hh" in name:
stats["hh_weight_grad_norm"] = param.grad.norm().item()
return stats
5. Prediction Sanity Checks
Collapse Detection
def check_prediction_sanity(preds: np.ndarray, targets: np.ndarray) -> dict:
"""Detect prediction collapse or explosion."""
stats = {
"pred_mean": preds.mean(),
"pred_std": preds.std(),
"pred_min": preds.min(),
"pred_max": preds.max(),
"target_std": targets.std(),
}
# Relative threshold (not absolute!)
stats["pred_std_ratio"] = stats["pred_std"] / stats["target_std"]
# Collapse detection
if stats["pred_std_ratio"] < 0.005: # < 0.5% of target variance
raise ValueError(
f"Predictions collapsed!\n"
f" pred_std: {stats['pred_std']:.6f}\n"
f" target_std: {stats['target_std']:.6f}\n"
f" ratio: {stats['pred_std_ratio']:.4%}"
)
# Explosion detection
if stats["pred_std_ratio"] > 100: # > 100x target variance
raise ValueError(
f"Predictions exploded!\n"
f" pred_std: {stats['pred_std']:.2f}\n"
f" target_std: {stats['target_std']:.6f}\n"
f" ratio: {stats['pred_std_ratio']:.1f}x"
)
# Unique value check
stats["unique_values"] = len(np.unique(np.round(preds, 6)))
if stats["unique_values"] < 10:
print(f" WARNING: Only {stats['unique_values']} unique prediction values")
return stats
Correlation Check
def check_prediction_correlation(preds: np.ndarray, targets: np.ndarray) -> float:
"""Check if predictions have any correlation with targets."""
corr = np.corrcoef(preds.flatten(), targets.flatten())[0, 1]
if not np.isfinite(corr):
print(" WARNING: Correlation is NaN (likely collapsed predictions)")
return 0.0
# Note: negative correlation may still be useful (short signal)
print(f" Prediction-target correlation: {corr:.4f}")
return corr
6. NDJSON Logging Validation
Required Event Types
REQUIRED_EVENTS = {
"experiment_start": ["architecture", "features", "config"],
"fold_start": ["fold_id", "train_size", "val_size", "test_size"],
"epoch_complete": ["epoch", "train_loss", "val_loss"],
"fold_complete": ["fold_id", "test_sharpe", "test_sortino"],
"experiment_complete": ["total_folds", "mean_sharpe", "elapsed_seconds"],
}
def validate_ndjson_schema(log_path: Path) -> None:
"""Validate NDJSON log has all required events and fields."""
events = {}
with open(log_path) as f:
for line in f:
event = json.loads(line)
phase = event.get("phase", "unknown")
if phase not in events:
events[phase] = []
events[phase].append(event)
for phase, required_fields in REQUIRED_EVENTS.items():
if phase not in events:
raise ValueError(f"Missing event type: {phase}")
sample = events[phase][0]
missing = [f for f in required_fields if f not in sample]
if missing:
raise ValueError(f"Event '{phase}' missing fields: {missing}")
print(f" NDJSON schema valid: {len(events)} event types")
7. POC Timing Guide
| Check | Typical Time | Max Time | Action if Exceeded |
|---|---|---|---|
| Model instantiation | < 1s | 5s | Check device, reduce model size |
| Gradient flow | < 2s | 10s | Check batch size |
| Schema validation | < 0.1s | 1s | Check data loading |
| Mini training (10 epochs) | < 30s | 2min | Reduce batch, check data loader |
| Full POC (10 checks) | < 2min | 5min | Something is wrong |
8. Failure Response Guide
| Failure | Likely Cause | Fix |
|---|---|---|
| Shape mismatch | Wrong input_size or seq_len | Check feature count |
| NaN gradients | LR too high, bad init | Reduce LR, check init |
| Zero gradients | Dead layers, missing params | Check model architecture |
| Predictions collapsed | Normalizer issue, bad loss | Check sLSTM normalizer |
| Predictions exploded | Gradient explosion | Add/tighten gradient clipping |
| Schema missing columns | Wrong data source | Check fetch function |
| Checkpoint load fails | State dict key mismatch | Check model architecture match |
9. Integration Example
def main():
# Parse args, setup output dir...
# PHASE 1: Fail-fast POC
print("=" * 60)
print("FAIL-FAST POC VALIDATION")
print("=" * 60)
try:
run_poc_validation()
except Exception as e:
print(f"\n{'=' * 60}")
print(f"POC FAILED: {type(e).__name__}")
print(f"{'=' * 60}")
print(f"Error: {e}")
print("\nFix the issue before running full experiment.")
sys.exit(1)
# PHASE 2: Full experiment (only if POC passes)
print("\n" + "=" * 60)
print("STARTING FULL EXPERIMENT")
print("=" * 60)
run_full_experiment()
10. Anti-Patterns to Avoid
DON'T: Skip validation to "save time"
# BAD: "I'll just run it and see"
run_full_experiment() # 4 hours later: crash
DON'T: Use absolute thresholds for relative quantities
# BAD: Absolute threshold
assert pred_std > 1e-4 # Meaningless for returns ~0.001
# GOOD: Relative threshold
assert pred_std / target_std > 0.005 # 0.5% of target variance
DON'T: Catch all exceptions silently
# BAD: Hides real issues
try:
result = risky_operation()
except Exception:
result = default_value # What went wrong?
# GOOD: Catch specific exceptions
try:
result = risky_operation()
except (ValueError, RuntimeError) as e:
logger.error(f"Operation failed: {e}")
raise
DON'T: Print without flush
# BAD: Output buffered, can't see progress
print(f"Processing fold {i}...")
# GOOD: See output immediately
print(f"Processing fold {i}...", flush=True)
References
Troubleshooting
| Issue | Cause | Solution |
|---|---|---|
| NaN gradients in POC | Learning rate too high | Reduce LR by 10x, check weight initialization |
| Zero gradients | Dead layers or missing params | Check model architecture, verify requires_grad=True |
| Predictions collapsed | Normalizer issue or bad loss | Check target normalization, verify loss function |
| Predictions exploded | Gradient explosion | Add gradient clipping, reduce learning rate |
| Schema missing columns | Wrong data source or transform | Verify fetch function returns expected columns |
| Checkpoint load fails | State dict key mismatch | Ensure model architecture matches saved checkpoint |
| POC timeout (>5 min) | Data loading or model too large | Reduce batch size, check DataLoader num_workers |
| Mini training no progress | Learning rate too low or frozen | Increase LR, verify optimizer updates all parameters |
| NDJSON validation fails | Missing required event types | Check all phases emit expected fields |
| Shape mismatch error | Wrong input_size or seq_len | Verify feature count matches model input dimension |
Weekly Installs
44
Repository
terrylica/cc-skillsGitHub Stars
19
First Seen
Feb 7, 2026
Security Audits
Installed on
opencode44
github-copilot43
codex43
kimi-cli43
gemini-cli43
amp43