GradiendModel
Bases: Module
GRADIEND - GRADIent ENcoder Decoder model implementation (weights-only): maps model gradients to a low-dimensional latent space and back.
Proposed by Drechsel et al. 2025 (https://arxiv.org/abs/2502.01406).
This class holds ONLY the neural components (encoder/decoder) + utilities that depend solely on GRADIEND parameters: - forward / forward_encoder (tensor input space) - weight-derived importance scores (encoder/decoder/decoder-bias/decoder-sum) - internal prune primitive that physically reduces input_dim (slices weights; no mapping logic) - save_pretrained / from_pretrained for weights + architecture + metadata
Saving: - Weights: model.safetensors if available, else pytorch_model.bin - Config: config.json (format_version=0) - Run info: training.json (optional; if kwargs contains "training")
Use ParamMappedGradiendModel when you need a parameter mapping or dict-of-gradients I/O.
Initialize a weights-only GRADIEND model (i.e., a GRADIEND encoder-decoder without base-model context but ).
Activation functions (case-insensitive): tanh, relu, leakyrelu, gelu, silu, elu, sigmoid, smht (hardtanh), id (identity). Defaults (paper): encoder tanh, decoder id.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Size of the GRADIEND input space (total selected gradient entries). |
required |
latent_dim
|
int
|
Size of the latent bottleneck. |
required |
activation_encoder
|
str
|
Encoder activation name (case-insensitive). |
'tanh'
|
activation_decoder
|
str
|
Decoder activation name. If falsy, uses encoder activation but with decoder-appropriate defaults via get_activation. |
'id'
|
bias_decoder
|
bool
|
Whether the decoder linear layer uses a bias term. |
True
|
torch_dtype
|
dtype
|
dtype used for model parameters. |
float32
|
device
|
Optional[device]
|
Optional default device for both encoder and decoder when specific devices are not provided. |
None
|
device_encoder
|
Optional[device]
|
Device for encoder parameters. |
None
|
device_decoder
|
Optional[device]
|
Device for decoder parameters. |
None
|
lazy_init
|
bool
|
If True, do not create encoder/decoder weights here. Build them later via prune (with pruned size) or _build_encoder_decoder (full size). |
False
|
**kwargs
|
Any
|
Additional metadata stored in |
{}
|
Source code in gradiend/model/model.py
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | |
base_model_id
property
Base model identifier stored in kwargs.
Raises:
| Type | Description |
|---|---|
ValueError
|
If base_model is missing from kwargs. |
decoder_norm
property
L2 norm of the decoder weight matrix.
Returns:
| Type | Description |
|---|---|
float
|
Scalar float of the decoder's weight L2 norm. |
encoder_norm
property
L2 norm of the encoder weight matrix.
Returns:
| Type | Description |
|---|---|
float
|
Scalar float of the encoder's weight L2 norm. |
__len__
Return the current input_dim after pruning.
Returns:
| Type | Description |
|---|---|
int
|
Current input_dim as an integer. |
__str__
Source code in gradiend/model/model.py
_build_encoder_decoder
Instantiate encoder and decoder with given input_dim. Used for lazy init: either after prune (with pruned size) or before first use (with full size).
Source code in gradiend/model/model.py
_ensure_built
Build encoder/decoder with current input_dim if not yet built (lazy init).
_ensure_input
Source code in gradiend/model/model.py
_prune_input_dims
INTERNAL: physically prune input_dim and output_dim by slicing encoder/decoder weights.
- encoder: slice columns (latent_dim, input_dim) -> (latent_dim, new_in)
- decoder: slice rows (input_dim, latent_dim) -> (new_in, latent_dim)
- bias: slice entries (input_dim,) -> (new_in,)
Source code in gradiend/model/model.py
_require_built
Raise if encoder/decoder are not yet built (lazy init).
Source code in gradiend/model/model.py
_serialize_kwargs
Serialize kwargs, filtering out non-JSON objects.
Source code in gradiend/model/model.py
cpu
cuda
Move encoder and decoder to CUDA.
Source code in gradiend/model/model.py
forward
Forward pass for tensor input already in GRADIEND input space.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
1D tensor of shape (input_dim,) representing a flattened gradient vector in GRADIEND input space. |
required |
return_encoded
|
bool
|
If True, also return the latent encoding. |
False
|
Returns:
| Type | Description |
|---|---|
Union[Tensor, Tuple[Tensor, Tensor]]
|
If return_encoded is False: Decoded tensor of shape (input_dim,). |
Union[Tensor, Tuple[Tensor, Tensor]]
|
If return_encoded is True: Tuple (decoded, encoded), where: - decoded: tensor of shape (input_dim,) - encoded: tensor of shape (latent_dim,) |
Source code in gradiend/model/model.py
forward_encoder
Encoder-only forward for tensor input.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
1D tensor of shape (input_dim,) in GRADIEND input space. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Encoded tensor of shape (latent_dim,). |
Source code in gradiend/model/model.py
from_pretrained
classmethod
Load weights + config.json (weights-only). ParamMappedGradiendModel overrides to also load mapping.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
load_directory
|
str
|
Directory containing model files. |
required |
device_encoder
|
Optional[device]
|
Optional device override for encoder parameters. |
None
|
device_decoder
|
Optional[device]
|
Optional device override for decoder parameters. |
None
|
torch_dtype
|
Optional[dtype]
|
Optional dtype override. If None, uses dtype stored in config.json. |
None
|
Returns:
| Type | Description |
|---|---|
GradiendModel
|
Instantiated GradiendModel with loaded weights and metadata. |
Source code in gradiend/model/model.py
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 | |
get_topk_weights
Return the top-k input indices by importance score.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
part
|
str
|
Importance source passed to get_weight_importance. Options: "encoder-weight", "decoder-weight", "decoder-bias", "decoder-sum". |
'decoder-weight'
|
topk
|
Union[int, float]
|
Number of indices to return (clipped to input_dim) or a proportion in (0, 1]. |
1000
|
Returns:
| Type | Description |
|---|---|
List[int]
|
List of input indices (length k) sorted by descending importance. |
Source code in gradiend/model/model.py
get_update_vector
Return a flattened weight-derived update vector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
part
|
str
|
Which component to use for the update vector: - "decoder-weight": decoder weight vector (flattened) - "decoder-bias": decoder bias vector - "decoder-sum": decoder weight vector + bias - "encoder-weight": encoder weight vector (flattened) |
'decoder-weight'
|
Returns:
| Type | Description |
|---|---|
Tensor
|
1D tensor in GRADIEND input space derived from the requested component. |
Source code in gradiend/model/model.py
get_weight_importance
Importance per GRADIEND input dimension (length = input_dim), on CPU. Args: part: Which component to use for importance aggregation: - "encoder-weight": L1 over encoder weight columns - "decoder-weight": L1 over decoder weight rows - "decoder-bias": absolute decoder bias - "decoder-sum": absolute(sum(weight_row) + bias)
Returns:
| Type | Description |
|---|---|
Tensor
|
1D CPU float tensor of length input_dim, where higher means more |
Tensor
|
influential according to the chosen aggregation. |
Source code in gradiend/model/model.py
prune
prune(*, topk=None, threshold=None, mask=None, part='decoder-weight', importance=None, inplace=False, return_mask=False)
Physically prune the model (reduce input_dim) by selecting important input dimensions.
Selection order: mask -> threshold -> topk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
topk
|
Union[int, float, None]
|
int (absolute) or float in (0,1] (relative fraction among remaining dims). |
None
|
threshold
|
Optional[float]
|
keep dims with importance >= threshold. |
None
|
mask
|
Optional[Tensor]
|
optional bool tensor of shape (input_dim,) in current input space. |
None
|
part
|
str
|
'encoder-weight' | 'decoder-weight' | 'decoder-bias' | 'decoder-sum' (used when importance is None). |
'decoder-weight'
|
importance
|
Optional[Tensor]
|
optional 1D tensor of length input_dim; used instead of get_weight_importance(part) when provided. |
None
|
inplace
|
bool
|
modify this instance if True, else return a deepcopy. |
False
|
return_mask
|
bool
|
if True, also return final combined_mask (original input space). |
False
|
Returns:
| Type | Description |
|---|---|
Union[GradiendModel, Tuple[GradiendModel, Tensor]]
|
If return_mask is False:
The pruned GradiendModel (self or a deepcopy depending on |
Union[GradiendModel, Tuple[GradiendModel, Tensor]]
|
If return_mask is True: Tuple (model, combined_mask) where combined_mask is a bool tensor of shape (old_input_dim,) indicating kept dimensions. |
Source code in gradiend/model/model.py
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 | |
pruned_length
Return the current input_dim after pruning.
Returns:
| Type | Description |
|---|---|
int
|
Current input_dim as an integer. |
save_pretrained
Save weights + config.json (+ optional training.json).
Notes: - safetensors is used if available unless use_safetensors=False. - training info: if kwargs contains "training", it is written to training.json and removed from config metadata.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
save_directory
|
str
|
Folder to write model files into. |
required |
use_safetensors
|
Optional[bool]
|
If True, require safetensors. If False, force PyTorch bin format. If None, prefer safetensors when available. |
None
|
**kwargs
|
Any
|
Extra metadata to store in config.json. |
{}
|
Returns:
| Type | Description |
|---|---|
None
|
None. |
Source code in gradiend/model/model.py
to
Move encoder and decoder to the requested devices.
- If device_encoder or device_decoder is provided, moves only those submodules.
- If device is provided (and no split devices), moves both to that device.
- If device_encoder/device_decoder is None, leaves that submodule's placement unchanged.
- When encoder/decoder are not yet built (lazy init), only updates target device attributes.