Skip to content

Core classes and use cases

This guide explains the most important classes in GRADIEND and when to use them. For complete API details, see the API reference.


Overview: the GRADIEND workflow

The typical GRADIEND workflow involves three main components:

  1. Data creationTextPredictionDataCreator and TextFilterConfig
  2. TrainingTextPredictionTrainer and TrainingArguments
  3. EvaluationEvaluator, EncoderEvaluator, DecoderEvaluator

Below, we explain each component and the key classes within them.


Data creation

TextPredictionDataCreator

Purpose: Build training and neutral datasets from raw text or existing data sources.

When to use: - You have raw text and want to extract feature-specific training data - You need to create masked text pairs for a specific feature (e.g., gender, grammatical case) - You want to generate neutral evaluation data

Key methods: - generate_training_data() — Creates training/validation/test splits with masked texts - generate_neutral_data() — Creates feature-neutral evaluation data

Example:

from gradiend import TextPredictionDataCreator, TextFilterConfig

creator = TextPredictionDataCreator(
    base_data=["The chef tasted the soup, then he added pepper."],
    feature_targets=[
        TextFilterConfig(targets=["he", "she", "it"], id="3SG"),
        TextFilterConfig(targets=["they"], id="3PL"),
    ],
)
training = creator.generate_training_data(max_size_per_class=10)
neutral = creator.generate_neutral_data(max_size=15)

See also: Data generation tutorial, Data handling guide


TextFilterConfig

Purpose: Define a feature class by specifying target tokens and optional linguistic constraints.

When to use: - Defining what tokens/patterns belong to a feature class - Using spaCy tags for grammatical filtering (e.g., gender, case, number) - Creating multiple classes for complex features (e.g., German gender–case with 12 classes)

Key attributes: - targets — List of tokens to match (e.g., ["he", "she", "it"]) - spacy_tags — Dictionary of spaCy tags for filtering (e.g., {"Gender": "Masc", "Case": "Nom"}) - id — Identifier for the class (used as column names and class labels)

Example:

from gradiend import TextFilterConfig

# Simple string matching
config = TextFilterConfig(targets=["he", "she"], id="3SG")

# With spaCy tags (German definite articles)
config = TextFilterConfig(
    targets=["der"],
    spacy_tags={"pos": "DET", "Case": "Nom", "Gender": "Masc"},
    id="masc_nom"
)

See also: Data generation tutorial


Training

TextPredictionTrainer

Purpose: Train a GRADIEND model to learn a feature from gradient differences.

When to use: - Training a GRADIEND encoder-decoder on your data - The main entry point for the GRADIEND training pipeline - You want to train, evaluate, and save a GRADIEND model

Key methods: - train() — Start training - evaluate_encoder() — Evaluate encoder correlation and separation - evaluate_decoder() — Evaluate decoder's ability to modify the base model. By default computes strengthen summaries only (keys e.g. "3SG"). Use increase_target_probabilities=False to compute weaken summaries only (keys e.g. "3SG_weaken"); only the dataset–feature-factor combinations required for the chosen direction are evaluated. - rewrite_base_model() — Rewrite base model(s) using decoder evaluation results for given target_class(s), optionally save to disk. By default strengthens the target class; set increase_target_probabilities=False to apply weakening config (requires having run evaluate_decoder(increase_target_probabilities=False) so that dec["<class>_weaken"] exists). - plot_training_convergence() — Visualize training progress

Example:

from gradiend import TextPredictionTrainer, TrainingArguments

args = TrainingArguments(
    train_batch_size=4,
    max_steps=100,
    learning_rate=1e-4,
)

trainer = TextPredictionTrainer(
    model="bert-base-uncased",
    data=training_data,
    eval_neutral_data=neutral_data,
    args=args,
)
trainer.train()

# Evaluate
enc_result = trainer.evaluate_encoder(plot=True)
dec_result = trainer.evaluate_decoder()
changed_model = trainer.rewrite_base_model(decoder_results=dec_result, target_class="masc_nom")  # strengthen (default)

See also: Start here, Training tutorial


TrainingArguments

Purpose: Configure training hyperparameters and experiment settings.

When to use: - Setting batch sizes, learning rates, training length - Configuring multi-seed training - Enabling pruning (pre/post) - Setting evaluation frequency

Key attributes: - train_batch_size, eval_batch_size — Batch sizes - max_steps, num_train_epochs — Training length - learning_rate — Learning rate for GRADIEND model - eval_steps — How often to evaluate during training - max_seeds — Number of seeds for multi-seed training - pre_prune_config, post_prune_config — Pruning settings

Example:

from gradiend import TrainingArguments, PrePruneConfig

args = TrainingArguments(
    train_batch_size=8,
    max_steps=200,
    learning_rate=1e-4,
    eval_steps=20,
    max_seeds=3,  # Train 3 seeds and pick best
    pre_prune_config=PrePruneConfig(keep_top_k=1000),
)

See also: Training arguments guide, Pruning guide


TextPredictionConfig

Purpose: Configure data loading and preprocessing for TextPredictionTrainer.

When to use: - Loading data from HuggingFace datasets - Specifying column names for custom data formats - Configuring target classes and data splits

Key attributes: - data — Training data (DataFrame, dict, HF dataset ID, or file path) - hf_dataset — HuggingFace dataset ID - target_classes — Classes to train on (pair is auto-inferred if len=2) - class_merge_map — Map base classes to merged classes (e.g. {"singular": ["1SG", "3SG"], "plural": ["1PL", "3PL"]}); when set, target_classes uses merged names; when exactly two keys, target_classes can be omitted - masked_col, split_col — Column names - eval_neutral_dataOptional. Neutral evaluation data (DataFrame, path, or HuggingFace dataset ID). Used for encoder (neutral_dataset variant) and decoder evaluation (LMS). When omitted, decoder evaluation falls back to training-like data (test split with factual masks filled in); target tokens are then ignored in LMS. See Evaluation (intra-model).

Example:

from gradiend import TextPredictionConfig

config = TextPredictionConfig(
    data="aieng-lab/de-gender-case-articles",
    target_classes=["masc_nom", "fem_nom"],
    masked_col="masked",
    split_col="split",
    eval_neutral_data="aieng-lab/wortschatz-leipzig-de-grammar-neutral",
)

See also: Data handling guide


Model classes

GradiendModel

Purpose: The core encoder-decoder model (weights-only, no base model context).

When to use: - Low-level access to GRADIEND encoder/decoder weights - Saving/loading GRADIEND models independently - Computing importance scores from weights

Key methods: - forward() — Encode gradients to latent space - forward_decoder() — Decode from latent space to gradient space - save_pretrained() — Save model weights and config - from_pretrained() — Load a saved model

Note: Most users should use ModelWithGradiend or TextPredictionTrainer instead, which handle the base model integration.

See also: API reference


ParamMappedGradiendModel

Purpose: GRADIEND model with parameter mapping for base-model gradients.

When to use: - Working with gradient dictionaries (parameter name → gradient tensor) - Need to map between GRADIEND's parameter space and base model parameters - Advanced use cases requiring direct gradient manipulation

Key difference from GradiendModel: Handles parameter name mapping, enabling gradient I/O as dictionaries.

See also: API reference


ModelWithGradiend

Purpose: Wrapper combining a base language model with a GRADIEND encoder-decoder.

When to use: - Loading a trained GRADIEND model with its base model - Applying decoder updates to modify the base model - Evaluating encoder/decoder on a loaded model

Key methods: - encode() — Encode gradients to latent feature value - rewrite_base_model() — Rewrite the base model by applying decoder updates (takes learning_rate, feature_factor). The trainer’s rewrite_base_model(target_class=..., increase_target_probabilities=...) selects these from decoder evaluation and calls this method. - from_pretrained() — Load base model + GRADIEND model

Example:

from gradiend import ModelWithGradiend

model = ModelWithGradiend.from_pretrained(
    base_model="bert-base-uncased",
    gradiend_model="path/to/gradiend/model",
)

# Encode a gradient
feature_value = model.encode(gradient_dict)

# Rewrite the base model
modified_model = model.rewrite_base_model(
    learning_rate=1e-4,
    feature_factor=1.0,
    part='decoder'
)

See also: API reference, Saving & loading guide


Evaluation

Evaluator

Purpose: High-level evaluation coordinator bound to a trainer.

When to use: - Running encoder and decoder evaluation together - Convenient access to evaluation methods from a trainer - Default evaluation workflow

Key methods: - evaluate_encoder() — Run encoder evaluation (correlation, plots) - evaluate_decoder() — Run decoder evaluation (probability shifts) - Delegates plotting to Visualizer if configured

Note: TextPredictionTrainer already has an Evaluator instance, so you typically call trainer.evaluate_encoder() directly.

See also: API reference, Evaluation tutorial


EncoderEvaluator

Purpose: Evaluate how well the encoder separates feature classes.

When to use: - Computing correlation between encoded values and feature classes - Analyzing encoder separation on test/neutral data - Debugging encoder convergence

Key outputs: - Correlation coefficient (higher = better separation) - Encoded value distributions per class - Plots showing class separation

See also: API reference, Evaluation guide


DecoderEvaluator

Purpose: Evaluate how well the decoder can modify the base model.

When to use: - Testing if decoder updates change model behavior as expected - Finding optimal learning rate and feature factor - Measuring probability shifts for target tokens

Key outputs: - Grid search results over learning rates and feature factors - Probability shift scores - Language modeling scores (to ensure model quality is maintained)

See also: API reference, Evaluation tutorial


Utility classes

TextPreprocessConfig

Purpose: Configure text preprocessing for data creation: split into sentences, filter by length or characters.

When to use: - Splitting long texts (e.g. articles) into sentences before filtering - Dropping segments that are too short or too long (min_chars, max_chars) - Excluding segments containing certain characters (exclude_chars) or via a custom filter

Key attributes: - split_to_sentences — If True, split on sentences (regex or spaCy). Default False. - min_chars — Drop segments shorter than this. Default None (no minimum). - max_chars — Drop segments longer than this. Default None (no maximum). - exclude_chars — Drop segments containing any of these characters. Default None. - custom_filter — Optional callable (str) -> bool; keep only segments where it returns True.

Default: There is no default instance. In TextPredictionDataCreator, preprocess=None (default) means no preprocessing: texts are used as-is. Pass a TextPreprocessConfig instance when you want sentence splitting or length/char filtering.

See also: Data generation tutorial, API: TextPreprocessConfig.


PrePruneConfig / PostPruneConfig

Purpose: Configure pruning to reduce model size and focus on important parameters.

When to use: - Reducing computational cost - Focusing on most important parameters - Speeding up training

Key attributes: - keep_top_k — Number of parameters to keep - importance_metric — How to compute importance (e.g., "gradient_norm", "weight_magnitude")

Pre-pruning: Prune before training based on gradient statistics.
Post-pruning: Prune after training based on learned weights.

See also: Pruning guide


Quick reference: which class to use?

Task Primary class Secondary classes
Create training data from text TextPredictionDataCreator TextFilterConfig, TextPreprocessConfig
Load precomputed data TextPredictionTrainer (via data parameter) TextPredictionConfig
Train a GRADIEND model TextPredictionTrainer TrainingArguments, TextPredictionConfig
Evaluate encoder TextPredictionTrainer.evaluate_encoder() EncoderEvaluator
Evaluate decoder TextPredictionTrainer.evaluate_decoder() DecoderEvaluator
Load a trained model ModelWithGradiend.from_pretrained() GradiendModel.from_pretrained()
Rewrite a model ModelWithGradiend.rewrite_base_model() TextPredictionTrainer.rewrite_base_model()
Configure pruning PrePruneConfig, PostPruneConfig Used via TrainingArguments

Next steps