fit.luz_module_generator {luz} | R Documentation |
Fit a nn_module
Description
Fit a nn_module
Usage
## S3 method for class 'luz_module_generator'
fit(
object,
data,
epochs = 10,
callbacks = NULL,
valid_data = NULL,
accelerator = NULL,
verbose = NULL,
...,
dataloader_options = NULL
)
Arguments
object |
An nn_module that has been setup() .
|
data |
(dataloader, dataset or list) A dataloader created with
torch::dataloader() used for training the model, or a dataset created
with torch::dataset() or a list. Dataloaders and datasets must return a
list with at most 2 items. The first item will be used as input for the
module and the second will be used as a target for the loss function.
|
epochs |
(int) The maximum number of epochs for training the model. If a
single value is provided, this is taken to be the max_epochs and
min_epochs is set to 0. If a vector of two numbers is provided, the first
value is min_epochs and the second value is max_epochs . The minimum and
maximum number of epochs are included in the context object as
ctx$min_epochs and ctx$max_epochs , respectively.
|
callbacks |
(list, optional) A list of callbacks defined with
luz_callback() that will be called during the training procedure. The
callbacks luz_callback_metrics() , luz_callback_progress() and
luz_callback_train_valid() are always added by default.
|
valid_data |
(dataloader, dataset, list or scalar value; optional) A
dataloader created with torch::dataloader() or a dataset created with
torch::dataset() that will be used during the validation procedure. They
must return a list with (input, target). If data is a torch dataset or a
list, then you can also supply a numeric value between 0 and 1 - and in
this case a random sample with size corresponding to that proportion from
data will be used for validation.
|
accelerator |
(accelerator, optional) An optional accelerator() object
used to configure device placement of the components like nn_modules,
optimizers and batches of data.
|
verbose |
(logical, optional) An optional boolean value indicating if
the fitting procedure should emit output to the console during training.
By default, it will produce output if interactive() is TRUE , otherwise
it won't print to the console.
|
... |
Currently unused.
|
dataloader_options |
Options used when creating a dataloader. See
torch::dataloader() . shuffle=TRUE by default for the training data and
batch_size=32 by default. It will error if not NULL and data is
already a dataloader.
|
Value
A fitted object that can be saved with luz_save()
and can be
printed with print()
and plotted with plot()
.
See Also
predict.luz_module_fitted()
for how to create predictions.
setup()
to find out how to create modules that can be trained with fit
.
Other training:
evaluate()
,
predict.luz_module_fitted()
,
setup()
[Package
luz version 0.4.0
Index]