Skip to content

The OG Framework

Overview

The Overfit-to-Generalization (OG) framework is a two-stage machine learning approach designed to improve prediction quality in spatiotemporal tasks, particularly when dealing with heterogeneous data density.

Motivation

Traditional machine learning models face challenges with spatiotemporal data:

Challenge Description
Density Imbalance Some regions have dense observations while others are sparse
Local Overfitting Models may memorize patterns in dense areas without learning generalizable features
Spatial Heterogeneity Different spatial regions may have fundamentally different relationships

Two-Stage Architecture

Stage 1: High-Variance (HV) Model

The HV model is designed to:

  1. Capture local patterns through high model complexity
  2. Generate pseudo-labels for training the LV model
  3. Apply density-aware sampling to balance learning across regions
# HV model options
hv_models = ['lightgbm', 'xgboost', 'catboost', 'random_forest']

Oscillation Noise

During pseudo-label generation, oscillation noise is injected into features:

\[X_{noisy} = X + \epsilon \cdot \sigma_X \cdot \mathcal{N}(0, 1)\]

where \(\epsilon\) is the oscillation parameter (typically 0.05) and \(\sigma_X\) is the feature standard deviation.

This regularization technique helps the LV model learn smoother, more generalizable patterns.

Stage 2: Low-Variance (LV) Model

The LV model is trained to:

  1. Learn from pseudo-labels generated by the HV model
  2. Capture generalizable patterns through controlled model capacity
  3. Apply density-aware weighting during training
# LV model options
lv_models = ['mlp', 'bigmlp', 'resnet', 'transformer']

Density-Aware Sampling

A key innovation in OG is density-aware sampling, which prioritizes learning from sparse regions:

\[w_i = \frac{1}{density_i^\alpha}\]

where: - \(w_i\) is the sampling weight for observation \(i\) - \(density_i\) is the local data density - \(\alpha\) is the sampling parameter (typically 0.1)

This ensures the model doesn't overfit to dense regions at the expense of sparse areas.


Using OGModel

Basic Usage

from og_learn import OGModel

model = OGModel(
    hv='lightgbm',       # HV model type
    lv='mlp',            # LV model type
    oscillation=0.05,    # Noise injection strength
    sampling_alpha=0.1,  # Density sampling exponent
    epochs=100,          # Training epochs for LV
    seed=42              # Random seed for reproducibility
)

model.fit(X_train, y_train, density=density_train)
predictions = model.predict(X_test)

With Validation and Early Stopping

model = OGModel(
    hv='lightgbm',
    lv='resnet',
    oscillation=0.05,
    sampling_alpha=0.1,
    epochs=200,
    early_stopping=True,
    patience=20
)

model.fit(
    X_train, y_train,
    density=density_train,
    X_valid=X_valid,
    y_valid=y_valid
)

With TensorBoard Logging

model = OGModel(
    hv='lightgbm',
    lv='mlp',
    tensorboard_dir='runs/experiment1',
    tensorboard_name='og_mlp',
    eval_every_epochs=5
)

model.fit(X_train, y_train, density=density_train)

Model Comparison

Use compare_models to benchmark different configurations:

from og_learn import OGModel, compare_models
from og_learn.presets import get_lv_model

# Define models to compare
models = {
    'MLP': get_lv_model('mlp', num_features=X_train.shape[1]),
    'OG_MLP': OGModel(hv='lightgbm', lv='mlp'),
    'OG_ResNet': OGModel(hv='lightgbm', lv='resnet'),
}

# Run comparison
results = compare_models(
    models,
    X_train, y_train,
    X_test, y_test,
    density=density_train,
    tensorboard_dir='runs/comparison',
    save_dir='checkpoints'
)

# View results
print(results)

Output:

============================================================
         Model Comparison
============================================================
              Model  Train_R2  Test_R2
0               MLP    0.6012   0.5743
1            OG_MLP    0.6524   0.6127
2          OG_ResNet   0.6891   0.6352
============================================================

Best Practices

Choosing HV Models

  • LightGBM: Fast, good default choice
  • XGBoost: Often slightly better accuracy
  • CatBoost: Better with categorical features

Choosing LV Models

  • MLP: Simple, fast, good baseline
  • ResNet: Better for complex patterns
  • Transformer: Best for sequential/temporal patterns

Common Pitfalls

  • Don't forget to provide density during training
  • Use validation data for early stopping with neural networks
  • Set seed for reproducibility