Skip to content

Commit

Permalink
Calibration with temperature scaling (#2)
Browse files Browse the repository at this point in the history
Added John's temperature scaling and some additional documentation
  • Loading branch information
nukularrr committed Jul 6, 2023
1 parent c53246a commit 9edec19
Show file tree
Hide file tree
Showing 11 changed files with 540 additions and 281 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/gh-pages.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: github pages

on:
push:
branches:
- main

jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'

- name: Upgrade pip
run: |
# install pip=>20.1 to use "pip cache dir"
python3 -m pip install --upgrade pip
- name: Get pip cache dir
id: pip-cache
run: echo "::set-output name=dir::$(pip cache dir)"

- name: Cache dependencies
uses: actions/cache@v3
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml', 'mkdocs.yml') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install package
run: python3 -m pip install .[docs]

- name: Run mkdocs
run: mkdocs gh-deploy --force --clean --verbose
186 changes: 78 additions & 108 deletions docs/example_notebooks/MNIST_OOD_detection.ipynb

Large diffs are not rendered by default.

128 changes: 85 additions & 43 deletions docs/example_notebooks/toy_example_EquineProtonet.ipynb

Large diffs are not rendered by default.

90 changes: 49 additions & 41 deletions docs/example_notebooks/toy_example_GP.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Establishing Quantified Uncertainty in Neural Networks
<p align="center"><img src="assets/equine_full_logo.svg" width="720"\></p>

[![Build Status](https://github.com/mit-ll-responsible-ai/equine/actions/workflows/python-package.yml/badge.svg?branch=main)](https://github.com/mit-ll-responsible-ai/equine/actions/workflows/Tests.yml)
[![Build Status](https://github.com/mit-ll-responsible-ai/equine/actions/workflows/Tests.yml/badge.svg?branch=main)](https://github.com/mit-ll-responsible-ai/equine/actions/workflows/Tests.yml)
![python_passing_tests](https://img.shields.io/badge/Tests%20Passed-100%25-green)
![python_coverage](https://img.shields.io/badge/Coverage-91%25-green)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
Expand Down
5 changes: 2 additions & 3 deletions src/equine/equine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations
import torch
from torch.utils.data import TensorDataset # type: ignore
from abc import ABC, abstractmethod
from typing import Any

Expand Down Expand Up @@ -32,9 +33,7 @@ def forward(self, X: torch.Tensor):
def predict(self, X: torch.Tensor) -> EquineOutput:
raise NotImplementedError

def train_model(
dataset: torch.utils.data.TensorDataset, **kwargs
) -> dict[str, Any]:
def train_model(self, dataset: TensorDataset, **kwargs) -> dict[str, Any]:
raise NotImplementedError

def save(self, path: str) -> None:
Expand Down
110 changes: 99 additions & 11 deletions src/equine/equine_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import icontract
import torch
from typing import Optional, Callable, Any
from torch.utils.data import TensorDataset, DataLoader # type: ignore
from typing import Optional, Callable, Any, Union
from tqdm import tqdm
from typeguard import typechecked
from datetime import datetime
from sklearn.model_selection import train_test_split

from .equine import Equine, EquineOutput
from .utils import generate_train_summary
Expand Down Expand Up @@ -244,12 +246,21 @@ class EquineGP(Equine):
"""

def __init__(
self, embedding_model: torch.nn.Module, emb_out_dim: int, num_classes: int
self,
embedding_model: torch.nn.Module,
emb_out_dim: int,
num_classes: int,
use_temperature: bool = False,
init_temperature: float = 1.0,
device: str = "cpu",
) -> None:
"""EquineGP constructor
:param embedding_model: Neural Network feature embedding
:param emb_out_dim: The number of deep features from the feature embedding
:param num_classes: The number of output classes this model predicts
:param use_temperature: whether to use temperature scaling after training
:param init_temperature: what to use as the initial temperature (1.0 has no effect)
"param device: either 'cuda' or 'cpu'
"""
super().__init__(embedding_model)
self.num_deep_features = emb_out_dim
Expand All @@ -260,7 +271,11 @@ def __init__(
self.mean_field_factor = 25
self.ridge_penalty = 1
self.feature_scale = 2

self.use_temperature = use_temperature
self.init_temperature = init_temperature
self.register_buffer(
"temperature", torch.Tensor(self.init_temperature * torch.ones(1))
)
self.model = _Laplace(
self.embedding_model,
self.num_deep_features,
Expand All @@ -272,23 +287,43 @@ def __init__(
self.mean_field_factor,
self.ridge_penalty,
)
self.device_type = device
self.device = torch.device(self.device_type)

def train_model(
self,
dataset: torch.utils.data.TensorDataset, # type: ignore
dataset: TensorDataset,
loss_fn: Callable,
opt: torch.optim.Optimizer,
num_epochs: int,
batch_size: int = 64,
) -> dict[str, Any]:
calib_frac: float = 0.1,
num_calibration_epochs: int = 2,
calibration_lr: float = 0.01,
) -> tuple[dict[str, Any], Union[DataLoader, None]]:
"""Train or fine-tune an EquineGP model
:param dataset: An iterable, pytorch TensorDataset
:param loss_fn: A pytorch loss function, eg., torch.nn.CrossEntropyLoss()
:param opt: A pytorch optimizer, e.g., torch.optim.Adam()
:param num_epochs: The desired number of epochs to use for training,
"param batch_size: The number of samples to use per batch
:param batch_size: The number of samples to use per batch
:param calib_frac: fraction of training data to use in temperature scaling
:param num_calibration_epochs: The desired number of epochs to use for temperature scaling,
:param calibration_lr: learning rate for temperature scaling
:return: A tuple containing the training history and a dataloader for the calibration data
"""
train_loader = torch.utils.data.DataLoader( # type: ignore

if self.use_temperature:
X, Y = dataset[:]
train_x, calib_x, train_y, calib_y = train_test_split(
X, Y, test_size=calib_frac, stratify=Y
) # TODO: Replace sklearn with torch call
dataset = TensorDataset(train_x, train_y)
self.temperature = torch.Tensor(
self.init_temperature * torch.ones(1)
).type_as(self.temperature)

train_loader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, drop_last=True
)
self.model.set_training_params(
Expand All @@ -300,26 +335,72 @@ def train_model(
epoch_loss = 0.0
for i, (xs, labels) in enumerate(train_loader):
opt.zero_grad()
xs = xs.to(self.device)
labels = labels.to(self.device)
yhats = self.model(xs)
loss = loss_fn(yhats, labels.to(torch.long))
loss.backward()
opt.step()
epoch_loss += loss.item()
self.model.eval()

calibration_loader = None
if self.use_temperature:
dataset_calibration = TensorDataset(calib_x, calib_y)
calibration_loader = DataLoader(
dataset_calibration,
batch_size=batch_size,
shuffle=True,
drop_last=False,
)
self.calibrate_temperature(
calibration_loader, num_calibration_epochs, calibration_lr
)

_, train_y = dataset[:]
date_trained = datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
self.train_summary = generate_train_summary(self, train_y, date_trained)

return self.train_summary
return self.train_summary, calibration_loader

def calibrate_temperature(
self,
calibration_loader: DataLoader,
num_calibration_epochs: int = 1,
calibration_lr: float = 0.01,
) -> None:
"""
Fine-tune the temperature after training. Note this function is also run at the conclusion of train_model
:param calibration_loader: data loader returned by train_model
:param num_calibration_epochs: number of epochs to tune temperature
:param calibration_lr: learning rate for temperature optimization
:return:
"""
self.temperature.requires_grad = True
loss_fn = torch.nn.functional.cross_entropy
optimizer = torch.optim.Adam([self.temperature], lr=calibration_lr)
for _ in range(num_calibration_epochs):
for (xs, labels) in calibration_loader:
optimizer.zero_grad()
xs = xs.to(self.device)
labels = labels.to(self.device)
with torch.no_grad():
logits = self.model(xs)
logits = logits / self.temperature
loss = loss_fn(logits, labels.to(torch.long))
loss.backward()
optimizer.step()
self.temperature.requires_grad = False

def forward(self, X: torch.Tensor) -> torch.Tensor:
"""EquineGP forward function, generates logits for classification
:param X: Input tensor for generating predictions
:return[torch.Tensor]: Output probabilities computed
"""
preds = self.model(X)
return preds
X = X.to(self.device)
with torch.no_grad():
preds = self.model(X)
return preds / self.temperature

@icontract.ensure(
lambda result: all((0 <= result.ood_scores) & (result.ood_scores <= 1.0))
Expand All @@ -329,7 +410,11 @@ def predict(self, X: torch.Tensor) -> EquineOutput:
:param X: Input tensor
:return[EquineOutput] : Output object containing prediction probabilities and OOD scores
"""
preds = torch.softmax(self.model(X), dim=1)
X = X.to(self.device)
with torch.no_grad():
logits = self.model(X)
logits = logits / self.temperature
preds = torch.softmax(logits, dim=1)
equiprobable = torch.ones(self.num_outputs) / self.num_outputs
max_entropy = torch.sum(torch.special.entr(equiprobable))
ood_score = torch.sum(torch.special.entr(preds), dim=1) / max_entropy
Expand All @@ -345,6 +430,9 @@ def save(self, path: str) -> None:
model_settings = {
"emb_out_dim": self.num_deep_features,
"num_classes": self.num_outputs,
"use_temperature": self.use_temperature,
"init_temperature": self.temperature.item(),
"device": self.device_type,
}

jit_model = torch.jit.script(self.model.feature_extractor) # type: ignore
Expand Down
62 changes: 60 additions & 2 deletions src/equine/equine_protonet.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,16 @@ def __init__(
emb_out_dim: int,
cov_type: CovType = CovType.UNIT,
relative_mahal: bool = True,
use_temperature: bool = False,
init_temperature: float = 1.0,
) -> None:
"""EquineProtonet constructor
:param embedding_model: Neural Network feature embedding model
:param emb_out_dim: The number of output features from the embedding model
:param cov_type: The type of covariance to use when training the protonet [UNIT, DIAG, FULL]
:param relative_mahal: Use relative mahalanobis distance for OOD calculations. If false, uses standard mahalanobis distance instead
:param use_temperature: whether to use temperature scaling after training
param init_temperature: what to use as the initial temperature (1.0 has no effect)
"""
super().__init__(embedding_model)
self.cov_type = cov_type
Expand All @@ -327,6 +331,11 @@ def __init__(
self.epsilon = DEFAULT_EPSILON
self.outlier_score_kde = None
self.model_summary = None
self.use_temperature = use_temperature
self.init_temperature = init_temperature
self.register_buffer(
"temperature", torch.Tensor(self.init_temperature * torch.ones(1))
)

self.model = _Protonet(
embedding_model,
Expand Down Expand Up @@ -355,7 +364,9 @@ def train_model(
episode_size: int = 100,
loss_fn: Callable = torch.nn.functional.cross_entropy,
opt_class: Callable = torch.optim.Adam,
) -> dict[str, Any]:
num_calibration_epochs: int = 2,
calibration_lr: float = 0.01,
) -> tuple[dict[str, Any], torch.Tensor, torch.Tensor]:
"""Train or fine-tune an EquineProtonet model
:param dataset: Input pytorch TensorDataset of training data for model
:param num_episodes: The desired number of episodes to use for training
Expand All @@ -365,9 +376,17 @@ def train_model(
:param episode_size: Number of examples to use per episode
:param loss_fn: A pytorch loss function, eg., torch.nn.CrossEntropyLoss()
:param opt_class: A pytorch optimizer, e.g., torch.optim.Adam
:param num_calibration_epochs: The desired number of epochs to use for temperature scaling,
:param calibration_lr: learning rate for temperature scaling
:return: A tuple containing the model summary, the held out calibration data, and the calibration labels
"""
self.train()

if self.use_temperature:
self.temperature = torch.Tensor(
self.init_temperature * torch.ones(1)
).type_as(self.temperature)

X, Y = dataset[:]

train_x, calib_x, train_y, calib_y = train_test_split(
Expand Down Expand Up @@ -405,9 +424,43 @@ def train_model(
ood_dists = self._compute_ood_dist(X_embed, pred_probs, dists)
self._fit_outlier_scores(ood_dists, calib_y)

if self.use_temperature:
self.calibrate_temperature(
calib_x, calib_y, num_calibration_epochs, calibration_lr
)

date_trained = datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
self.train_summary = generate_train_summary(self, train_y, date_trained)
return self.train_summary
return self.train_summary, calib_x, calib_y

def calibrate_temperature(
self,
calib_x: torch.Tensor,
calib_y: torch.Tensor,
num_calibration_epochs: int = 1,
calibration_lr: float = 0.01,
) -> None:
"""
Fine-tune the temperature after training. Note this function is also run at the conclusion of train_model
:param calib_x: training data to be used for temperature calibration
:param calib_y: labels corresponding to calib_x
:param num_calibration_epochs: number of epochs to tune temperature
:param calibration_lr: learning rate for temperature optimization
:return: None
"""
self.temperature.requires_grad = True
optimizer = torch.optim.Adam([self.temperature], lr=calibration_lr)
for t in range(num_calibration_epochs):
optimizer.zero_grad()
with torch.no_grad():
pred_probs, dists = self.model(calib_x)
dists = dists / self.temperature
loss = torch.nn.functional.cross_entropy(
torch.neg(dists), calib_y.to(torch.long)
)
loss.backward()
optimizer.step()
self.temperature.requires_grad = False

@icontract.ensure(lambda self: self.model.support_embeddings is not None)
def _fit_outlier_scores(
Expand Down Expand Up @@ -480,6 +533,9 @@ def predict(self, X: torch.Tensor) -> EquineOutput:
if X_embed.shape == torch.Size([self.model.emb_out_dim]):
X_embed = X_embed.unsqueeze(dim=0) # Handle single examples
preds, dists = self.model(X)
if self.use_temperature:
dists = dists / self.temperature
preds = torch.softmax(torch.negative(dists), dim=1)
ood_dist = self._compute_ood_dist(X_embed, preds, dists)
ood_scores = self._compute_outlier_scores(ood_dist, preds)

Expand Down Expand Up @@ -524,6 +580,8 @@ def save(self, path: str) -> None:
model_settings = {
"cov_type": self.cov_type,
"emb_out_dim": self.emb_out_dim,
"use_temperature": self.use_temperature,
"init_temperature": self.temperature.item(),
}

jit_model = torch.jit.script(self.model.embedding_model)
Expand Down
Loading

0 comments on commit 9edec19

Please sign in to comment.