Source code for pylo.optim.AdafacLO_naive

"""AdafacLO_Naive: An MLP learned optimizer.

This is a PyTorch implementation of small_fc_lopt from: https://arxiv.org/abs/2203.11860

The following code is adapted from the following Jax implementation: https://github.com/google/learned_optimization/blob/main/learned_optimization/learned_optimizers/adafac_mlp_lopt.py
"""
from collections import OrderedDict
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer

from pylo.models.Meta_MLP import MetaMLP


def init_factors(p):
    shape = p.shape
    f_dims = factored_dims(shape)
    shape = shape + (3,)
    if f_dims is not None:
        d1, d0 = f_dims
        vr_shape = tuple(dim for i, dim in enumerate(shape) if i != d0)
        vc_shape = tuple(dim for i, dim in enumerate(shape) if i != d1)
        v_row = torch.zeros(vr_shape, dtype=torch.float32)
        v_col = torch.zeros(vc_shape, dtype=torch.float32)
        return v_row, v_col, torch.tensor([], dtype=torch.float32)

    else:
        v = torch.zeros(shape, dtype=torch.float32)
        return (
            torch.tensor([], dtype=torch.float32),
            torch.tensor([], dtype=torch.float32),
            v,
        )


def safe_rsqrt(x):
    return torch.rsqrt(
        torch.maximum(x, torch.tensor(1e-9, dtype=x.dtype, device=x.device))
    )


def update_factors(
    v_col, v_row, v_full, g, g_shape, decay_rate: float = 0.9, epsilon: float = 1e-30 #! change
):
    f_dims = factored_dims(g_shape)
    mixing_rate = 1.0 - decay_rate
    rp_shape = [1] * len(g_shape)
    g = g.repeat(rp_shape + [decay_rate.shape[-1]])
    grad_sqr = g * g + epsilon
    if f_dims is not None:
        d1, d0 = f_dims
        decay_rate, mixing_rate = decay_rate.squeeze(0), mixing_rate.squeeze(0)
        # print(f_dims, decay_rate.shape, mixing_rate.shape, grad_sqr.shape, v_row.shape, v_col.shape)
        new_v_row = decay_rate * v_row + mixing_rate * torch.mean(grad_sqr, dim=d0)
        new_v_col = decay_rate * v_col + mixing_rate * torch.mean(grad_sqr, dim=d1)

        reduced_d1 = d1 - 1 if d1 > d0 else d1
        row_col_mean = torch.mean(new_v_row, dim=reduced_d1, keepdim=True)

        row_factor = safe_rsqrt(new_v_row / (row_col_mean + 1e-9))
        col_factor = safe_rsqrt(new_v_col)
        # print(f_dims, mixing_rate.shape, g.shape, row_factor.shape, col_factor.shape)
        y = g * row_factor.unsqueeze(d0) * col_factor.unsqueeze(d1)
        return new_v_col, new_v_row, torch.tensor([], dtype=torch.float32), y

    else:
        new_v = decay_rate * v_full + mixing_rate * grad_sqr
        y = g * safe_rsqrt(new_v + 1e-9)
        return (
            torch.tensor([], dtype=torch.float32),
            torch.tensor([], dtype=torch.float32),
            new_v,
            y,
        )


def tanh_embedding(x):
    x = torch.tensor(x, dtype=torch.float32)
    timescales = torch.tensor(
        [1, 3, 10, 30, 100, 300, 1000, 3000, 10000, 30000, 100000], dtype=torch.float32
    )
    embeddings = torch.tanh(x / timescales - 1.0)
    return embeddings


def second_moment_normalizer(x, axis, eps=1e-5):
    mean_squared = torch.mean(x**2, dim=axis, keepdim=True)
    return x * torch.rsqrt(eps + mean_squared)


def factored_dims(shape):
    if len(shape) < 2:
        return None
    sorted_dims = np.argsort(shape)
    return int(sorted_dims[-2]), int(sorted_dims[-1])


def decay_to_param(x):
    return torch.log(1 - x) / 10.0


def param_to_decay(x):
    return 1 - torch.exp(x * 10.0)


[docs]class AdafacLO_naive(Optimizer):
[docs] def __init__( self, params, momentum_decays=[0.15216392, 0.14245212, 0.06812963], rms_decays=[0.01079706], adafactor_decays=[0.18621896, -0.10864615, -0.06185547], lr=1.0, exp_mult=0.001, step_mult=0.01, input_size=39, hidden_size=32, hidden_layers=1, initial_momentum_decays=(0.9, 0.99, 0.999), initial_rms_decays=(0.999,), initial_adafactor_decays=(0.9, 0.99, 0.999), max_grad_norm=None, concat_weights=True, make_separate_weights=False, split_weights=False, clip_grad=False, weight_decay=0.0, mup_lrs=None, hf_key: Optional[str] = "btherien/mulo", ): self.device = "cuda" if torch.cuda.is_available() else "cpu" momentum_decays = torch.tensor(momentum_decays).to(self.device) rms_decays = torch.tensor(rms_decays).to(self.device) adafactor_decays = torch.tensor(adafactor_decays).to(self.device) mom_decay = param_to_decay( decay_to_param(torch.tensor(initial_momentum_decays, device=self.device)) + momentum_decays ) rms_decays = param_to_decay( decay_to_param(torch.tensor(initial_rms_decays, device=self.device)) + rms_decays ) adafactor_decays = param_to_decay( decay_to_param(torch.tensor(initial_adafactor_decays, device=self.device)) + adafactor_decays ) clip_mom_decays = torch.clip(mom_decay.clone().detach(), 0.0, 1.0).to( self.device ) clip_rms_decays = torch.clip(rms_decays.clone().detach(), 0.0, 1.0).to( self.device ) clip_adafactor_decays = torch.clip( adafactor_decays.clone().detach(), 0.0, 1.0 ).to(self.device) defaults = dict( lr=lr, exp_mult=exp_mult, step_mult=step_mult, initial_momentum_decays=clip_mom_decays, initial_rms_decays=clip_rms_decays, initial_adafactor_decays=clip_adafactor_decays, concat_weights=concat_weights, make_separate_weights=make_separate_weights, split_weights=split_weights, clip_grad=clip_grad, weight_decay=weight_decay, mup_lrs=mup_lrs, max_grad_norm=max_grad_norm, ) super(AdafacLO_naive, self).__init__(params, defaults) self.network = MetaMLP.from_pretrained(hf_key).to(self.device)
[docs] @torch.no_grad() def step(self, loss=None): for group in self.param_groups: exp_mult = group["exp_mult"] step_mult = group["step_mult"] max_grad_norm = group["max_grad_norm"] weight_decay = group["weight_decay"] if "step" in group: group["step"] += 1 else: group["step"] = 1 for p in group["params"]: if max_grad_norm is not None: torch.nn.utils.clip_grad_norm_(p, max_grad_norm) beta_m = group["initial_momentum_decays"] beta_rms = group["initial_rms_decays"] beta_adafactor = group["initial_adafactor_decays"] p_shape = p.shape if p.grad is None: continue grad = p.grad state = self.state[p] if len(state) == 0: state["mom"] = torch.zeros(p_shape + (3,)).to(self.device) state["rms"] = torch.zeros(p_shape + (1,)).to(self.device) state["fac_vec_row"], state["fac_vec_col"], state["fac_vec_v"] = ( init_factors(p) ) state["fac_vec_row"], state["fac_vec_col"], state["fac_vec_v"] = ( state["fac_vec_row"].to(self.device), state["fac_vec_col"].to(self.device), state["fac_vec_v"].to(self.device), ) batch_p = p.unsqueeze(-1) batch_g = grad.unsqueeze(-1) training_step_feature = tanh_embedding(group["step"] - 1).to( self.device ) axis = list(range(len(p_shape))) for _ in axis: beta_m = beta_m[None, ...] beta_rms = beta_rms[None, ...] beta_adafactor = beta_adafactor[None, ...] training_step_feature = training_step_feature[None, ...] training_step_feature = training_step_feature.repeat(p_shape + (1,)) mom = state["mom"] rms = state["rms"] mom.mul_(beta_m).add_((1 - beta_m) * batch_g) rms.mul_(beta_rms).add_((1 - beta_rms) * (batch_g**2)) ( state["fac_vec_col"], state["fac_vec_row"], state["fac_vec_v"], fac_g, ) = update_factors( state["fac_vec_col"], state["fac_vec_row"], state["fac_vec_v"], batch_g, p_shape, beta_adafactor, ) fac_vec_col, fac_vec_row, fac_vec_v = ( state["fac_vec_col"], state["fac_vec_row"], state["fac_vec_v"], ) rsqrt = torch.rsqrt(rms + 1e-6) # inps = [batch_p, batch_g, mom, rms, mom * rsqrt, rsqrt, fac_g] inps = [batch_g, batch_p, mom, rms, mom * rsqrt, rsqrt, fac_g] f_dims = factored_dims(p_shape) if f_dims is not None: d1, d0 = f_dims rp_row = [1] * (1 + len(p_shape)) rp_col = [1] * (1 + len(p_shape)) rp_row[d0] = p_shape[d0] rp_col[d1] = p_shape[d1] row_feat = fac_vec_row.unsqueeze(d0).repeat(rp_row) col_feat = fac_vec_col.unsqueeze(d1).repeat(rp_col) inps.extend( [ row_feat, col_feat, torch.rsqrt(row_feat + 1e-8), torch.rsqrt(col_feat + 1e-8), ] ) reduced_d1 = d1 - 1 if d1 > d0 else d1 #!r change row_col_mean = fac_vec_row.mean(dim=reduced_d1, keepdim=True) row_factor = safe_rsqrt(fac_vec_row / (row_col_mean + 1e-9)) #!r change col_factor = safe_rsqrt(fac_vec_col) fac_mom_mult = ( mom * row_factor.unsqueeze(d0) * col_factor.unsqueeze(d1) ) inps.append(fac_mom_mult) else: inps.extend( [ fac_vec_v, fac_vec_v, torch.rsqrt(fac_vec_v + 1e-8), torch.rsqrt(fac_vec_v + 1e-8), ] ) fac_mom_mult = mom * torch.pow(fac_vec_v + 1e-6, -0.5) inps.append(fac_mom_mult) inps = torch.cat(inps, dim=-1) inps = second_moment_normalizer(inps, axis=axis) inp_stack = torch.cat([inps, training_step_feature], dim=-1) direction, magnitude = self.network(inp_stack).split(1, dim=-1) step = ( direction * torch.exp(magnitude * exp_mult) * step_mult ).squeeze(-1) p.add_(step, alpha=-group["lr"]) if weight_decay > 0: p.add_(p, alpha=-weight_decay * group["lr"]) return