Skip to content

PrePruneConfig

Config for pre-prune: gradient mean over n samples, then prune by topk or threshold.

By default use the dataset passed to pre_prune(); set dataset to use different data for this step.

batch_size class-attribute instance-attribute

batch_size = 8

Batch size for gradient computation (samples per chunk; one gradient per sample).

dataset class-attribute instance-attribute

dataset = None

Optional override: use this dataset instead of the one passed to pre_prune().

feature_class_key class-attribute instance-attribute

feature_class_key = 'feature_class_id'

Key in item dict for stratification (must have exactly two target classes).

n_samples instance-attribute

n_samples

Total number of samples to use for the gradient mean.

source class-attribute instance-attribute

source = 'factual'

Which gradient to average: 'factual' | 'alternative' | 'diff'.

target_feature_class_ids class-attribute instance-attribute

target_feature_class_ids = None

Class IDs to stratify over (neutral/identity are ignored). If None and pre_prune(..., definition=trainer) is used, the trainer's target feature class IDs are used automatically.

threshold class-attribute instance-attribute

threshold = None

Same as prune(): keep dims with importance >= threshold.

topk class-attribute instance-attribute

topk = None

Same as prune(): int (absolute, top-k dims) or float in (0,1] (relative). topk=1.0 (float) means no pruning. One of topk or threshold required.

use_cached_gradients class-attribute instance-attribute

use_cached_gradients = False

If True, can use same cache key as training. Default False.

__post_init__

__post_init__()
Source code in gradiend/trainer/core/pruning.py
def __post_init__(self) -> None:
    if self.topk is None and self.threshold is None:
        raise ValueError("PrePruneConfig: at least one of topk or threshold must be set.")
    if not isinstance(self.n_samples, int):
        raise TypeError(f"n_samples must be int, got {type(self.n_samples).__name__}")
    if self.n_samples < 1:
        raise ValueError(f"n_samples must be >= 1, got {self.n_samples}")
    _validate_topk(self.topk, "topk")
    _validate_threshold(self.threshold, "threshold")
    if not isinstance(self.source, str):
        raise TypeError(f"source must be str, got {type(self.source).__name__}")
    if self.source not in ("factual", "alternative", "diff"):
        raise ValueError(f"source must be one of factual, alternative, diff; got {self.source!r}")
    if not isinstance(self.batch_size, int):
        raise TypeError(f"batch_size must be int, got {type(self.batch_size).__name__}")
    if self.batch_size < 1:
        raise ValueError(f"batch_size must be >= 1, got {self.batch_size}")
    if not isinstance(self.feature_class_key, str):
        raise TypeError(f"feature_class_key must be str, got {type(self.feature_class_key).__name__}")
    if not isinstance(self.use_cached_gradients, bool):
        raise TypeError(f"use_cached_gradients must be bool, got {type(self.use_cached_gradients).__name__}")

__str__

__str__()
Source code in gradiend/trainer/core/pruning.py
def __str__(self) -> str:
    return f"PrePruneConfig(n_samples={self.n_samples}, topk={self.topk!r}, threshold={self.threshold!r}, source={self.source!r})"