sbrlModel2arcCBARuleModel {qCBA} | R Documentation |
sbrlModel2arcCBARuleModel Converts a model created by sbrl so that it can be passed to qCBA
Description
Creates instance of CBAmodel class from the arc package. Instance of CBAmodel can then be passed to qcba
Usage
sbrlModel2arcCBARuleModel(
sbrl_model,
cutPoints,
rawDataset,
classAtt,
attTypes
)
Arguments
sbrl_model |
object returned by arulesCBA::CBA() |
cutPoints |
specification of cutpoints applied on the data before they were passed to |
rawDataset |
the raw data (before discretization). This dataset is used to guess attribute types if attTypes is not passed |
classAtt |
the name of the class attribute |
attTypes |
vector of attribute types of the original data. If set to null, you need to pass rawDataset. |
Examples
if (! requireNamespace("rCBA", quietly = TRUE)) {
message("Please install rCBA to allow for sbrl model conversion")
return()
} else if (! requireNamespace("sbrl", quietly = TRUE)) {
message("Please install sbrl to allow for postprocessing of sbrl models")
} else
{
#' # This will run only outside a CRAN test, if the environment variable NOT_CRAN is set to true
# This environment variable is set by devtools
if (identical(Sys.getenv("NOT_CRAN"), "true")) {
library(sbrl)
library(rCBA)
# sbrl handles only binary problems, iris has 3 target classes - remove one class
set.seed(111)
allData <- datasets::iris[sample(nrow(datasets::iris)),]
classToExclude<-"versicolor"
allData <- allData[allData$Species!=classToExclude, ]
# drop the removed level
allData$Species <-allData$Species [, drop=TRUE]
trainFold <- allData[1:50,]
testFold <- allData[51:nrow(allData),]
sbrlFixedLabel<-"label"
origLabel<-"Species"
orignames<-colnames(trainFold)
orignames[which(orignames == origLabel)]<-sbrlFixedLabel
colnames(trainFold)<-orignames
colnames(testFold)<-orignames
# to recode label to binary values:
# first create dict mapping from original distinct class values to 0,1
origval<-levels(as.factor(trainFold$label))
newval<-range(0,1)
dict<-data.frame(origval,newval)
# then apply dict to train and test fold
trainFold$label<-dict[match(trainFold$label, dict$origval), 2]
testFold$label<-dict[match(testFold$label, dict$origval), 2]
# discretize training data
trainFoldDiscTemp <- discrNumeric(trainFold, sbrlFixedLabel)
trainFoldDiscCutpoints <- trainFoldDiscTemp$cutp
trainFoldDisc <- as.data.frame(lapply(trainFoldDiscTemp$Disc.data, as.factor))
# discretize test data
testFoldDisc <- applyCuts(testFold, trainFoldDiscCutpoints, infinite_bounds=TRUE, labels=TRUE)
# SBRL 1.4 crashes if features contain a space
# even if these features are converted to factors,
# to circumvent this, it is necessary to replace spaces
trainFoldDisc <- as.data.frame(lapply(trainFoldDisc, function(x) gsub(" ", "", as.character(x))))
for (name in names(trainFoldDisc)) {trainFoldDisc[name] <- as.factor(trainFoldDisc[,name])}
# learn sbrl model, rule_minlen is increased to demonstrate the effect of postprocessing
sbrl_model <- sbrl(trainFoldDisc, iters=20000, pos_sign="0",
neg_sign="1", rule_minlen=3, rule_maxlen=5, minsupport_pos=0.05, minsupport_neg=0.05,
lambda=20.0, eta=5.0, nchain=25)
# apply sbrl model on a test fold
yhat <- predict(sbrl_model, testFoldDisc)
yvals<- as.integer(yhat$V1>0.5)
sbrl_acc<-mean(as.integer(yvals == testFoldDisc$label))
message("SBRL RESULT")
message(sbrl_model)
rm_sbrl<-sbrlModel2arcCBARuleModel(sbrl_model,trainFoldDiscCutpoints,trainFold,sbrlFixedLabel)
message(paste("sbrl acc=",sbrl_acc,", sbrl rule count=",nrow(sbrl_model$rs), ",
avg condition count (incl. default rule)",
sum(rm_sbrl@rules@lhs@data)/length(rm_sbrl@rules)))
rmQCBA_sbrl <- qcba(cbaRuleModel=rm_sbrl,datadf=trainFold)
prediction <- predict(rmQCBA_sbrl,testFold)
acc_qcba_sbrl <- CBARuleModelAccuracy(prediction, testFold[[rmQCBA_sbrl@classAtt]])
avg_rule_length <- rmQCBA_sbrl@rules$condition_count/nrow(rmQCBA_sbrl@rules)
message("RESULT of QCBA postprocessing of SBRL")
message(rmQCBA_sbrl@rules)
message(paste("QCBA after SBRL acc=",acc_qcba_sbrl,", rule count=",
rmQCBA_sbrl@ruleCount, ", avg condition count (incl. default rule)", avg_rule_length))
unlink("tdata_R.label") # delete temp files created by SBRL
unlink("tdata_R.out")
}
}
[Package qCBA version 1.0 Index]