Source code for pylo.models.VeLO_MLP
"""
VeLO-MLP: An implementation based on Google's Versatile Learned Optimizer (VeLO) paper.
Some of the following code is adapted from https://github.com/google/learned_optimization/blob/main/learned_optimization/research/general_lopt/hyper_v2.py
This module implements the VeLO-MLP architecture, a neural network model that
serves as a learned optimizer as described in the VeLO paper from Google Research.
The model maintains two sets of parameters:
1. Storage parameters that hold a collection of possible parameter values
2. Actual parameters that are used during forward computation
The model can update its actual parameters based on a control vector through
the update_params method, following the versatile parameter adaptation approach
introduced in the VeLO paper.
"""
from collections import OrderedDict
import torch
from torch import nn
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
[docs]class VeLOMLP(
nn.Module, PyTorchModelHubMixin, license="apache-2.0", tags=["learned-optimizer"]
):
"""
Versatile Learned Optimizer MLP (VeLO-MLP).
This class implements a multi-layer perceptron based on Google's VeLO paper,
which can adapt its parameters based on a control vector. It maintains two sets of parameters:
- Storage parameters (with underscore suffix) that maintain a collection of possible parameter values
- Actual parameters that are used during forward computation
The model is designed to be used as a learned optimizer where parameters
can be dynamically updated based on optimization context, implementing the
versatile parameter adaptation approach described in the VeLO paper.
"""
[docs] def __init__(
self,
param_inits=256,
input_size=30,
hidden_size=4,
hidden_layers=1,
output_size=3,
):
"""
Initialize the VeLOMLP model.
Args:
param_inits (int, optional): Number of parameter initializations to maintain in storage.
Defaults to 256.
input_size (int, optional): Size of the input dimension. Defaults to 30.
hidden_size (int, optional): Size of the hidden dimensions. Defaults to 4.
hidden_layers (int, optional): Number of hidden layers. Defaults to 1.
output_size (int, optional): Size of the output dimension. Defaults to 3.
"""
super(VeLOMLP, self).__init__()
self.hidden_layers = hidden_layers
# This is to build an architecture to store all the params
self.input_weights_ = nn.Parameter(
torch.randn(param_inits, hidden_size, input_size)
)
self.input_bias_ = nn.Parameter(torch.zeros(param_inits, hidden_size))
self.hidden_weights_ = nn.ParameterList()
self.hidden_bias_ = nn.ParameterList()
for _ in range(hidden_layers):
weight = nn.Parameter(torch.randn(param_inits, hidden_size, hidden_size))
bias = nn.Parameter(torch.zeros(param_inits, hidden_size))
self.hidden_weights_.append(weight)
self.hidden_bias_.append(bias)
self.output_weights_ = nn.Parameter(
torch.randn(param_inits, output_size, hidden_size)
)
self.output_bias_ = nn.Parameter(torch.zeros(param_inits, output_size))
# This is to define the VeLO-MLP
self.input_weights = nn.Parameter(torch.randn(hidden_size, input_size))
self.input_bias = nn.Parameter(torch.zeros(hidden_size))
self.hidden_weights = nn.ParameterList()
self.hidden_bias = nn.ParameterList()
for _ in range(hidden_layers):
weight = nn.Parameter(torch.randn(hidden_size, hidden_size))
bias = nn.Parameter(torch.zeros(hidden_size))
self.hidden_weights.append(weight)
self.hidden_bias.append(bias)
self.output_weights = nn.Parameter(torch.randn(output_size, hidden_size))
self.output_bias = nn.Parameter(torch.zeros(output_size))
[docs] def update_params(self, control):
"""
Update the actual parameters based on the control vector.
This method computes a weighted average of the storage parameters based on
the control vector, and updates the actual parameters with the result.
The weighted average is scaled by a factor of 100.0.
Args:
control (torch.Tensor): Control vector that determines the weights for
parameter averaging. Shape should be compatible with the first
dimension of storage parameters.
Returns:
None
"""
control_w = control[:, None, None]
control_b = control[:, None]
self.input_weights.data.copy_(
(control_w * self.input_weights_.data).mean(0) * 100.0
)
self.input_bias.data.copy_((control_b * self.input_bias_.data).mean(0) * 100.0)
for i in range(self.hidden_layers):
self.hidden_weights[i].data.copy_(
(control_w * self.hidden_weights_[i].data).mean(0) * 100.0
)
self.hidden_bias[i].data.copy_(
(control_b * self.hidden_bias_[i].data).mean(0) * 100.0
)
self.output_weights.data.copy_(
(control_w * self.output_weights_.data).mean(0) * 100.0
)
self.output_bias.data.copy_(
(control_b * self.output_bias_.data).mean(0) * 100.0
)
[docs] def forward(self, x):
"""
Forward pass through the network.
Args:
x (torch.Tensor): Input tensor
Returns:
torch.Tensor: Output of the network
"""
x = F.relu(F.linear(x, self.input_weights, self.input_bias))
for weight, bias in zip(self.hidden_weights, self.hidden_bias):
x = F.relu(F.linear(x, weight, bias))
x = F.linear(x, self.output_weights, self.output_bias)
return x