Skip to content

PostPruneConfig

Config for post-training prune: select dimensions by weight-based importance (part), then prune.

When passed as training_args.post_prune_config, train() runs post_prune() automatically after training and saves the pruned model. You can also call post_prune() manually when not using this config.

inplace class-attribute instance-attribute

inplace = True

If True, prune the model in place.

mask class-attribute instance-attribute

mask = None

Optional bool mask of shape (input_dim,). Not serialized in to_dict.

part class-attribute instance-attribute

part = 'decoder-weight'

Importance source: 'encoder-weight' | 'decoder-weight' | 'decoder-bias' | 'decoder-sum'.

return_mask class-attribute instance-attribute

return_mask = False

If True, also return the combined mask from prune.

threshold class-attribute instance-attribute

threshold = None

Same as prune(): keep dims with importance >= threshold.

topk class-attribute instance-attribute

topk = None

Same as prune(): int (absolute, top-k dims) or float in (0,1] (relative). topk=1.0 (float) means no pruning. One of topk, threshold, or mask required.

__post_init__

__post_init__()
Source code in gradiend/trainer/core/pruning.py
def __post_init__(self) -> None:
    if self.topk is None and self.threshold is None and self.mask is None:
        raise ValueError("PostPruneConfig: at least one of topk, threshold, or mask must be set.")
    _validate_topk(self.topk, "topk")
    _validate_threshold(self.threshold, "threshold")
    if not isinstance(self.part, str):
        raise TypeError(f"part must be str, got {type(self.part).__name__}")
    if self.part not in ("encoder-weight", "decoder-weight", "decoder-bias", "decoder-sum"):
        raise ValueError(
            "part must be encoder-weight, decoder-weight, decoder-bias, or decoder-sum; "
            f"got {self.part!r}"
        )
    if not isinstance(self.inplace, bool):
        raise TypeError(f"inplace must be bool, got {type(self.inplace).__name__}")
    if not isinstance(self.return_mask, bool):
        raise TypeError(f"return_mask must be bool, got {type(self.return_mask).__name__}")

__str__

__str__()
Source code in gradiend/trainer/core/pruning.py
def __str__(self) -> str:
    return f"PostPruneConfig(topk={self.topk!r}, threshold={self.threshold!r}, part={self.part!r}, inplace={self.inplace})"