The OG Framework¶
Overview¶
The Overfit-to-Generalization (OG) framework is a two-stage machine learning approach designed to improve prediction quality in spatiotemporal tasks, particularly when dealing with heterogeneous data density.
Motivation¶
Traditional machine learning models face challenges with spatiotemporal data:
| Challenge | Description |
|---|---|
| Density Imbalance | Some regions have dense observations while others are sparse |
| Local Overfitting | Models may memorize patterns in dense areas without learning generalizable features |
| Spatial Heterogeneity | Different spatial regions may have fundamentally different relationships |
Two-Stage Architecture¶
Stage 1: High-Variance (HV) Model¶
The HV model is designed to:
- Capture local patterns through high model complexity
- Generate pseudo-labels for training the LV model
- Apply density-aware sampling to balance learning across regions
Oscillation Noise¶
During pseudo-label generation, oscillation noise is injected into features:
where \(\epsilon\) is the oscillation parameter (typically 0.05) and \(\sigma_X\) is the feature standard deviation.
This regularization technique helps the LV model learn smoother, more generalizable patterns.
Stage 2: Low-Variance (LV) Model¶
The LV model is trained to:
- Learn from pseudo-labels generated by the HV model
- Capture generalizable patterns through controlled model capacity
- Apply density-aware weighting during training
Density-Aware Sampling¶
A key innovation in OG is density-aware sampling, which prioritizes learning from sparse regions:
where: - \(w_i\) is the sampling weight for observation \(i\) - \(density_i\) is the local data density - \(\alpha\) is the sampling parameter (typically 0.1)
This ensures the model doesn't overfit to dense regions at the expense of sparse areas.
Using OGModel¶
Basic Usage¶
from og_learn import OGModel
model = OGModel(
hv='lightgbm', # HV model type
lv='mlp', # LV model type
oscillation=0.05, # Noise injection strength
sampling_alpha=0.1, # Density sampling exponent
epochs=100, # Training epochs for LV
seed=42 # Random seed for reproducibility
)
model.fit(X_train, y_train, density=density_train)
predictions = model.predict(X_test)
With Validation and Early Stopping¶
model = OGModel(
hv='lightgbm',
lv='resnet',
oscillation=0.05,
sampling_alpha=0.1,
epochs=200,
early_stopping=True,
patience=20
)
model.fit(
X_train, y_train,
density=density_train,
X_valid=X_valid,
y_valid=y_valid
)
With TensorBoard Logging¶
model = OGModel(
hv='lightgbm',
lv='mlp',
tensorboard_dir='runs/experiment1',
tensorboard_name='og_mlp',
eval_every_epochs=5
)
model.fit(X_train, y_train, density=density_train)
Model Comparison¶
Use compare_models to benchmark different configurations:
from og_learn import OGModel, compare_models
from og_learn.presets import get_lv_model
# Define models to compare
models = {
'MLP': get_lv_model('mlp', num_features=X_train.shape[1]),
'OG_MLP': OGModel(hv='lightgbm', lv='mlp'),
'OG_ResNet': OGModel(hv='lightgbm', lv='resnet'),
}
# Run comparison
results = compare_models(
models,
X_train, y_train,
X_test, y_test,
density=density_train,
tensorboard_dir='runs/comparison',
save_dir='checkpoints'
)
# View results
print(results)
Output:
============================================================
Model Comparison
============================================================
Model Train_R2 Test_R2
0 MLP 0.6012 0.5743
1 OG_MLP 0.6524 0.6127
2 OG_ResNet 0.6891 0.6352
============================================================
Best Practices¶
Choosing HV Models
- LightGBM: Fast, good default choice
- XGBoost: Often slightly better accuracy
- CatBoost: Better with categorical features
Choosing LV Models
- MLP: Simple, fast, good baseline
- ResNet: Better for complex patterns
- Transformer: Best for sequential/temporal patterns
Common Pitfalls
- Don't forget to provide
densityduring training - Use validation data for early stopping with neural networks
- Set
seedfor reproducibility