PrePruneConfig
Config for pre-prune: gradient mean over n samples, then prune by topk or threshold.
By default use the dataset passed to pre_prune(); set dataset to use different data for this step.
batch_size
class-attribute
instance-attribute
Batch size for gradient computation (samples per chunk; one gradient per sample).
dataset
class-attribute
instance-attribute
Optional override: use this dataset instead of the one passed to pre_prune().
feature_class_key
class-attribute
instance-attribute
Key in item dict for stratification (must have exactly two target classes).
source
class-attribute
instance-attribute
Which gradient to average: 'factual' | 'alternative' | 'diff'.
target_feature_class_ids
class-attribute
instance-attribute
Class IDs to stratify over (neutral/identity are ignored). If None and pre_prune(..., definition=trainer) is used, the trainer's target feature class IDs are used automatically.
threshold
class-attribute
instance-attribute
Same as prune(): keep dims with importance >= threshold.
topk
class-attribute
instance-attribute
Same as prune(): int (absolute, top-k dims) or float in (0,1] (relative). topk=1.0 (float) means no pruning. One of topk or threshold required.
use_cached_gradients
class-attribute
instance-attribute
If True, can use same cache key as training. Default False.