ezpz.integrationsΒΆ
WandbPredictionProgressCallback
ΒΆ
Bases: WandbCallback
Custom WandbCallback to log model predictions during training.
This callback logs model predictions and labels to a wandb.Table at each logging step during training. It allows visualization of the model predictions as training progresses.
Attributes:
| Name | Type | Description |
|---|---|---|
trainer |
Hugging Face Trainer instance. |
|
tokenizer |
Tokenizer associated with the model. |
|
sample_dataset |
Subset of the validation dataset for predictions. |
|
num_samples |
Number of samples to select from validation for predictions. |
|
freq |
Frequency of logging (epochs). |
Source code in src/ezpz/integrations.py
__init__(trainer, tokenizer, val_dataset, num_samples=100, freq=2)
ΒΆ
Initializes the WandbPredictionProgressCallback instance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Any
|
Hugging Face Trainer instance. |
required |
tokenizer
|
Any
|
Tokenizer associated with the model. |
required |
val_dataset
|
Any
|
Validation dataset. |
required |
num_samples
|
int
|
Number of samples to select from validation for generating predictions. Defaults to 100. |
100
|
freq
|
int
|
Frequency of logging. Defaults to 2. |
2
|