task_dataset {mlr3torch} | R Documentation |
Creates a torch dataset from an mlr3 Task
.
The resulting dataset's $.get_batch()
method returns a list with elements x
, y
and index
:
x
is a list with tensors, whose content is defined by the parameter feature_ingress_tokens
.
y
is the target variable and its content is defined by the parameter target_batchgetter
.
.index
is the index of the batch in the task's data.
The data is returned on the device specified by the parameter device
.
task_dataset(task, feature_ingress_tokens, target_batchgetter = NULL, device)
task |
|
feature_ingress_tokens |
(named |
target_batchgetter |
( |
device |
( |
task = tsk("iris")
sepal_ingress = TorchIngressToken(
features = c("Sepal.Length", "Sepal.Width"),
batchgetter = batchgetter_num,
shape = c(NA, 2)
)
petal_ingress = TorchIngressToken(
features = c("Petal.Length", "Petal.Width"),
batchgetter = batchgetter_num,
shape = c(NA, 2)
)
ingress_tokens = list(sepal = sepal_ingress, petal = petal_ingress)
target_batchgetter = function(data, device) {
torch_tensor(data = data[[1L]], dtype = torch_float32(), device)$unsqueeze(2)
}
dataset = task_dataset(task, ingress_tokens, target_batchgetter, "cpu")
batch = dataset$.getbatch(1:10)
batch