La Paste
Create new paste
Pastes Archive
2024-04-09 13:13:19
copy
raw
download
class MRNN { protected $weight = 0.01; protected $lastError = 1; protected $smoothing = 0.0001; protected $actualResult = 0.01; protected $correction = 0; protected $epoch = 0; protected $trainStats = array(); protected $statEvery = 5000; protected $debug = false; protected $activationFunction = 'def'; public function __construct($activationFunction = 'def') { $this->setActivationFunc($activationFunction); } public function setWeight($weight) { $this->weight = $weight; } protected function setActivationFunc($type) { $supportedTypes = array( 'def' => 'def', 'sigmoid' => 'sigmoid' ); if (isset($supportedTypes[$type])) { $this->activationFunction = $type; } else { throw new Exception('EX_WRONG_ACTFUNCTION'); } } public function processInputData($input) { $result = $input * $this->weight; return($result); } public function restoreInputData($output) { $result = $output / $this->weight; return($result); } protected function train($input, $expectedResult) { switch ($this->activationFunction) { case 'def': $this->actualResult = $input * $this->weight; $this->lastError = $expectedResult - $this->actualResult; $this->correction = ($this->lastError / $this->actualResult) * $this->smoothing; $this->weight += $this->correction; break; case 'sigmoid': $this->actualResult = $input * $this->weight; $this->actualResult = $this->sigmoid($this->actualResult); $this->lastError = $expectedResult - $this->unsigmoid($this->actualResult); $this->correction = ($this->lastError / $this->actualResult) * $this->smoothing; $this->weight += $this->correction; break; } } protected function learn($input, $expectedResult) { $this->epoch = 0; while ($this->lastError > $this->smoothing OR $this->lastError < '-' . $this->smoothing) { $this->train($input, $expectedResult); if (($this->epoch % $this->statEvery) == 0) { $this->trainStats[$this->epoch] = $this->lastError; } $this->epoch++; } return(true); } public function learnDataSet($dataSet, $accel = false) { $result = false; if (is_array($dataSet)) { if (!empty($dataSet)) { $totalweight = 0; $neurons = array(); $neuronIndex = 0; $prevWeight = $this->weight; $networkName = get_class($this); foreach ($dataSet as $input => $expectedResult) { $neurons[$neuronIndex] = new $networkName($this->activationFunction); if ($accel) { $neurons[$neuronIndex]->setWeight($prevWeight); } if ($neurons[$neuronIndex]->learn($input, $expectedResult)) { if ($this->debug) { show_success('Trained weight: ' . $neurons[$neuronIndex]->getWeight() . ' on epoch ' . $neurons[$neuronIndex]->getEpoch()); } $totalweight += $neurons[$neuronIndex]->getWeight(); $this->trainStats[] = $neurons[$neuronIndex]->getTrainStats(); $prevWeight = $neurons[$neuronIndex]->getWeight(); unset($neurons[$neuronIndex]); } $neuronIndex++; } $this->weight = $totalweight / $neuronIndex; } } $result = true; return($result); } public function getWeight() { return($this->weight); } protected function getLastError() { return($this->lastError); } protected function getEpoch() { return($this->epoch); } }
↑