-
-
Notifications
You must be signed in to change notification settings - Fork 194
Expand file tree
/
Copy pathMonteCarlo.php
More file actions
124 lines (106 loc) · 3.59 KB
/
MonteCarlo.php
File metadata and controls
124 lines (106 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
<?php
namespace Rubix\ML\CrossValidation;
use Rubix\ML\Learner;
use Rubix\ML\Parallel;
use Rubix\ML\Estimator;
use Rubix\ML\Helpers\Stats;
use Rubix\ML\Backends\Serial;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Dataset;
use Rubix\ML\Traits\Multiprocessing;
use Rubix\ML\CrossValidation\Metrics\Metric;
use Rubix\ML\Backends\Tasks\TrainAndValidate;
use Rubix\ML\Specifications\EstimatorIsCompatibleWithMetric;
use Rubix\ML\Exceptions\InvalidArgumentException;
use Rubix\ML\Exceptions\RuntimeException;
/**
* Monte Carlo
*
* Monte Carlo cross validation (or *repeated random subsampling*) is a technique that
* averages the validation score of a learner over a user-defined number of simulations
* where the learner is trained and tested on random splits of the dataset. The estimated
* validation score approaches the actual validation score as the number of simulations
* goes to infinity, however, only a tiny fraction of all possible simulations are needed
* to produce a pretty good approximation.
*
* @category Machine Learning
* @package Rubix/ML
* @author Andrew DalPino
*/
class MonteCarlo implements Validator, Parallel
{
use Multiprocessing;
/**
* The number of simulations i.e. random subsamplings of the dataset.
*
* @var int
*/
protected int $simulations;
/**
* The hold out ratio. i.e. the ratio of samples to use for testing.
*
* @var float
*/
protected float $ratio;
/**
* @param int $simulations
* @param float $ratio
* @throws InvalidArgumentException
*/
public function __construct(int $simulations = 10, float $ratio = 0.2)
{
if ($simulations < 1) {
throw new InvalidArgumentException('Number of simulations'
. " must be greater than 0, $simulations given.");
}
if ($ratio <= 0.0 or $ratio >= 1.0) {
throw new InvalidArgumentException('Ratio must be'
. " between 0 and 1, $ratio given.");
}
$this->simulations = $simulations;
$this->ratio = $ratio;
$this->backend = new Serial();
}
/**
* Test the estimator with the supplied dataset and return a validation score.
*
* @param Learner $estimator
* @param Labeled $dataset
* @param Metric $metric
* @throws RuntimeException
* @return float
*/
public function test(Learner $estimator, Labeled $dataset, Metric $metric) : float
{
EstimatorIsCompatibleWithMetric::with($estimator, $metric)->check();
if ($dataset->numSamples() * $this->ratio < 1) {
throw new RuntimeException('Dataset does not contain'
. ' enough records to create a validation set with a'
. " hold out ratio of {$this->ratio}.");
}
$stratify = $dataset->labelType()->isCategorical();
$this->backend->flush();
for ($i = 0; $i < $this->simulations; ++$i) {
$dataset->randomize();
[$testing, $training] = $stratify
? $dataset->stratifiedSplit($this->ratio)
: $dataset->split($this->ratio);
$this->backend->enqueue(
new TrainAndValidate($estimator, $training, $testing, $metric)
);
}
$scores = $this->backend->process();
return Stats::mean($scores);
}
/**
* Return the string representation of the object.
*
* @internal
*
* @return string
*/
public function __toString() : string
{
return "Monte Carlo (simulations: {$this->simulations}, ratio: {$this->ratio})";
}
}