Tutorial: Model Rewrite
This tutorial describes how to export a persistently modified checkpoint whose behavior is shifted along the learned feature direction (e.g., higher token probabilities for one target class). You can either use assisted parameter selection via decoder evaluation (recommended) or supply manual feature factor and learning rate.
For running decoder evaluation itself, see Tutorial: Evaluation (intra-model).
Prerequisites
Before rewriting, you need:
- A trained GRADIEND model (i.e.
trainer.train()has been run). - Either decoder evaluation results from
trainer.evaluate_decoder(...), or you will pass manualfeature_factorandlearning_rate(see below). - A target class id to strengthen or weaken (e.g.
"masc_nom").
Parameter selection: assisted vs manual
The rewrite applies an update of the form base model + learning_rate × decoder(feature_factor). The strength and direction of the effect depend on feature factor and learning rate.
- Manual parameters: You can call the model’s
rewrite_base_model(learning_rate=..., feature_factor=...)(on aModelWithGradiendinstance) with any values. This gives full control but the outcome is ambiguous—different choices can over-strengthen, under-strengthen, or harm other classes. - Assisted parameter selection (recommended): Run
trainer.evaluate_decoder(...)to sweep a grid of(feature_factor, learning_rate)and score each candidate (e.g. by target-class probability or combined metric). The trainer then uses the best config per class when you calltrainer.rewrite_base_model(decoder_results=..., target_class=...). This way, the chosen parameters are driven by your evaluation data and metric.
Use the trainer’s rewrite_base_model with decoder_results when you want data-driven parameters; use the model’s rewrite_base_model with explicit learning_rate and feature_factor when you want to experiment manually despite ambiguous outcomes.
What trainer.rewrite_base_model(...) does
When you pass decoder_results (and optionally target_class), the trainer looks up the best (learning_rate, feature_factor) from the decoder evaluation for the chosen class and direction, then applies that update to the base model and returns the rewritten model (e.g. BertForMaskedLM or a causal LM).
- Strengthen (default):
increase_target_probabilities=True— increases probabilities for the target tokens oftarget_class. - Weaken:
increase_target_probabilities=False— applies the weakening config; you must have run decoder evaluation withincrease_target_probabilities=Falsefirst so that the corresponding summary exists (e.g."masc_nom_weaken").
Basic usage (assisted parameters)
# 1) Run decoder evaluation (strengthen direction by default)
dec_results = trainer.evaluate_decoder()
# 2) Rewrite for one class using the best config from the evaluation
changed_model = trainer.rewrite_base_model(
decoder_results=dec_results,
target_class="masc_nom",
)
You can then use changed_model for inference or further evaluation.
Strengthen vs weaken
Strengthen (default):
Weaken: run decoder evaluation in weaken direction first, then pass the same flag to rewrite:
# Evaluate weaken direction (produces keys like "masc_nom_weaken")
dec_results_weaken = trainer.evaluate_decoder(increase_target_probabilities=False)
# Rewrite using the weaken config
changed_model_weaken = trainer.rewrite_base_model(
decoder_results=dec_results_weaken,
target_class="masc_nom",
increase_target_probabilities=False,
)
Choosing target classes
target_class="masc_nom"— one rewritten model for that class.target_class=["masc_nom", "fem_nom"]— one rewritten model per class.
Use class ids that match your setup (e.g. from target_classes or your dataset’s feature class ids).
Using cached decoder results
If experiment_dir is set and decoder evaluation was run with use_cache=True, you can omit decoder_results; the trainer will load decoder stats from cache when available:
Saving rewritten model(s)
To save the rewritten checkpoint(s), pass output_dir:
trainer.rewrite_base_model(
decoder_results=dec_results,
target_class="masc_nom",
output_dir="./output/masc_nom_rewrite",
)
With multiple target_class entries, experiment_dir is used to derive paths when saving.
Minimal end-to-end snippet
trainer.train()
dec_results = trainer.evaluate_decoder()
# In-memory rewritten model (uses best config from decoder evaluation)
changed_model = trainer.rewrite_base_model(
decoder_results=dec_results,
target_class="masc_nom",
)
# Save rewritten model to disk
trainer.rewrite_base_model(
decoder_results=dec_results,
target_class="masc_nom",
output_dir="./output/masc_nom_rewrite",
)
Next steps
- Tutorial: Evaluation (intra-model) — Encoder and decoder evaluation, including decoder config selection.
- Tutorial: Evaluation (inter-model) — Comparing runs (e.g. top-k overlap, heatmaps).
- Core classes and use cases —
ModelWithGradiendand rewrite APIs. - API reference —
TextPredictionTrainer,ModelWithGradiend.