Skip to content

TextPredictionConfig

Bases: TrainerConfig

Configuration for TextPredictionTrainer.

Unified data contract: internal representation uses masked, split, factual_class, alternative_class, factual, alternative, transition. Training uses only rows where transition ∈ {c1→c2, c2→c1} for the configured pair.

Data input: - data as Dict[str, DataFrame]: per-class format. Key = factual_class; each df has masked, split, and columns = class names (factual = df[factual_class], alternative = df[alternative_class]). If a df has no column for another class (single-token-per-class, e.g. Gender), target is inferred as the other class's token for the configured pair. - data as DataFrame: merged format with label_class_col, label_col, and optionally target_col, target_class_col. If target columns omitted, pair is required and target = other class's token. - hf_dataset: load from HuggingFace (merged format), then convert to unified.

Attributes:

Name Type Description
run_id Optional[str]

Optional run identifier (subdir and display).

data Optional[Union[DataFrame, Dict[str, DataFrame], str, Path]]

Per-class dict (label_class -> DataFrame) or merged DataFrame.

hf_dataset Optional[str]

HuggingFace dataset ID; loads merged format and converts to unified.

hf_subset Optional[Union[str, List[str]]]

Subset name(s) to load when using hf_dataset.

hf_splits Optional[List[str]]

Splits to include (e.g. ['train', 'validation', 'test']).

target_classes Optional[List[str]]

Target classes for training. Pair is automatically inferred when len(target_classes) == 2.

all_classes Optional[List[str]]

All classes available in the dataset. If None (default), inferred from data. When loading from HuggingFace datasets, None means load all configs/subsets.

masked_col, split_col

Column names. For merged format also label_col, label_class_col, and optionally target_col, target_class_col.

use_class_names_as_columns bool

For per-class data, use class name as column name for tokens.

all_classes class-attribute instance-attribute

all_classes = None

All classes available in the dataset. If None (default), inferred from data. When loading from HuggingFace datasets, None means load all configs/subsets.

alternative_class_col class-attribute instance-attribute

alternative_class_col = 'alternative_class'

alternative_col class-attribute instance-attribute

alternative_col = 'alternative'

class_merge_map class-attribute instance-attribute

class_merge_map = None

class_merge_transition_groups class-attribute instance-attribute

class_merge_transition_groups = None

data class-attribute instance-attribute

data = None

decoder_eval_ignore_tokens class-attribute instance-attribute

decoder_eval_ignore_tokens = None

decoder_eval_lms_max_samples class-attribute instance-attribute

decoder_eval_lms_max_samples = None

decoder_eval_prob_on_other_class class-attribute instance-attribute

decoder_eval_prob_on_other_class = True

decoder_eval_restrict_to_target_classes class-attribute instance-attribute

decoder_eval_restrict_to_target_classes = True

decoder_eval_targets class-attribute instance-attribute

decoder_eval_targets = None

eval_neutral_data class-attribute instance-attribute

eval_neutral_data = None

eval_neutral_max_rows class-attribute instance-attribute

eval_neutral_max_rows = None

hf_dataset class-attribute instance-attribute

hf_dataset = None

hf_splits class-attribute instance-attribute

hf_splits = None

hf_subset class-attribute instance-attribute

hf_subset = None

label_class_col class-attribute instance-attribute

label_class_col = 'label_class'

label_col class-attribute instance-attribute

label_col = 'label'

masked_col class-attribute instance-attribute

masked_col = 'masked'

max_counterfactuals_per_sentence class-attribute instance-attribute

max_counterfactuals_per_sentence = 1

n_features class-attribute instance-attribute

n_features = 1

random_state class-attribute instance-attribute

random_state = None

run_id class-attribute instance-attribute

run_id = None

split_col class-attribute instance-attribute

split_col = 'split'

target_classes class-attribute instance-attribute

target_classes = None

Target classes for training. Pair is automatically inferred when len(target_classes) == 2.

use_class_names_as_columns class-attribute instance-attribute

use_class_names_as_columns = True

__str__

__str__()
Source code in gradiend/trainer/text/prediction/trainer.py
def __str__(self) -> str:
    return (
        f"TextPredictionConfig(img_format={self.img_format!r}, target_classes={self.target_classes!r}, "
        f"masked_col={self.masked_col!r}, n_features={self.n_features})"
    )