Core classes and use cases
This guide explains the most important classes in GRADIEND and when to use them. For complete API details, see the API reference.
Overview: the GRADIEND workflow
The typical GRADIEND workflow involves three main components:
- Data creation →
TextPredictionDataCreatorandTextFilterConfig - Training →
TextPredictionTrainerandTrainingArguments - Evaluation →
Evaluator,EncoderEvaluator,DecoderEvaluator
Below, we explain each component and the key classes within them.
Data creation
TextPredictionDataCreator
Purpose: Build training and neutral datasets from raw text or existing data sources.
When to use: - You have raw text and want to extract feature-specific training data - You need to create masked text pairs for a specific feature (e.g., gender, grammatical case) - You want to generate neutral evaluation data
Key methods:
- generate_training_data() — Creates training/validation/test splits with masked texts
- generate_neutral_data() — Creates feature-neutral evaluation data
Example:
from gradiend import TextPredictionDataCreator, TextFilterConfig
creator = TextPredictionDataCreator(
base_data=["The chef tasted the soup, then he added pepper."],
feature_targets=[
TextFilterConfig(targets=["he", "she", "it"], id="3SG"),
TextFilterConfig(targets=["they"], id="3PL"),
],
)
training = creator.generate_training_data(max_size_per_class=10)
neutral = creator.generate_neutral_data(max_size=15)
See also: Data generation tutorial, Data handling guide
TextFilterConfig
Purpose: Define a feature class by specifying target tokens and optional linguistic constraints.
When to use: - Defining what tokens/patterns belong to a feature class - Using spaCy tags for grammatical filtering (e.g., gender, case, number) - Creating multiple classes for complex features (e.g., German gender–case with 12 classes)
Key attributes:
- targets — List of tokens to match (e.g., ["he", "she", "it"])
- spacy_tags — Dictionary of spaCy tags for filtering (e.g., {"Gender": "Masc", "Case": "Nom"})
- id — Identifier for the class (used as column names and class labels)
Example:
from gradiend import TextFilterConfig
# Simple string matching
config = TextFilterConfig(targets=["he", "she"], id="3SG")
# With spaCy tags (German definite articles)
config = TextFilterConfig(
targets=["der"],
spacy_tags={"pos": "DET", "Case": "Nom", "Gender": "Masc"},
id="masc_nom"
)
See also: Data generation tutorial
Training
TextPredictionTrainer
Purpose: Train a GRADIEND model to learn a feature from gradient differences.
When to use: - Training a GRADIEND encoder-decoder on your data - The main entry point for the GRADIEND training pipeline - You want to train, evaluate, and save a GRADIEND model
Key methods:
- train() — Start training
- evaluate_encoder() — Evaluate encoder correlation and separation
- evaluate_decoder() — Evaluate decoder's ability to modify the base model. By default computes strengthen summaries only (keys e.g. "3SG"). Use increase_target_probabilities=False to compute weaken summaries only (keys e.g. "3SG_weaken"); only the dataset–feature-factor combinations required for the chosen direction are evaluated.
- rewrite_base_model() — Rewrite base model(s) using decoder evaluation results for given target_class(s), optionally save to disk. By default strengthens the target class; set increase_target_probabilities=False to apply weakening config (requires having run evaluate_decoder(increase_target_probabilities=False) so that dec["<class>_weaken"] exists).
- plot_training_convergence() — Visualize training progress
Example:
from gradiend import TextPredictionTrainer, TrainingArguments
args = TrainingArguments(
train_batch_size=4,
max_steps=100,
learning_rate=1e-4,
)
trainer = TextPredictionTrainer(
model="bert-base-uncased",
data=training_data,
eval_neutral_data=neutral_data,
args=args,
)
trainer.train()
# Evaluate
enc_result = trainer.evaluate_encoder(plot=True)
dec_result = trainer.evaluate_decoder()
changed_model = trainer.rewrite_base_model(decoder_results=dec_result, target_class="masc_nom") # strengthen (default)
See also: Start here, Training tutorial
TrainingArguments
Purpose: Configure training hyperparameters and experiment settings.
When to use: - Setting batch sizes, learning rates, training length - Configuring multi-seed training - Enabling pruning (pre/post) - Setting evaluation frequency
Key attributes:
- train_batch_size, eval_batch_size — Batch sizes
- max_steps, num_train_epochs — Training length
- learning_rate — Learning rate for GRADIEND model
- eval_steps — How often to evaluate during training
- max_seeds — Number of seeds for multi-seed training
- pre_prune_config, post_prune_config — Pruning settings
Example:
from gradiend import TrainingArguments, PrePruneConfig
args = TrainingArguments(
train_batch_size=8,
max_steps=200,
learning_rate=1e-4,
eval_steps=20,
max_seeds=3, # Train 3 seeds and pick best
pre_prune_config=PrePruneConfig(keep_top_k=1000),
)
See also: Training arguments guide, Pruning guide
TextPredictionConfig
Purpose: Configure data loading and preprocessing for TextPredictionTrainer.
When to use: - Loading data from HuggingFace datasets - Specifying column names for custom data formats - Configuring target classes and data splits
Key attributes:
- data — Training data (DataFrame, dict, HF dataset ID, or file path)
- hf_dataset — HuggingFace dataset ID
- target_classes — Classes to train on (pair is auto-inferred if len=2)
- class_merge_map — Map base classes to merged classes (e.g. {"singular": ["1SG", "3SG"], "plural": ["1PL", "3PL"]}); when set, target_classes uses merged names; when exactly two keys, target_classes can be omitted
- masked_col, split_col — Column names
- eval_neutral_data — Optional. Neutral evaluation data (DataFrame, path, or HuggingFace dataset ID). Used for encoder (neutral_dataset variant) and decoder evaluation (LMS). When omitted, decoder evaluation falls back to training-like data (test split with factual masks filled in); target tokens are then ignored in LMS. See Evaluation (intra-model).
Example:
from gradiend import TextPredictionConfig
config = TextPredictionConfig(
data="aieng-lab/de-gender-case-articles",
target_classes=["masc_nom", "fem_nom"],
masked_col="masked",
split_col="split",
eval_neutral_data="aieng-lab/wortschatz-leipzig-de-grammar-neutral",
)
See also: Data handling guide
Model classes
GradiendModel
Purpose: The core encoder-decoder model (weights-only, no base model context).
When to use: - Low-level access to GRADIEND encoder/decoder weights - Saving/loading GRADIEND models independently - Computing importance scores from weights
Key methods:
- forward() — Encode gradients to latent space
- forward_decoder() — Decode from latent space to gradient space
- save_pretrained() — Save model weights and config
- from_pretrained() — Load a saved model
Note: Most users should use ModelWithGradiend or TextPredictionTrainer instead, which handle the base model integration.
See also: API reference
ParamMappedGradiendModel
Purpose: GRADIEND model with parameter mapping for base-model gradients.
When to use: - Working with gradient dictionaries (parameter name → gradient tensor) - Need to map between GRADIEND's parameter space and base model parameters - Advanced use cases requiring direct gradient manipulation
Key difference from GradiendModel: Handles parameter name mapping, enabling gradient I/O as dictionaries.
See also: API reference
ModelWithGradiend
Purpose: Wrapper combining a base language model with a GRADIEND encoder-decoder.
When to use: - Loading a trained GRADIEND model with its base model - Applying decoder updates to modify the base model - Evaluating encoder/decoder on a loaded model
Key methods:
- encode() — Encode gradients to latent feature value
- rewrite_base_model() — Rewrite the base model by applying decoder updates (takes learning_rate, feature_factor). The trainer’s rewrite_base_model(target_class=..., increase_target_probabilities=...) selects these from decoder evaluation and calls this method.
- from_pretrained() — Load base model + GRADIEND model
Example:
from gradiend import ModelWithGradiend
model = ModelWithGradiend.from_pretrained(
base_model="bert-base-uncased",
gradiend_model="path/to/gradiend/model",
)
# Encode a gradient
feature_value = model.encode(gradient_dict)
# Rewrite the base model
modified_model = model.rewrite_base_model(
learning_rate=1e-4,
feature_factor=1.0,
part='decoder'
)
See also: API reference, Saving & loading guide
Evaluation
Evaluator
Purpose: High-level evaluation coordinator bound to a trainer.
When to use: - Running encoder and decoder evaluation together - Convenient access to evaluation methods from a trainer - Default evaluation workflow
Key methods:
- evaluate_encoder() — Run encoder evaluation (correlation, plots)
- evaluate_decoder() — Run decoder evaluation (probability shifts)
- Delegates plotting to Visualizer if configured
Note: TextPredictionTrainer already has an Evaluator instance, so you typically call trainer.evaluate_encoder() directly.
See also: API reference, Evaluation tutorial
EncoderEvaluator
Purpose: Evaluate how well the encoder separates feature classes.
When to use: - Computing correlation between encoded values and feature classes - Analyzing encoder separation on test/neutral data - Debugging encoder convergence
Key outputs: - Correlation coefficient (higher = better separation) - Encoded value distributions per class - Plots showing class separation
See also: API reference, Evaluation guide
DecoderEvaluator
Purpose: Evaluate how well the decoder can modify the base model.
When to use: - Testing if decoder updates change model behavior as expected - Finding optimal learning rate and feature factor - Measuring probability shifts for target tokens
Key outputs: - Grid search results over learning rates and feature factors - Probability shift scores - Language modeling scores (to ensure model quality is maintained)
See also: API reference, Evaluation tutorial
Utility classes
TextPreprocessConfig
Purpose: Configure text preprocessing for data creation: split into sentences, filter by length or characters.
When to use:
- Splitting long texts (e.g. articles) into sentences before filtering
- Dropping segments that are too short or too long (min_chars, max_chars)
- Excluding segments containing certain characters (exclude_chars) or via a custom filter
Key attributes:
- split_to_sentences — If True, split on sentences (regex or spaCy). Default False.
- min_chars — Drop segments shorter than this. Default None (no minimum).
- max_chars — Drop segments longer than this. Default None (no maximum).
- exclude_chars — Drop segments containing any of these characters. Default None.
- custom_filter — Optional callable (str) -> bool; keep only segments where it returns True.
Default: There is no default instance. In TextPredictionDataCreator, preprocess=None (default) means no preprocessing: texts are used as-is. Pass a TextPreprocessConfig instance when you want sentence splitting or length/char filtering.
See also: Data generation tutorial, API: TextPreprocessConfig.
PrePruneConfig / PostPruneConfig
Purpose: Configure pruning to reduce model size and focus on important parameters.
When to use: - Reducing computational cost - Focusing on most important parameters - Speeding up training
Key attributes:
- keep_top_k — Number of parameters to keep
- importance_metric — How to compute importance (e.g., "gradient_norm", "weight_magnitude")
Pre-pruning: Prune before training based on gradient statistics.
Post-pruning: Prune after training based on learned weights.
See also: Pruning guide
Quick reference: which class to use?
| Task | Primary class | Secondary classes |
|---|---|---|
| Create training data from text | TextPredictionDataCreator |
TextFilterConfig, TextPreprocessConfig |
| Load precomputed data | TextPredictionTrainer (via data parameter) |
TextPredictionConfig |
| Train a GRADIEND model | TextPredictionTrainer |
TrainingArguments, TextPredictionConfig |
| Evaluate encoder | TextPredictionTrainer.evaluate_encoder() |
EncoderEvaluator |
| Evaluate decoder | TextPredictionTrainer.evaluate_decoder() |
DecoderEvaluator |
| Load a trained model | ModelWithGradiend.from_pretrained() |
GradiendModel.from_pretrained() |
| Rewrite a model | ModelWithGradiend.rewrite_base_model() |
TextPredictionTrainer.rewrite_base_model() |
| Configure pruning | PrePruneConfig, PostPruneConfig |
Used via TrainingArguments |
Next steps
- Start here — Run a complete example; script also in gradiend/examples on GitHub
- API reference — Complete API documentation
- Tutorials — Step-by-step workflows
- Guides — Detailed guides for specific topics