Scaling analysis

  • from unit_scaling.utils import analyse_module
    • using torch.fx to record the computation graph (symbolic tracer)

Details to think about

Checking correctness

  • implement unit scaling and mup separately
    • require unit scaling to be activated for mup to work
    • check unit scaling by
      • scaling analysis
      • individual operators testing test_functional.py
    • check mup by coord checking + HP sweep

Unit scaling the model

  • Need to scale the loss,

    • e.g. cross-entropy requires to the gradients to be scaled up
  • matmuls,

  • residual-adds

    • should scale the residual add by (CLT)
    • need to expliclity split between residual stream (the part that will modified by the current block) and skip (the original input that stays unchanged)
    • residual_split scales the backward (happens at the beginning of the forward of the residual block)
    • residual_add scales the forward (happens at the end of the forward of the residual block)
    • can be combined in residual_apply(fn) for cleaner code
    • the taus guiding the scaling of the residual layers is dependent on the network structure
      • unit-scaling implemented the rule for pre-norm architectures e.g. Llama (Appendix F.4)
      • exact equations are described in Appendix G.2.2
  • embedding layer and last layer

    • need to be explicitly scaled differently (not subject to constraint of scale_fwd=scale_bwd)
  • constrain the scales of operations to have the same forward and backward factors for the output and input gradient (except in the case input and output layers, where you are free to change that)

    • the weight gradients can still be given their own scaling factor due to the cut-edge rule (i.e. they don’t affect the rest of the computational graph)
  • mup_scaling_depth

mupifying the model

  • Implemented in unit_scaling/parameter.py

    • Parameters are tagged with their role in the model (as a “bias”, “norm” parameter, “weight” or “output”).
      • The library achieves this by extending torch.nn.Parameter with an additional property mup_type. (can be added and removed of the state_dict with hooks potentially)
      • This property is required for every parameter in a u-ÎĽP model. Given this, and information on the overall depth of the model, the library applies the learning rate rules of Table 2 as a pre-optimizer transformation that modifies the learning rate for each parameter.
        • This allows standard PyTorch optimizers to be used without modification.
  • DepthSequential class to keep track of the depth

  • removing the trainable parameters in the norms

  • The multipliers are the mult parameters in the library

    • https://github.com/graphcore-research/unit-scaling/issues/62
    • can be merged into scale for MHA
    • need to be implemented for
      • ffn-activation, attn-softmax, loss softmax
      • residual_mult and residual_attn_ratio
        • Appendix G of u-mup apper
        • If both hyperparameters are set to 1.0, the total contribution of embedding, attention and MLP layers are all equal.
  • The optimizer

    • implements lr scaling based on weight type and depth
    • implements independent adamW