Data handling (text prediction)
This guide describes how to provide text data to TextPredictionTrainer and which columns your datasets must have. It is specific to text prediction.
Overview: three input paths
TextPredictionTrainer accepts data in three main ways:
| Path | Use when |
|---|---|
| Per-class HuggingFace | Your dataset is on HuggingFace with one config/subset per class (e.g. aieng-lab/de-gender-case-articles, aieng-lab/gradiend_race_data). Pass the dataset ID to data. |
| Per-class dict | You have a Python dict of DataFrames keyed by class name, or output from TextPredictionDataCreator.generate_training_data(format="per_class"). Pass it to data. |
| Merged | You have a single table (DataFrame, CSV, Parquet, or HF dataset) where each row has label_class and label. Pass a DataFrame or file path to data; or pass an HF dataset ID to hf_dataset. |
1. Per-class HuggingFace format
Use when your dataset is on HuggingFace and has one config/subset per class (e.g. masc_nom, fem_nom or white, black, asian).
How to specify
trainer = TextPredictionTrainer(
model="bert-base-uncased",
data="aieng-lab/de-gender-case-articles", # HuggingFace dataset ID (string, not a file path)
target_classes=["masc_nom", "fem_nom"], # The pair to train on
masked_col="masked",
split_col="split",
# Optional:
all_classes=["masc_nom", "fem_nom", "neut_nom"], # Limit which configs to load; default: all
hf_splits=["train", "validation", "test"], # Splits to include; default: all
)
Required columns (per subset/config)
Each config/subset of the dataset must have:
| Column | Required | Description |
|---|---|---|
masked |
✅ | Sentence with [MASK] at the target position. |
split |
✅ | Dataset split: train, validation (or val), test. |
| Factual token | ✅ | One column holding the token that fills the mask for this class. See below. |
Factual token column: The trainer looks for a column in this order:
- A column with the same name as the class (e.g.
masc_nomin subsetmasc_nom) — whenuse_class_names_as_columns=True(default) - Otherwise
labelortoken
Two variants of per-class HuggingFace data
Per-class datasets come in two shapes. Both are valid; the trainer handles them automatically.
Variant A: Single token per row (one factual column)
Each subset has one factual column. The alternative token is derived by joining with the other target class’s subset (same masked + split).
Example: aieng-lab/de-gender-case-articles
- Configs:
masc_nom,fem_nom,neut_nom, etc. - Each config has:
masked,split,label(the factual token for that class) - No columns for other classes — alternatives come from the other config’s rows
# gender_de example
trainer = TextPredictionTrainer(
model="bert-base-german-cased",
data="aieng-lab/de-gender-case-articles",
target_classes=["masc_nom", "fem_nom"],
masked_col="masked",
split_col="split",
)
Variant B: All tokens per row (one column per class)
Each subset has columns for every class, so factual and alternative tokens are in the same row (i.e., alternative tokens are explicitly defined; this can be useful if different multiple target words are considered per target class that are not pairwise interchangeable).
Example: aieng-lab/gradiend_race_data, aieng-lab/gradiend_religion_data
- Configs:
white,black,asian(orchristian,jewish,muslim) - Each config has:
masked,split,white,black,asian(columns = class names) - The column matching the subset holds the factual token; other columns hold alternatives
# race_religion example
trainer = TextPredictionTrainer(
model="distilbert-base-cased",
data="aieng-lab/gradiend_race_data",
target_classes=["white", "black"],
masked_col="masked",
split_col="split",
)
2. Per-class dict format
Use when you have a Python dict mapping class name → DataFrame, e.g. from TextPredictionDataCreator.generate_training_data(format="per_class") or your own preparation.
How to specify
per_class_data = {
"masc_nom": df_masc, # DataFrame for class masc_nom
"fem_nom": df_fem, # DataFrame for class fem_nom
}
trainer = TextPredictionTrainer(
model="bert-base-uncased",
data=per_class_data,
target_classes=["masc_nom", "fem_nom"],
masked_col="masked",
split_col="split",
)
Required columns (per DataFrame)
Each DataFrame in the dict must have:
| Column | Required | Description |
|---|---|---|
masked |
✅ | Sentence with [MASK] at the target position. |
split |
✅ | Dataset split. Defaults to "train" if missing. |
| Factual token | ✅ | One column with the factual token: class-name column, label, token, or source/target (see below). |
Factual token column (lookup order):
- Class-name column (e.g.
masc_nomwhen key is"masc_nom") — whenuse_class_names_as_columns=True(default) label,token, orsource(for pre-paired data)
Alternative tokens:
- If the DataFrame has columns for other classes: the alternative is taken from that column in the same row.
- If not (single-token-per-class): the alternative is derived from the other class’s DataFrame (pair must be set via
target_classes).
Pre-paired rows (source/target): If a DataFrame has source and target columns, each row is treated as a factual/alternative pair. Optional: source_id, target_id for class names.
Configurable column names
Override defaults with TextPredictionConfig:
| Parameter | Default | Description |
|---|---|---|
masked_col |
"masked" |
Column for masked sentences |
split_col |
"split" |
Column for split |
use_class_names_as_columns |
True |
Use class names as token column names when present |
3. Merged format
Use when you have a single table where each row already has a class and label. Can be a DataFrame, CSV/Parquet path, or HuggingFace dataset (via hf_dataset).
How to specify
DataFrame or file path:
# DataFrame in memory
trainer = TextPredictionTrainer(
model="bert-base-uncased",
data=df, # DataFrame with masked, split, label_class, label
target_classes=["A", "B"],
masked_col="masked",
split_col="split",
label_col="label",
label_class_col="label_class",
)
# Local file
trainer = TextPredictionTrainer(
model="bert-base-uncased",
data="path/to/data.csv", # or .parquet
target_classes=["A", "B"],
...
)
HuggingFace (merged):
trainer = TextPredictionTrainer(
model="bert-base-uncased",
hf_dataset="org/merged-dataset", # Use hf_dataset for merged format
hf_subset="subset_name", # Optional: specific config
hf_splits=["train", "validation", "test"],
target_classes=["A", "B"],
masked_col="masked",
split_col="split",
label_col="label",
label_class_col="label_class",
)
Required columns
| Column | Required | Description |
|---|---|---|
masked |
✅ | Sentence with [MASK] at the target position |
split |
✅ | Dataset split (train, validation, test, or val) |
label_class |
✅ | The factual (source) class for this row (e.g. "masc_nom") |
label |
✅ | The factual token that fills the mask |
Optional: explicit alternative columns
If your table has explicit factual/alternative pairs per row:
| Column | Required | Description |
|---|---|---|
alternative |
Optional | The counterfactual token for this row |
alternative_class |
Optional | The counterfactual class (used with alternative) |
When both alternative and alternative_class are present, the trainer uses them directly. Otherwise it infers the alternative from the other class in the pair (requires target_classes with exactly two classes and a single token per class in the pair).
Configurable column names
| Parameter | Default | Description |
|---|---|---|
masked_col |
"masked" |
Column for masked sentences |
split_col |
"split" |
Column for split |
label_col |
"label" |
Column for factual token |
label_class_col |
"label_class" |
Column for factual class |
alternative_col |
"alternative" |
Column for alternative token (merged only) |
alternative_class_col |
"alternative_class" |
Column for alternative class (merged only) |
Import flow
The trainer decides how to load and convert data using this resolution order:
- If hf_dataset is set: load that dataset (with
hf_subset,hf_splits), then convert using merged column names. - Else if data is a string and not a file path: treat as HuggingFace ID → per-class load (multiple configs/subsets).
- Else if data is a dict: per-class DataFrames; keys = class names.
- Else if data is a file path (str or Path): load CSV/parquet as merged format.
- Else data is a DataFrame: use as merged format in memory.
Data source (HF id / path / DataFrame / dict)
→ load or use in memory
→ merged_to_unified() or per_class_dict_to_unified()
→ unified DataFrame (pair-filtered)
→ TextTrainingDataset / GradientTrainingDataset
→ training batches
Training pair and target classes
Training uses transitions between two classes (e.g. masculine vs feminine). Specify them with target_classes:
target_classes=["masc_nom", "fem_nom"] # Pair for training (currently: must have exactly two elements!)
- When
all_classes(usually determined by data automatically) has exactly two elements, the trainer uses these two classes automatically astarget_classes. - Merged classes: Use
class_merge_mapto merge base classes (e.g.{"singular": ["1SG", "3SG"], "plural": ["1PL", "3PL"]}); thentarget_classesrefers to merged names. With exactly two merged keys,target_classescan be omitted.
Data balancing and size caps
Balancing and caps rely on a feature-class identifier. After conversion, the trainer infers this from the class labels in your data.
TrainingArguments.train_max_size: Caps training samples; applied per feature class when available.TrainingArguments.encoder_eval_max_size,decoder_eval_max_size_training_like, etc.: Similar per-class caps for evaluation.- When class information is available, the scheduler oversamples smaller classes and
train_max_sizeapplies per-class downsampling.
Splits
Datasets should provide train, validation (or val), and test splits. The trainer normalizes split names (e.g. "val" → "validation"). Support for fewer splits may be added later.
Optional: neutral evaluation data (eval_neutral_data)
TextPredictionTrainer accepts optional neutral data via eval_neutral_data (DataFrame, path, or HuggingFace dataset ID). This is used for:
- Encoder evaluation — a separate
neutral_datasetvariant for encoding gradients from feature-independent text. - Decoder evaluation (LMS) — text to compute language modeling score (perplexity) without feature-related targets.
When eval_neutral_data is omitted or empty:
- Encoder evaluation still runs (no
neutral_datasetvariant). - Decoder evaluation falls back to training-like data (test split): each row's
textis built by filling the mask with the factual token. Target tokens are automatically added todecoder_eval_ignore_tokensso they are ignored in LMS. This works for quick runs; for best practice, provide true neutral data when available (e.g.TextPredictionDataCreator.generate_neutral_data()or HuggingFace datasets likeaieng-lab/wortschatz-leipzig-de-grammar-neutral).
See Evaluation (intra-model) for details.
Generating data from raw text
To create training data from base corpora (Wikipedia, CSV, HuggingFace, or lists of strings), use TextPredictionDataCreator from gradiend.data. See the Data generation tutorial for full usage.