Evaluation visualization
This guide documents all plot functions available for visualizing GRADIEND training and evaluation. It focuses on plot customization only — for how to run evaluation (encoder, decoder, metrics), see Tutorial: Evaluation (intra-model) and Tutorial: Evaluation (inter-model).
Plot overview
| Plot | Purpose | Entry point |
|---|---|---|
| Training convergence | Mean encoded values and correlation over training steps | trainer.plot_training_convergence() |
| Encoder distributions | Split violins of encoded values by class/transition | trainer.plot_encoder_distributions() or evaluate_encoder(..., plot=True, plot_kwargs=...) |
| Encoder scatter | Interactive 1D scatter (jitter x, encoded y) for outlier inspection | trainer.plot_encoder_scatter() |
| Top-k overlap heatmap | Pairwise overlap of top-k weight sets across models | plot_topk_overlap_heatmap() |
| Top-k overlap Venn | Venn diagram of top-k set intersection (2–6 models) | plot_topk_overlap_venn() |
1. Training convergence plot
Shows how training metrics evolve over steps: mean encoded value per class/feature class and correlation. The best checkpoint step (by convergence metric) is marked with a vertical line.
Entry points
# Via trainer (typical)
trainer.plot_training_convergence()
# Standalone (from model dir or pre-loaded stats)
from gradiend.visualizer.convergence import plot_training_convergence
plot_training_convergence(model_path="runs/experiment/model", show=True)
plot_training_convergence(training_stats=stats_dict, output="convergence.pdf")

Customization options
| Parameter | Type | Default | Description |
|---|---|---|---|
label_name_mapping |
Dict[str, str] |
None |
Map class/feature-class ids to display labels (e.g. "masc_nom" → "Masc. Nom."). Use when raw ids are technical or hard to read. |
plot_mean_by_class |
bool |
True |
Include subplot for mean encoded value per class. |
plot_mean_by_feature_class |
bool \| None |
None (auto) |
Include subplot for mean encoded value per feature class. When None, defaults to False if redundant with mean-by-class (e.g. all_classes == target_classes or no identity transitions), otherwise True. |
plot_correlation |
bool |
True |
Include subplot for correlation over steps. |
best_step |
bool |
True |
Draw vertical line and mark best checkpoint step. |
title |
str or bool |
True |
True = use run_id, False = no title, string = custom title. |
figsize |
Tuple[float, float] |
None |
Figure size in inches. Default: (8, 3 * n_subplots). |
output |
str |
None |
Explicit output file path. |
experiment_dir |
str |
None |
Used to resolve default artifact path when output is not set. |
show |
bool |
True |
Whether to call plt.show(). |
img_format |
str |
"png" |
Image format (e.g. "pdf", "png"). Appended to output path. |
Use cases
- Human-readable class labels (e.g. German article paradigm): pass
label_name_mappingso legend shows "Masc. Nom." instead ofmasc_nom. See gender_de_detailed.py. - Many classes (e.g. identity transitions): with ≥6 series a single figure-level legend is drawn to the right of the plot (independent of subplot height). Use
legend_ncol,legend_bbox_to_anchor, andlegend_locto adjust placement. - Publication-ready figure: set
output,figsize, andimg_format="png"or"pdf". - Minimal plot (only correlation): set
plot_mean_by_class=False,plot_mean_by_feature_class=False.
2. Encoder distribution plot
Grouped split violins showing the distribution of encoded values by transition/class. By default the plot shows only the target (training) transition(s) and neutral data; use target_and_neutral_only=False to include all transitions. Each group has left and right halves (e.g. masc→fem vs fem→masc in one split violin). When using evaluate_encoder(..., plot=True), any plot option can be forwarded via plot_kwargs (e.g. plot_kwargs=dict(target_and_neutral_only=False, show=False)).
Entry points
# Via trainer (requires encoder_df from evaluate_encoder; plot options via plot_kwargs)
enc_eval = trainer.evaluate_encoder(max_size=100, return_df=True, plot=True, plot_kwargs={...})
# Direct call with pre-computed encoder_df
trainer.plot_encoder_distributions(encoder_df=enc_df, legend_name_mapping={...})

Customization options (via plot_kwargs or direct call)
| Parameter | Type | Default | Description |
|---|---|---|---|
target_and_neutral_only |
bool |
True |
Restrict the plot to the target (training) transition(s) and neutral data only. Set to False to show all transitions. Uses trainer.pair to determine the target transition(s). |
legend_name_mapping |
Dict[str, str] |
None |
Map raw legend labels to display names (e.g. "masc_nom -> fem_nom" → "M→F"). |
legend_group_mapping |
Dict[str, List[str]] |
None |
Group multiple transitions into one legend entry. Keys = display label; values = list of raw labels to merge. Groups are downsampled to balance counts. Example: {"der": ["masc_nom -> masc_nom", "fem_dat -> fem_dat"], "die": ["fem_nom -> fem_nom", "fem_acc -> fem_acc"]}. |
paired_legend_labels |
List[str] |
None |
Explicit order of legend labels. Consecutive pairs (0,1), (2,3), … form split violins. |
violin_order |
List[str] |
None |
Order of violin groups on the x-axis (by legend label name). |
colors |
Dict[str, str] |
None |
Map legend labels to hex colors. |
cmap |
str |
"tab20" |
Matplotlib colormap for palette. |
legend_loc |
str |
"best" |
Matplotlib legend location. When >6 entries and legend_bbox_to_anchor is not set, legend is placed below the plot. |
legend_ncol |
int |
2 |
Number of columns in the legend. |
legend_bbox_to_anchor |
Tuple[float, float] or None |
None |
(x, y) for legend. When >6 entries and None, legend is placed below (0.5, -0.06). |
title |
str or bool |
True |
True = use run_id, False = no title, string = custom title. |
title_fontsize |
float |
None |
Title font size. |
label_fontsize |
float |
None |
Axis tick label font size. |
axis_label_fontsize |
float |
None |
Axis label font size. |
legend_fontsize |
float |
None |
Legend text font size. |
output |
str |
None |
Explicit output path. |
output_dir |
str |
None |
Output directory when output and experiment_dir are not set. |
show |
bool |
True |
Whether to call plt.show(). |
img_format |
str |
"png" |
Image format (e.g. "pdf", "png"). |
Use cases
- Show all transitions (e.g. multi-class pronoun setup): pass
target_and_neutral_only=Falseso the plot includes every transition, not only the target pair and neutral. - Renaming labels for readability:
legend_name_mapping={"masc_nom -> fem_nom": "M→F", "fem_nom -> masc_nom": "F→M"}. See gender_de_detailed.py. - Grouping by surface form (e.g. all “der” transitions together):
legend_group_mapping={"der": ["masc_nom -> masc_nom", "fem_dat -> fem_dat", ...], "die": [...], "das": [...]}. See gender_de_detailed.py. - Many legend entries (>6): the legend is placed below the plot by default so the axes stay large. Override with
legend_bbox_to_anchor=(x, y)andlegend_locif needed. - Custom colors for specific classes:
colors={"M→F": "#1f77b4", "F→M": "#ff7f0e"}. - Paper-style plot: set
legend_fontsize,axis_label_fontsize,output,img_format="png".
3. Encoder scatter plot
Interactive Plotly scatter: x = jitter, y = encoded value, colored by a chosen column. Intended for Jupyter to inspect outliers (hover shows point data).
Optional dependency: Plotly is required for this plot. Install with pip install plotly. If Plotly is not installed, the function returns None and logs a warning. See Installation.
Example: Jupyter notebook gradiend/examples/encoder_scatter.ipynb — trains on HuggingFace data (aieng-lab/gradiend_race_data), runs encoder evaluation, and shows the interactive scatter inline.
Entry point
Customization options
| Parameter | Type | Default | Description |
|---|---|---|---|
encoder_df |
pd.DataFrame |
None |
Pre-computed encoder analysis. If None, calls trainer.analyze_encoder(). |
color_by |
str |
"label" |
Column used for point color. |
hover_cols |
List[str] |
None |
Columns shown on hover. Default: existing cols among text, label, encoded, source_id, target_id, type. |
jitter_scale |
float |
0.15 |
Scale of random jitter on x-axis. |
max_points |
int |
None |
Max number of points; subsampling is stratified. |
stratify_by |
str |
None |
Column for stratified subsampling when max_points is set. Default: feature_class or color_by. |
cmap |
str |
"tab20" |
Matplotlib colormap for colors (matches encoder violins). |
height |
int |
500 |
Figure height in pixels. |
title |
str |
None |
Plot title. |
output_path |
str |
None |
Path to save HTML. |
output_dir |
str |
None |
Directory for HTML when output_path and experiment_dir are not set. |
show |
bool |
True |
Whether to display the figure. |
hover_text_max_chars |
int |
50 |
Max characters for text in hover; truncated around first [MASK] with .... |
Use case
- Outlier inspection in Jupyter: run with default
show=True; hover over points to seetext,label, etc. - Large datasets: set
max_points=500to avoid slow rendering; usestratify_by="feature_class"to keep class balance.
4. Top-k overlap heatmap
Heatmap of pairwise overlap between top-k weight index sets across multiple GRADIEND models. Rows and columns are the dict keys of models; use display labels (e.g. run_id or "3SG ↔ 3PL") as keys for readable axis labels. Cell value is overlap (raw count or normalized fraction).
Entry point
from gradiend.visualizer.topk.pairwise_heatmap import plot_topk_overlap_heatmap
models = {t.run_id: t.get_model() for t in trainers}
plot_topk_overlap_heatmap(
models,
topk=1000,
part="decoder-weight",
output_path="topk_overlap_heatmap.png",
)
output_path="topk_overlap_heatmap.png" to view.
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
models |
Dict[str, ModelWithGradiend] |
required | Mapping from label to model. Keys are used as axis labels; use display labels (e.g. run_id or "3SG ↔ 3PL") as keys. |
topk |
int |
1000 |
Number of top weights per model. |
part |
str |
"decoder-weight" |
Weight part for importance ranking: encoder-weight, decoder-weight, decoder-bias, or decoder-sum. |
value |
str |
"intersection" |
Cell value: "intersection" (raw |A ∩ B|) or "intersection_frac" (|A ∩ B| / k). Use "intersection_frac" for cross-experiment comparison. |
order |
str or List[str] |
"input" |
Order of models on axes: "input" (dict order), "name" (alphabetical), or explicit list. Ignored if pretty_groups is set. |
cluster |
bool |
False |
Reorder models by similarity (greedy) so similar models are adjacent. |
annot |
bool or str |
"auto" |
True = always annotate cells, False = never, "auto" = annotate only if ≤ 25 models. |
fmt |
str |
None |
Format string for annotations (e.g. "d", ".2f"). Default: "d" for intersection, ".2f" for fraction. |
figsize |
Tuple[float, float] |
None |
Figure size. Default: (max(14, n*0.4), max(14, n*0.4)). |
cmap |
str |
"viridis" |
Colormap for heatmap. |
vmin, vmax |
float |
None |
Colormap bounds. Default: [0, k] for intersection, [0, 1] for fraction. |
scale |
str |
"linear" |
Color scale: "linear", "log", "sqrt", or "power". |
scale_gamma |
float |
None |
Gamma for scale="power" (e.g. 0.5 for sqrt-like). |
pretty_groups |
Dict[str, List[str]] |
None |
Map group name → list of labels (dict keys). Groups are shown on top/right; keys must be disjoint. Uncovered keys go to "Other". |
annot_fontsize |
float |
None |
Font size for cell annotations. |
tick_label_fontsize |
float |
None |
Font size for axis tick labels. |
group_label_fontsize |
float |
None |
Font size for group labels (when pretty_groups is set). |
cbar_pad |
float |
None |
Padding between heatmap and colorbar. |
title |
str or bool |
False |
Plot title. Default title is auto-generated. |
output_path |
str |
None |
Path to save the figure. |
show |
bool |
True |
Whether to call plt.show(). |
return_data |
bool |
True |
Return overlap matrix and auxiliary data. |
Use cases
- Compare many runs (e.g. German article paradigm): pass a
modelsdict whose keys are the display labels (e.g. "der ↔ die", "3SG ↔ 3PL") and usepretty_groupsto group by transition. See gender_de_detailed.py and multilingual_gradiend_demo.py. - Normalized comparison across experiments:
value="intersection_frac". - Clustered layout:
cluster=Trueto order models by similarity. - Publication: set
output_path,figsize,tick_label_fontsize,annot_fontsize.
5. Top-k overlap Venn diagram
Venn diagram showing the intersection of top-k weight sets across 2–6 models. Dict keys of models are used as set labels; use the same display labels as keys as for the heatmap for consistent labeling. For 2–3 models uses matplotlib-venn; for 4–6 uses the venn package.
Entry point
from gradiend.visualizer.topk import plot_topk_overlap_venn
plot_topk_overlap_venn(
models,
topk=1000,
part="decoder-weight",
output_path="topk_overlap_venn.png",
)
Run the code above to generate Venn diagrams (e.g. 3-set or 5-set). Save with output_path="venn.png" to view.
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
models |
Dict[str, ModelWithGradiend] |
required | Mapping from label to model (2–6 entries). Keys are used as set labels; use display labels as keys for consistency with the heatmap. |
topk |
int |
1000 |
Number of top weights per model. |
part |
str |
"decoder-weight" |
Weight part: encoder-weight, decoder-weight, decoder-bias, or decoder-sum. |
output_path |
str |
None |
Path to save the figure. |
show |
bool |
True |
Whether to call plt.show(). |
Return value
Dict with per_model (label → list of weight indices; keys match input models), intersection, union, topk, part. Useful for downstream analysis.
Dependencies
- 2–3 models:
pip install matplotlib-venn - 4–6 models:
pip install venn
Image format and output paths
Plots that save to disk use img_format (e.g. "pdf", "png") when available. If experiment_dir is set via TrainingArguments, the trainer resolves default paths under that directory (e.g. [experiment_dir]/[run_id]/training_convergence.png). You can override with explicit output, output_path, or output_dir parameters.
Example references
| Script / notebook | Plots demonstrated |
|---|---|
| gender_de_detailed.py | Training convergence with label_name_mapping; encoder distributions with legend_group_mapping; top-k heatmap with value="intersection_frac"; top-k Venn per transition. |
| english_pronouns.ipynb | Full workflow: data creation → training (3SG vs 3PL) → encoder/decoder evaluation and probability-shifts plot. |
| gender_de.py | Basic training convergence, encoder distributions, decoder evaluation. |
| race_religion.py | Encoder distributions via evaluate_encoder(..., plot=True). |
| evaluation-inter-model | Top-k overlap concepts, heatmap and Venn usage. |