Skip to content

Evaluation visualization

This guide documents all plot functions available for visualizing GRADIEND training and evaluation. It focuses on plot customization only — for how to run evaluation (encoder, decoder, metrics), see Tutorial: Evaluation (intra-model) and Tutorial: Evaluation (inter-model).


Plot overview

Plot Purpose Entry point
Training convergence Mean encoded values and correlation over training steps trainer.plot_training_convergence()
Encoder distributions Split violins of encoded values by class/transition trainer.plot_encoder_distributions() or evaluate_encoder(..., plot=True, plot_kwargs=...)
Encoder scatter Interactive 1D scatter (jitter x, encoded y) for outlier inspection trainer.plot_encoder_scatter()
Top-k overlap heatmap Pairwise overlap of top-k weight sets across models plot_topk_overlap_heatmap()
Top-k overlap Venn Venn diagram of top-k set intersection (2–6 models) plot_topk_overlap_venn()

1. Training convergence plot

Shows how training metrics evolve over steps: mean encoded value per class/feature class and correlation. The best checkpoint step (by convergence metric) is marked with a vertical line.

Entry points

# Via trainer (typical)
trainer.plot_training_convergence()

# Standalone (from model dir or pre-loaded stats)
from gradiend.visualizer.convergence import plot_training_convergence

plot_training_convergence(model_path="runs/experiment/model", show=True)
plot_training_convergence(training_stats=stats_dict, output="convergence.pdf")

Training Convergence

Customization options

Parameter Type Default Description
label_name_mapping Dict[str, str] None Map class/feature-class ids to display labels (e.g. "masc_nom""Masc. Nom."). Use when raw ids are technical or hard to read.
plot_mean_by_class bool True Include subplot for mean encoded value per class.
plot_mean_by_feature_class bool \| None None (auto) Include subplot for mean encoded value per feature class. When None, defaults to False if redundant with mean-by-class (e.g. all_classes == target_classes or no identity transitions), otherwise True.
plot_correlation bool True Include subplot for correlation over steps.
best_step bool True Draw vertical line and mark best checkpoint step.
title str or bool True True = use run_id, False = no title, string = custom title.
figsize Tuple[float, float] None Figure size in inches. Default: (8, 3 * n_subplots).
output str None Explicit output file path.
experiment_dir str None Used to resolve default artifact path when output is not set.
show bool True Whether to call plt.show().
img_format str "png" Image format (e.g. "pdf", "png"). Appended to output path.

Use cases

  • Human-readable class labels (e.g. German article paradigm): pass label_name_mapping so legend shows "Masc. Nom." instead of masc_nom. See gender_de_detailed.py.
  • Many classes (e.g. identity transitions): with ≥6 series a single figure-level legend is drawn to the right of the plot (independent of subplot height). Use legend_ncol, legend_bbox_to_anchor, and legend_loc to adjust placement.
  • Publication-ready figure: set output, figsize, and img_format="png" or "pdf".
  • Minimal plot (only correlation): set plot_mean_by_class=False, plot_mean_by_feature_class=False.

2. Encoder distribution plot

Grouped split violins showing the distribution of encoded values by transition/class. By default the plot shows only the target (training) transition(s) and neutral data; use target_and_neutral_only=False to include all transitions. Each group has left and right halves (e.g. masc→fem vs fem→masc in one split violin). When using evaluate_encoder(..., plot=True), any plot option can be forwarded via plot_kwargs (e.g. plot_kwargs=dict(target_and_neutral_only=False, show=False)).

Entry points

# Via trainer (requires encoder_df from evaluate_encoder; plot options via plot_kwargs)
enc_eval = trainer.evaluate_encoder(max_size=100, return_df=True, plot=True, plot_kwargs={...})

# Direct call with pre-computed encoder_df
trainer.plot_encoder_distributions(encoder_df=enc_df, legend_name_mapping={...})

Encoder Distributions

Customization options (via plot_kwargs or direct call)

Parameter Type Default Description
target_and_neutral_only bool True Restrict the plot to the target (training) transition(s) and neutral data only. Set to False to show all transitions. Uses trainer.pair to determine the target transition(s).
legend_name_mapping Dict[str, str] None Map raw legend labels to display names (e.g. "masc_nom -> fem_nom""M→F").
legend_group_mapping Dict[str, List[str]] None Group multiple transitions into one legend entry. Keys = display label; values = list of raw labels to merge. Groups are downsampled to balance counts. Example: {"der": ["masc_nom -> masc_nom", "fem_dat -> fem_dat"], "die": ["fem_nom -> fem_nom", "fem_acc -> fem_acc"]}.
paired_legend_labels List[str] None Explicit order of legend labels. Consecutive pairs (0,1), (2,3), … form split violins.
violin_order List[str] None Order of violin groups on the x-axis (by legend label name).
colors Dict[str, str] None Map legend labels to hex colors.
cmap str "tab20" Matplotlib colormap for palette.
legend_loc str "best" Matplotlib legend location. When >6 entries and legend_bbox_to_anchor is not set, legend is placed below the plot.
legend_ncol int 2 Number of columns in the legend.
legend_bbox_to_anchor Tuple[float, float] or None None (x, y) for legend. When >6 entries and None, legend is placed below (0.5, -0.06).
title str or bool True True = use run_id, False = no title, string = custom title.
title_fontsize float None Title font size.
label_fontsize float None Axis tick label font size.
axis_label_fontsize float None Axis label font size.
legend_fontsize float None Legend text font size.
output str None Explicit output path.
output_dir str None Output directory when output and experiment_dir are not set.
show bool True Whether to call plt.show().
img_format str "png" Image format (e.g. "pdf", "png").

Use cases

  • Show all transitions (e.g. multi-class pronoun setup): pass target_and_neutral_only=False so the plot includes every transition, not only the target pair and neutral.
  • Renaming labels for readability: legend_name_mapping={"masc_nom -> fem_nom": "M→F", "fem_nom -> masc_nom": "F→M"}. See gender_de_detailed.py.
  • Grouping by surface form (e.g. all “der” transitions together): legend_group_mapping={"der": ["masc_nom -> masc_nom", "fem_dat -> fem_dat", ...], "die": [...], "das": [...]}. See gender_de_detailed.py.
  • Many legend entries (>6): the legend is placed below the plot by default so the axes stay large. Override with legend_bbox_to_anchor=(x, y) and legend_loc if needed.
  • Custom colors for specific classes: colors={"M→F": "#1f77b4", "F→M": "#ff7f0e"}.
  • Paper-style plot: set legend_fontsize, axis_label_fontsize, output, img_format="png".

3. Encoder scatter plot

Interactive Plotly scatter: x = jitter, y = encoded value, colored by a chosen column. Intended for Jupyter to inspect outliers (hover shows point data).

Optional dependency: Plotly is required for this plot. Install with pip install plotly. If Plotly is not installed, the function returns None and logs a warning. See Installation.

Example: Jupyter notebook gradiend/examples/encoder_scatter.ipynb — trains on HuggingFace data (aieng-lab/gradiend_race_data), runs encoder evaluation, and shows the interactive scatter inline.

Entry point

trainer.plot_encoder_scatter(encoder_df=enc_df)

Customization options

Parameter Type Default Description
encoder_df pd.DataFrame None Pre-computed encoder analysis. If None, calls trainer.analyze_encoder().
color_by str "label" Column used for point color.
hover_cols List[str] None Columns shown on hover. Default: existing cols among text, label, encoded, source_id, target_id, type.
jitter_scale float 0.15 Scale of random jitter on x-axis.
max_points int None Max number of points; subsampling is stratified.
stratify_by str None Column for stratified subsampling when max_points is set. Default: feature_class or color_by.
cmap str "tab20" Matplotlib colormap for colors (matches encoder violins).
height int 500 Figure height in pixels.
title str None Plot title.
output_path str None Path to save HTML.
output_dir str None Directory for HTML when output_path and experiment_dir are not set.
show bool True Whether to display the figure.
hover_text_max_chars int 50 Max characters for text in hover; truncated around first [MASK] with ....

Use case

  • Outlier inspection in Jupyter: run with default show=True; hover over points to see text, label, etc.
  • Large datasets: set max_points=500 to avoid slow rendering; use stratify_by="feature_class" to keep class balance.

4. Top-k overlap heatmap

Heatmap of pairwise overlap between top-k weight index sets across multiple GRADIEND models. Rows and columns are the dict keys of models; use display labels (e.g. run_id or "3SG ↔ 3PL") as keys for readable axis labels. Cell value is overlap (raw count or normalized fraction).

Entry point

from gradiend.visualizer.topk.pairwise_heatmap import plot_topk_overlap_heatmap

models = {t.run_id: t.get_model() for t in trainers}
plot_topk_overlap_heatmap(
    models,
    topk=1000,
    part="decoder-weight",
    output_path="topk_overlap_heatmap.png",
)
Run the code above to generate the heatmap; save with output_path="topk_overlap_heatmap.png" to view.

Parameters

Parameter Type Default Description
models Dict[str, ModelWithGradiend] required Mapping from label to model. Keys are used as axis labels; use display labels (e.g. run_id or "3SG ↔ 3PL") as keys.
topk int 1000 Number of top weights per model.
part str "decoder-weight" Weight part for importance ranking: encoder-weight, decoder-weight, decoder-bias, or decoder-sum.
value str "intersection" Cell value: "intersection" (raw |A ∩ B|) or "intersection_frac" (|A ∩ B| / k). Use "intersection_frac" for cross-experiment comparison.
order str or List[str] "input" Order of models on axes: "input" (dict order), "name" (alphabetical), or explicit list. Ignored if pretty_groups is set.
cluster bool False Reorder models by similarity (greedy) so similar models are adjacent.
annot bool or str "auto" True = always annotate cells, False = never, "auto" = annotate only if ≤ 25 models.
fmt str None Format string for annotations (e.g. "d", ".2f"). Default: "d" for intersection, ".2f" for fraction.
figsize Tuple[float, float] None Figure size. Default: (max(14, n*0.4), max(14, n*0.4)).
cmap str "viridis" Colormap for heatmap.
vmin, vmax float None Colormap bounds. Default: [0, k] for intersection, [0, 1] for fraction.
scale str "linear" Color scale: "linear", "log", "sqrt", or "power".
scale_gamma float None Gamma for scale="power" (e.g. 0.5 for sqrt-like).
pretty_groups Dict[str, List[str]] None Map group name → list of labels (dict keys). Groups are shown on top/right; keys must be disjoint. Uncovered keys go to "Other".
annot_fontsize float None Font size for cell annotations.
tick_label_fontsize float None Font size for axis tick labels.
group_label_fontsize float None Font size for group labels (when pretty_groups is set).
cbar_pad float None Padding between heatmap and colorbar.
title str or bool False Plot title. Default title is auto-generated.
output_path str None Path to save the figure.
show bool True Whether to call plt.show().
return_data bool True Return overlap matrix and auxiliary data.

Use cases

  • Compare many runs (e.g. German article paradigm): pass a models dict whose keys are the display labels (e.g. "der ↔ die", "3SG ↔ 3PL") and use pretty_groups to group by transition. See gender_de_detailed.py and multilingual_gradiend_demo.py.
  • Normalized comparison across experiments: value="intersection_frac".
  • Clustered layout: cluster=True to order models by similarity.
  • Publication: set output_path, figsize, tick_label_fontsize, annot_fontsize.

5. Top-k overlap Venn diagram

Venn diagram showing the intersection of top-k weight sets across 2–6 models. Dict keys of models are used as set labels; use the same display labels as keys as for the heatmap for consistent labeling. For 2–3 models uses matplotlib-venn; for 4–6 uses the venn package.

Entry point

from gradiend.visualizer.topk import plot_topk_overlap_venn

plot_topk_overlap_venn(
    models,
    topk=1000,
    part="decoder-weight",
    output_path="topk_overlap_venn.png",
)

Run the code above to generate Venn diagrams (e.g. 3-set or 5-set). Save with output_path="venn.png" to view.

Parameters

Parameter Type Default Description
models Dict[str, ModelWithGradiend] required Mapping from label to model (2–6 entries). Keys are used as set labels; use display labels as keys for consistency with the heatmap.
topk int 1000 Number of top weights per model.
part str "decoder-weight" Weight part: encoder-weight, decoder-weight, decoder-bias, or decoder-sum.
output_path str None Path to save the figure.
show bool True Whether to call plt.show().

Return value

Dict with per_model (label → list of weight indices; keys match input models), intersection, union, topk, part. Useful for downstream analysis.

Dependencies

  • 2–3 models: pip install matplotlib-venn
  • 4–6 models: pip install venn

Image format and output paths

Plots that save to disk use img_format (e.g. "pdf", "png") when available. If experiment_dir is set via TrainingArguments, the trainer resolves default paths under that directory (e.g. [experiment_dir]/[run_id]/training_convergence.png). You can override with explicit output, output_path, or output_dir parameters.


Example references

Script / notebook Plots demonstrated
gender_de_detailed.py Training convergence with label_name_mapping; encoder distributions with legend_group_mapping; top-k heatmap with value="intersection_frac"; top-k Venn per transition.
english_pronouns.ipynb Full workflow: data creation → training (3SG vs 3PL) → encoder/decoder evaluation and probability-shifts plot.
gender_de.py Basic training convergence, encoder distributions, decoder evaluation.
race_religion.py Encoder distributions via evaluate_encoder(..., plot=True).
evaluation-inter-model Top-k overlap concepts, heatmap and Venn usage.