"""
VeLO: An implementation of VeLO from https://arxiv.org/abs/2211.09760.
Some of the following code is adapted from https://github.com/google/learned_optimization/blob/main/learned_optimization/research/general_lopt/hyper_v2.py
"""
import torch
import torch.nn as nn
import numpy as np
from torch.optim import Optimizer
import torch.nn.functional as F
from collections import OrderedDict
from pylo.models.VeLO_MLP import VeLOMLP
from pylo.models.VeLO_RNN import VeLORNN
def init_factors(p):
shape = p.shape
f_dims = factored_dims(shape)
shape = shape + (3,)
if f_dims is not None:
d1, d0 = f_dims
vr_shape = tuple(dim for i, dim in enumerate(shape) if i != d0)
vc_shape = tuple(dim for i, dim in enumerate(shape) if i != d1)
v_row = torch.zeros(vr_shape, dtype=torch.float32)
v_col = torch.zeros(vc_shape, dtype=torch.float32)
return v_row, v_col, torch.tensor([], dtype=torch.float32)
else:
v = torch.zeros(shape, dtype=torch.float32)
return (
torch.tensor([], dtype=torch.float32),
torch.tensor([], dtype=torch.float32),
v,
)
def update_factors(
v_col, v_row, v_full, g, g_shape, decay_rate: float = 0.9, epsilon: float = 1e-30
):
f_dims = factored_dims(g_shape)
mixing_rate = 1.0 - decay_rate
rp_shape = [1] * len(g_shape)
g = g.repeat(rp_shape + [decay_rate.shape[-1]])
grad_sqr = g * g + epsilon
if f_dims is not None:
d1, d0 = f_dims
decay_rate, mixing_rate = decay_rate.squeeze(0), mixing_rate.squeeze(0)
new_v_row = decay_rate * v_row + mixing_rate * torch.mean(grad_sqr, dim=d0)
new_v_col = decay_rate * v_col + mixing_rate * torch.mean(grad_sqr, dim=d1)
reduced_d1 = d1 - 1 if d1 > d0 else d1
row_col_mean = torch.mean(new_v_row, dim=reduced_d1, keepdim=True)
row_factor = safe_rsqrt(new_v_row / (row_col_mean + 1e-9))
col_factor = safe_rsqrt(new_v_col)
y = g * row_factor.unsqueeze(d0) * col_factor.unsqueeze(d1)
return new_v_col, new_v_row, torch.tensor([], dtype=torch.float32), y
else:
new_v = decay_rate * v_full + mixing_rate * grad_sqr
y = g * safe_rsqrt(new_v + 1e-9)
return (
torch.tensor([], dtype=torch.float32),
torch.tensor([], dtype=torch.float32),
new_v,
y,
)
def tanh_embedding(x):
x = torch.tensor(x, dtype=torch.float32)
timescales = torch.tensor(
[1, 3, 10, 30, 100, 300, 1000, 3000, 10000, 30000, 100000], dtype=torch.float32
)
embeddings = torch.tanh(x / timescales - 1.0)
return embeddings
def second_moment_normalizer(x, axis, eps=1e-5):
mean_squared = torch.mean(torch.square(x), dim=axis, keepdim=True)
return x * torch.rsqrt(eps + mean_squared)
def factored_dims(shape):
if len(shape) < 2:
return None
sorted_dims = np.argsort(shape)
if shape[int(sorted_dims[-2])] == shape[int(sorted_dims[-1])]:
if len(shape) == 4 and int(sorted_dims[-2]) == 0 and int(sorted_dims[-1]) == 1:
return int(sorted_dims[-2]), int(sorted_dims[-1])
else:
return int(sorted_dims[-1]), int(sorted_dims[-2])
else:
return int(sorted_dims[-2]), int(sorted_dims[-1])
def decay_to_param(x):
return torch.log(1 - x) / 10.0
def param_to_decay(x):
return 1 - torch.exp(x * 10.0)
def safe_rsqrt(x):
return torch.rsqrt(torch.maximum(x, torch.tensor(1e-9)))
def clip_log_abs(v, scale=1.0):
mag = torch.log(1e-8 + torch.abs(v * scale))
return torch.clamp(mag, -5, 5) * 0.5
def sorted_values(dd):
return list(zip(*sorted(dd.items(), key=lambda x: x[0])))[1]
def fractional_tanh_embed(x):
def one_freq(timescale):
return torch.tanh((x - timescale) * 10)
timescales = torch.tensor(
[0.03, 0.1, 0.2, 0.4, 0.6, 0.8, 0.9, 1.0, 1.1], dtype=torch.float32
)
return torch.stack([one_freq(ts) for ts in timescales])
class BufferLossAccumulators:
def __init__(self, device):
self.device = device
pass
def init(self, num_steps):
halflife = torch.logspace(
1, torch.log10(torch.tensor(num_steps, dtype=torch.float32)), 10
)
decays = torch.exp(-1.0 / halflife)
return {
"means": torch.zeros(len(decays), dtype=torch.float32, device=self.device),
"iteration": torch.tensor(0, dtype=torch.int32, device=self.device),
"running_min": 999999999999.0
* torch.ones(len(decays), dtype=torch.float32, device=self.device),
"decays": decays.to(self.device),
}
def update(self, state, loss):
jdecays = state["decays"]
cor_mean = state["means"] / (1 - jdecays ** (state["iteration"] + 1))
approx_max = torch.max(cor_mean)
approx_max = torch.where(state["iteration"] == 0, loss, approx_max)
loss = torch.minimum(torch.abs(approx_max) * 2, loss)
means = state["means"] * jdecays + loss * (1.0 - jdecays)
cor_mean = means / (1 - jdecays ** (state["iteration"] + 1))
running_min = torch.minimum(state["running_min"], cor_mean)
return {
"means": means,
"iteration": state["iteration"] + 1,
"running_min": running_min,
"decays": state["decays"],
}
def features(self, state):
jdecays = state["decays"]
cor_mean = state["means"] / (1 - jdecays ** state["iteration"])
approx_max = cor_mean[1:]
cor_mean = cor_mean[0:-1]
running_min = state["running_min"][0:-1]
den = torch.maximum(torch.tensor(1e-8), (approx_max - running_min))
pre_center = (cor_mean - running_min) / den
feature1 = pre_center - 1.0
feature1 = torch.clamp(feature1, -1, 1)
return torch.where(state["iteration"] <= 2, feature1 * 0, feature1)
def lstm_features_for_tensor(p, g, m, rms, fraction_trained, loss_features, device):
norm_mult = torch.rsqrt(torch.clamp(torch.mean(p**2), min=1e-9))
g = g * norm_mult
p = p * norm_mult
m = m * norm_mult
rms = rms * norm_mult
inputs = {}
fraction_left = fractional_tanh_embed(fraction_trained)
inputs["fraction_left"] = fraction_left.to(device)
inputs["loss_features"] = loss_features
leading_axis = list(range(0, len(p.shape)))
mean_m = torch.mean(m, dim=leading_axis, keepdim=True)
var_m = torch.mean((m - mean_m) ** 2, dim=leading_axis)
inputs["var_m"] = clip_log_abs(var_m, scale=10.0)
mean_rms = torch.mean(rms, dim=leading_axis, keepdim=True)
var_rms = torch.mean((rms - mean_m) ** 2, dim=leading_axis)
inputs["mean_rms"] = clip_log_abs(mean_rms.view(-1), scale=10.0)
inputs["var_rms"] = clip_log_abs(var_rms, scale=10.0)
n_rank = sum([1 for dim in p.shape if dim > 1])
inputs["rank"] = F.one_hot(torch.tensor(n_rank), num_classes=5).float().to(device)
values = sorted_values(inputs)
values = [v if len(v.shape) == 1 else v.unsqueeze(0) for v in values]
return torch.cat(values, dim=0)
[docs]class VeLO_naive(Optimizer):
[docs] def __init__(
self,
params,
momentum_decays=[0.0, 0.0, 0.0],
rms_decays=[0.0],
adafactor_decays=[0.0, 0.0, 0.0],
lr=1e-3,
exp_mult=0.001,
step_mult=0.001,
input_size=30,
hidden_size=4,
hidden_layers=1,
initial_momentum_decays=(0.9, 0.99, 0.999),
lstm_input_size=30,
lstm_hidden_size=512,
param_inits=256,
num_steps=10000,
initial_rms_decays=(0.999,),
initial_adafactor_decays=(0.9, 0.99, 0.999),
concat_weights=True,
make_separate_weights=False,
split_weights=False,
weight_decay=0.0,
clip_grad=False,
mup_lrs=None,
hf_key_rnn="Pauljanson002/VeLO_RNN",
hf_key_mlp="Pauljanson002/VeLO_MLP",
):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
momentum_decays = torch.tensor(momentum_decays).to(self.device)
rms_decays = torch.tensor(rms_decays).to(self.device)
adafactor_decays = torch.tensor(adafactor_decays).to(self.device)
mom_decay = param_to_decay(
decay_to_param(torch.tensor(initial_momentum_decays, device=self.device)) + momentum_decays
)
rms_decays = param_to_decay(
decay_to_param(torch.tensor(initial_rms_decays, device=self.device)) + rms_decays
)
adafactor_decays = param_to_decay(
decay_to_param(torch.tensor(initial_adafactor_decays, device=self.device)) + adafactor_decays
)
clip_mom_decays = torch.clip(mom_decay, 0.0, 1.0).to(self.device)
clip_rms_decays = torch.clip(rms_decays, 0.0, 1.0).to(self.device)
clip_adafactor_decays = torch.clip(adafactor_decays, 0.0, 1.0).to(self.device)
defaults = dict(
lr=lr,
exp_mult=exp_mult,
step_mult=step_mult,
initial_momentum_decays=clip_mom_decays,
lstm_hidden_size=lstm_hidden_size,
initial_rms_decays=clip_rms_decays,
initial_adafactor_decays=clip_adafactor_decays,
param_inits=param_inits,
concat_weights=concat_weights,
make_separate_weights=make_separate_weights,
input_size=input_size,
hidden_size=hidden_size,
hidden_layers=hidden_layers,
split_weights=split_weights,
clip_grad=clip_grad,
mup_lrs=mup_lrs,
weight_decay=weight_decay,
)
super(VeLO_naive, self).__init__(params, defaults)
self.buffer_loss_fns = BufferLossAccumulators(self.device)
self.loss_buffer = self.buffer_loss_fns.init(num_steps)
self.num_steps = num_steps
self.rnn = VeLORNN.from_pretrained(hf_key_rnn)
self.lstm_init_state = self.rnn.lstm_init_state
self.rnn.to(self.device)
self.network_stack = VeLOMLP.from_pretrained(hf_key_mlp)
self.network_stack.to(self.device)
for name, param in self.network_stack.named_parameters():
param.requires_grad = False
for name, param in self.rnn.named_parameters():
param.requires_grad = False
self.init_state()
[docs] @torch.no_grad()
def init_state(
self,
):
layer_idx = 0
for group in self.param_groups:
group["step"] = 0
for p in group["params"]:
if p.requires_grad is False:
continue
state = self.state[p]
p_shape = p.shape
if len(state) == 0:
state["layer_idx"] = layer_idx
state["mom"] = torch.zeros(p_shape + (3,)).to(self.device)
state["rms"] = torch.zeros(p_shape + (1,)).to(self.device)
state["fac_vec_row"], state["fac_vec_col"], state["fac_vec_v"] = (
init_factors(p)
)
state["fac_vec_row"], state["fac_vec_col"], state["fac_vec_v"] = (
state["fac_vec_row"].to(self.device),
state["fac_vec_col"].to(self.device),
state["fac_vec_v"].to(self.device),
)
layer_idx += 1
self.lstm_hidden_state = (
self.lstm_init_state[0].repeat(layer_idx, 1).to(self.device),
self.lstm_init_state[1].repeat(layer_idx, 1).to(self.device),
)
[docs] @torch.no_grad()
def collect_rnn_outputs(self, to_lstm_from_loss):
rnn_inputs = []
lstm_hidden_states = []
for group in self.param_groups:
fraction_trained = group["step"] / self.num_steps
for p in group["params"]:
grad = torch.clip(p.grad, -1000.0, 1000.0)
state = self.state[p]
mom = state["mom"]
rms = state["rms"]
rnn_inputs.append(
lstm_features_for_tensor(
p,
grad,
mom,
rms,
fraction_trained,
to_lstm_from_loss,
self.device,
)
)
rnn_inputs = torch.stack(rnn_inputs)
rnn_inputs = torch.flip(rnn_inputs, [0])
control_params, lr_mult, self.lstm_hidden_state = self.rnn(
rnn_inputs, self.lstm_hidden_state
)
return control_params, lr_mult
# Add this method to save the loss buffer and LSTM hidden state
[docs] def state_dict(self):
# First get the standard optimizer state_dict
state_dict = super(VeLO_naive, self).state_dict()
# Add our additional state information
state_dict["loss_buffer"] = self.loss_buffer
state_dict["lstm_hidden_state"] = self.lstm_hidden_state
state_dict["num_steps"] = self.num_steps
return state_dict
# Add this method to load the loss buffer and LSTM hidden state
[docs] def load_state_dict(self, state_dict):
# Extract our custom state information
loss_buffer = state_dict.pop("loss_buffer")
lstm_hidden_state = state_dict.pop("lstm_hidden_state")
num_steps = state_dict.pop("num_steps")
# Load the standard optimizer state
super(VeLO_naive, self).load_state_dict(state_dict)
# Restore our custom state
self.loss_buffer = loss_buffer
self.lstm_hidden_state = lstm_hidden_state
self.num_steps = num_steps
[docs] @torch.no_grad()
def step(self, loss):
self.loss_buffer = self.buffer_loss_fns.update(self.loss_buffer, loss)
to_lstm_from_loss = self.buffer_loss_fns.features(self.loss_buffer)
control_params, lr_mult = self.collect_rnn_outputs(to_lstm_from_loss)
for group in self.param_groups:
exp_mult = group["exp_mult"]
step_mult = group["step_mult"]
group["step"] += 1
for p in group["params"]:
beta_m = group["initial_momentum_decays"]
beta_rms = group["initial_rms_decays"]
beta_adafactor = group["initial_adafactor_decays"]
weight_decay = group["weight_decay"]
p_shape = p.shape
if p.grad is None:
continue
grad = torch.clip(p.grad, -1000.0, 1000.0)
state = self.state[p]
mom = state["mom"]
rms = state["rms"]
layer_idx = state["layer_idx"]
batch_p = p.unsqueeze(-1)
batch_g = grad.unsqueeze(-1)
clipped_g = torch.clip(batch_g, -0.1, 0.1)
axis = list(range(len(p_shape)))
for _ in axis:
beta_m = beta_m[None, ...]
beta_rms = beta_rms[None, ...]
beta_adafactor = beta_adafactor[None, ...]
mom.mul_(beta_m).add_((1 - beta_m) * batch_g)
rms.mul_(beta_rms).add_((1 - beta_rms) * (batch_g**2))
(
state["fac_vec_col"],
state["fac_vec_row"],
state["fac_vec_v"],
fac_g,
) = update_factors(
state["fac_vec_col"],
state["fac_vec_row"],
state["fac_vec_v"],
batch_g,
p_shape,
beta_adafactor,
)
fac_vec_col, fac_vec_row, fac_vec_v = (
state["fac_vec_col"],
state["fac_vec_row"],
state["fac_vec_v"],
)
rsqrt = torch.rsqrt(rms + 1e-6)
rms_norm_g = batch_g * rsqrt
inps = [
batch_g,
clipped_g,
batch_p,
mom,
rms,
mom * rsqrt,
rsqrt,
fac_g,
rms_norm_g,
]
f_dims = factored_dims(p_shape)
if f_dims is not None:
d1, d0 = f_dims
rp_row = [1] * (1 + len(p_shape))
rp_col = [1] * (1 + len(p_shape))
rp_row[d0] = p_shape[d0]
rp_col[d1] = p_shape[d1]
row_feat = fac_vec_row.unsqueeze(d0).repeat(rp_row)
col_feat = fac_vec_col.unsqueeze(d1).repeat(rp_col)
inps.extend(
[
row_feat,
col_feat,
torch.rsqrt(row_feat + 1e-8),
torch.rsqrt(col_feat + 1e-8),
]
)
reduced_d1 = d1 - 1 if d1 > d0 else d1 #!r change
row_col_mean = fac_vec_row.mean(dim=reduced_d1, keepdim=True) #!r change
row_factor = safe_rsqrt(fac_vec_row / (row_col_mean + 1e-9))
col_factor = safe_rsqrt(fac_vec_col)
fac_mom_mult = (
mom * row_factor.unsqueeze(d0) * col_factor.unsqueeze(d1)
)
inps.append(fac_mom_mult)
else:
inps.extend(
[
fac_vec_v,
fac_vec_v,
torch.rsqrt(fac_vec_v + 1e-8),
torch.rsqrt(fac_vec_v + 1e-8),
]
)
fac_mom_mult = mom * torch.pow(fac_vec_v, -0.5)
inps.append(fac_mom_mult)
inps = [second_moment_normalizer(i, axis=axis) for i in inps]
inps = torch.cat(inps, dim=-1)
self.network_stack.update_params(control_params[-(1 + layer_idx)])
direction, magnitude, _ = self.network_stack(inps).split(1, dim=-1)
# print(direction.shape, magnitude.shape, _.shape)
# print(direction, magnitude, _)
param_scale = torch.sqrt(torch.mean(p**2) + 1e-9)
step = param_scale * (
direction * torch.exp(magnitude * exp_mult) * step_mult
).squeeze(-1)
step = lr_mult[-(1 + layer_idx)] * step
p.add_(step, alpha=-group["lr"])
if weight_decay > 0:
p.add_(p, alpha=-weight_decay * group["lr"])
return