Source code for pylo.models.VeLO_RNN

"""
VeLO RNN: An implementation of the RNN component from Google's Versatile Learned Optimizer (VeLO) paper.

This module implements the per-tensor RNN component of the VeLO architecture as described
in Google's Versatile Learned Optimizer paper. The RNN processes tensor-specific features
and outputs control vectors and learning rate multipliers to adapt optimization behavior
for each tensor.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin


class LSTM(nn.Module):
    """
    Custom LSTM implementation used within the VeLO architecture.
    
    This LSTM implementation is optimized for the VeLO optimizer's requirements
    and differs slightly from PyTorch's standard LSTM. It uses a single linear
    layer to compute all gates and has a +1 bias for the forget gate to improve
    gradient flow.
    
    Attributes:
        linear (nn.Linear): Linear layer that computes all gate values
    
    Args:
        input_size (int): Size of input features
        hidden_size (int): Size of hidden state
    """
    def __init__(self, input_size, hidden_size):
        """
        Initialize the LSTM module with specified dimensions.

        Args:
            input_size (int): Size of input features
            hidden_size (int): Size of hidden state
        """
        super(LSTM, self).__init__()
        self.linear = nn.Linear(2 * input_size, 4 * hidden_size)

    def forward(self, x, prev_state):
        """
        Forward pass of the LSTM.

        Args:
            x (torch.Tensor): Input tensor
            prev_state (tuple): Previous hidden state and cell state (h_prev, c_prev)

        Returns:
            tuple: Tuple containing:
                - h (torch.Tensor): New hidden state
                - (h, c) (tuple): New state tuple for next iteration
        """
        h_prev, c_prev = prev_state
        combined = torch.cat((x, h_prev), dim=-1)
        gates = self.linear(combined)
        i, g, f, o = gates.chunk(4, -1)
        i = torch.sigmoid(i)
        f = torch.sigmoid(f + 1)
        o = torch.sigmoid(o)
        g = torch.tanh(g)
        c = f * c_prev + i * g
        h = o * torch.tanh(c)
        return h, (h, c)


[docs]class VeLORNN(nn.Module, PyTorchModelHubMixin, license="apache-2.0", tags=["learned-optimizer"]): """ VeLO RNN module that processes tensor-specific features as described in Google's Versatile Learned Optimizer paper. This module implements the per-tensor RNN component of the VeLO architecture. It processes tensor-specific features and outputs control vectors and learning rate multipliers that determine how parameters are adapted during optimization. The VeLORNN applies a feature mixing network (when enabled) followed by an LSTM to produce control vectors that weight different parameter initializations and learning rate multipliers that scale step sizes. """
[docs] def __init__(self, input_size=30, lstm_hidden_size=512, param_inits=256, mix_layers=True): """ Initialize the VeLORNN module. Args: input_size (int, optional): Dimension of input features. Defaults to 30. lstm_hidden_size (int, optional): Size of LSTM hidden state. Defaults to 512. param_inits (int, optional): Number of parameter initializations to control. Determines the dimension of the output control vector. Defaults to 256. mix_layers (bool, optional): Whether to use feature mixing layers before LSTM. Defaults to True. """ super(VeLORNN, self).__init__() self.mix_layers = mix_layers self.mix_layer1 = nn.Linear(input_size, lstm_hidden_size) self.mix_layer2 = nn.Linear(input_size, lstm_hidden_size) self.final_mix_layer = nn.Linear(input_size, lstm_hidden_size) self.lstm = LSTM(input_size=lstm_hidden_size, hidden_size=lstm_hidden_size) self.rnn_to_controls = nn.Linear(lstm_hidden_size, param_inits) self.step_size = nn.Linear(lstm_hidden_size, 1) self.lstm_init_state = nn.ParameterList( [ nn.Parameter(torch.zeros(1, lstm_hidden_size)), nn.Parameter(torch.zeros(1, lstm_hidden_size)), ] )
[docs] def forward(self, x, state): """ Forward pass of the VeLORNN. This method processes tensor-specific features through optional mixing layers and an LSTM to produce control vectors and learning rate multipliers. Args: x (torch.Tensor): Input tensor containing tensor-specific features state (tuple): Previous LSTM state (h, c) Returns: tuple: Tuple containing: - controls (torch.Tensor): Control vector for weighting parameter initializations - lr_mult (torch.Tensor): Learning rate multiplier for scaling step size - state (tuple): Updated LSTM state for next iteration """ if self.mix_layers: # mix_1 = F.relu(self.mix_layer1(x)) #This line is skipped in the original implementation mix_2 = F.relu(self.mix_layer2(x)) v, _ = torch.max(mix_2, dim=0, keepdim=True) x = self.final_mix_layer(x) + v rnn_out, state = self.lstm(x, state) controls = self.rnn_to_controls(rnn_out) lr_mult = torch.squeeze(self.step_size(rnn_out), -1) return controls, lr_mult, state