-
-
Notifications
You must be signed in to change notification settings - Fork 194
Expand file tree
/
Copy pathCyclical.php
More file actions
145 lines (125 loc) · 3.46 KB
/
Cyclical.php
File metadata and controls
145 lines (125 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
<?php
namespace Rubix\ML\NeuralNet\Optimizers;
use Tensor\Tensor;
use Rubix\ML\NeuralNet\Parameter;
use Rubix\ML\Exceptions\InvalidArgumentException;
/**
* Cyclical
*
* The Cyclical optimizer uses a global learning rate that cycles between the
* lower and upper bound over a designated period while also decaying the
* upper bound by the decay coefficient at each step. Cyclical learning rates
* have been shown to help escape bad local minima and saddle points thus
* achieving lower training loss.
*
* References:
* [1] L. N. Smith. (2017). Cyclical Learning Rates for Training Neural Networks.
*
* @category Machine Learning
* @package Rubix/ML
* @author Andrew DalPino
*/
class Cyclical implements Optimizer
{
/**
* The lower bound on the learning rate.
*
* @var float
*/
protected float $lower;
/**
* The upper bound on the learning rate.
*
* @var float
*/
protected float $upper;
/**
* The range of the learning rate.
*
* @var float
*/
protected float $range;
/**
* The number of steps in every cycle.
*
* @var int
*/
protected int $losses;
/**
* The exponential scaling factor applied to each step as decay.
*
* @var float
*/
protected float $decay;
/**
* The number of steps taken so far.
*
* @var int
*/
protected int $t = 0;
/**
* @param float $lower
* @param float $upper
* @param int $losses
* @param float $decay
* @throws InvalidArgumentException
*/
public function __construct(
float $lower = 0.001,
float $upper = 0.006,
int $losses = 2000,
float $decay = 0.99994
) {
if ($lower <= 0.0) {
throw new InvalidArgumentException('Lower bound must be'
. " greater than 0, $lower given.");
}
if ($lower > $upper) {
throw new InvalidArgumentException('Lower bound cannot be'
. ' reater than the upper bound.');
}
if ($losses < 1) {
throw new InvalidArgumentException('The number of steps per'
. " cycle must be greater than 0, $losses given.");
}
if ($decay <= 0.0 or $decay >= 1.0) {
throw new InvalidArgumentException('Decay must be between'
. " 0 and 1, $decay given.");
}
$this->lower = $lower;
$this->upper = $upper;
$this->range = $upper - $lower;
$this->losses = $losses;
$this->decay = $decay;
}
/**
* Take a step of gradient descent for a given parameter.
*
* @internal
*
* @param Parameter $param
* @param Tensor<int|float|array> $gradient
* @return Tensor<int|float|array>
*/
public function step(Parameter $param, Tensor $gradient) : Tensor
{
$cycle = floor(1 + $this->t / (2 * $this->losses));
$x = abs($this->t / $this->losses - 2 * $cycle + 1);
$scale = $this->decay ** $this->t;
$rate = $this->lower + $this->range * max(0, 1 - $x) * $scale;
++$this->t;
return $gradient->multiply($rate);
}
/**
* Return the string representation of the object.
*
* @internal
*
* @return string
*/
public function __toString() : string
{
return "Cyclical (lower: {$this->lower}, upper: {$this->upper},"
. " steps: {$this->losses}, decay: {$this->decay})";
}
}