Model Pruning: A New Approach
Are we on the verge of a revolution in reducing the parameters of language models without significant loss of their capabilities?
Yesterday, I had the pleasure of going through a notebook on model pruning from Pere Martra's book "Large Language Models: Apply and Implement Strategies for Large Language Models" (Apress)*. After some minor modifications, I was able to reduce the meditsolutions/Llama-3.1-MedIT-SUN-8B model to 4.8B parameters. Fine-tuning the smaller model restored most of the "capabilities" of the original model. But let's start from the beginning.
What is model pruning?
Model pruning is a method of reducing the number of model parameters while trying to maintain its capabilities to speed up operation under inference conditions. This method identifies "neurons" whose contribution to the final prediction is minimal or negligible. The identified neurons are cut out of the model by simply removing their index from the weight of a given layer, thereby reducing the size of the layer weight without significant loss of quality. Pere Martra, in his method, proposed identifying unnecessary neurons using the maximum absolute values of weights in MLP layers of models based on the LLama architecture.
def compute_neuron_pair_importance(gate_weight, up_weight):
"""
compute neuron pair importance scores (Maximum Absolute Weight)
Args:
- gate_weight: Weight matrix from the gate_proj layer.
- up_weight: Weight matrix from the up_weight layer.
Returns:
- importance_scores: Importance scores for each neuron pair.
"""
gate_max_abs = torch.max(torch.abs(gate_weight), dim=1).values
up_max_abs = torch.max(torch.abs(up_weight), dim=1).values
importance_scores = gate_max_abs + up_max_abs
return importance_scores
I added my two cents by slightly modifying the method so that when calculating the importance of weights, we also take into account the minimum values:
def compute_neuron_pair_importance(gate_weight, up_weight):
...
gate_max_abs = torch.max(gate_weight, dim=1).values + torch.abs(torch.min(gate_weight, dim=1).values)
up_max_abs = torch.max(up_weight, dim=1).values + torch.abs(torch.min(up_weight, dim=1).values)
importance_scores = gate_max_abs + up_max_abs
return importance_scores
Results
The results obtained thanks to this allowed for cutting the weights of the meditsolutions/Llama-3.1-MedIT-SUN-8B model to 4.8B parameters (about 40% reduction!). The loss after pruning is about 2.1, so further fine-tuning quickly returns the model’s capabilities to generate coherent sequences.
Discussion
Of course, this is just the beginning of exploring this method. Still, after initial analysis and experiments, it's clear that it has enormous potential in reducing the models we can embed in production, lowering costs and inference time. Thus, it creates space for even greater innovation toward smaller SLMs that can imitate their larger counterparts.
The final method is applicable to all models with a Llama-like architecture that includes MLP gating, such as Llama, Phi, Mistral, Qwen, SmolLM, and others.
If you want to experiment, here is the notebook: Large-Language-Model-Notebooks-Course/6-PRUNING/6_3_pruning_structured_llama3.2-1b_OK.ipynb. To reproduce, follow the notebook with the pruning percentage of about 0.57 - 0.58.
Note from Pere Martra
The pruning section was created after the first edition of the book was published. They are not included in the book’s original content but are intended to supplement and expand on the topics covered.