2-stage pruning to favor distributed inference (local device compute half of the model, upload the feature for further computing on stronger devices or cloud).
by hou-yz, based on kuangliu/pytorch-cifar
improve inference speed and reduce intermediate feature sizes to favor distributed inference (local device compute half of the model and upload the feature for further computing on stronger devices or cloud).
pruning stage-1: prune the whole model to increase inference speed and slightly reduce intermediate feature sizes.
pruning stage-2: (based on step-1’s model) for each split-point (where the intermediate feature is transferred to another device for further computation), specifically prune the layer just before the split-point to reduce intermediate feature sizes even more.
only support python3 with pytorch > 0.3.1;
model trained on cifar-10, tested only on vgg-16.
also added auto-logging and auto chart-drawing.
python main.py --train # train from scratch
python main.py --resume # resume training
first, in step-1, you can prune the whole model by
python main.py --prune # prune the whole model
once you finished step-1, you can then prune each layer (step-2) individually for minimum bandwidth requirement with
python main.py --prune_layer # prune layers and save models separately
for logging and excel chart drawing, try
python maim.py --test_pruned # test the pruned model and save *.json logs
python draw_chart.py
which automatically generate the chart.xlsx
file.
openpyxl
(open in excel);