table = $this->table('tiki_machine_learning_models'); } public function get_models() { $models = $this->table->fetchAll(); return array_map([$this, 'deserialize'], $models); } public function get_model($mlmId) { $model = $this->table->fetchFullRow(['mlmId' => $mlmId]); if (! $model) { return false; } $model = $this->deserialize($model); $model['instances'] = $this->hydrate($model['payload']); return $model; } public function set_model($mlmId, $data) { $data = $this->serialize($data); return $this->table->insertOrUpdate($data, ['mlmId' => $mlmId]); } public function delete_model($mlmId) { $this->table->delete(['mlmId' => $mlmId]); TikiLib::lib('cache')->invalidate($mlmId, 'mlmodel'); return true; } public function hydrate($payload) { $instances = []; $payload = json_decode($payload); foreach ($payload as $learner) { $instances[] = $this->hydrate_single($learner->class, $learner->args); } return $instances; } public function hydrate_single($class, $args) { if (empty($class)) { return [ 'learner' => null, 'class' => null, 'instance' => null, 'serialized_args' => null ]; } $ref = new ReflectionClass('Rubix\ML\\' . $class); $instance_args = []; if ($args) { foreach ($args as $arg) { if ($arg->input_type == 'layers' && ! empty($arg->value)) { $layers = []; foreach ($arg->value as $argv) { $instance = $this->hydrate_single($argv->class, $argv->args); $layers[] = $instance['instance']; } $instance_args[] = $layers; } elseif ($arg->input_type == 'rubix' && ! empty($arg->value)) { $iargs = $arg->value; $instance = $this->hydrate_single($iargs->class, $iargs->args); $instance_args[] = $instance['instance']; } else { $instance_args[] = $arg->value; } } } try { $instance = $ref->newInstanceArgs($instance_args); } catch (TypeError $e) { Feedback::error(tr('Error instantiating %0 with arguments %1: %2', $class, print_r($instance_args, 1), $e->getMessage())); $instance = tr('(error instantiating)'); } return [ 'learner' => preg_replace('/^[^\\\\]*\\\\/', '', $class), 'class' => $class, 'instance' => $instance, 'serialized_args' => json_encode(['class' => $class, 'args' => $args]) ]; } public function train($model, $test = false) { $learner = null; $transformers = []; foreach ($model['instances'] as $row) { $instance = $row['instance']; if ($instance instanceof Rubix\ML\Transformers\Transformer) { $transformers[] = $instance; } elseif ($instance instanceof Rubix\ML\Learner) { $learner = $instance; } else { throw new Exception(tr('Not implemented: %0', get_class($instance))); } } $estimator = new Rubix\ML\Pipeline($transformers, $learner); $samples = []; $labels = []; $trklib = TikiLib::lib('trk'); $items = $trklib->list_items($model['sourceTrackerId'], 0, $test ? 10 : -1); $definition = Tracker_Definition::get($model['sourceTrackerId']); foreach ($items['data'] as $item) { $item = Tracker_Item::fromId($item['itemId']); switch ($model['labelField']) { case "itemId": $label = (string)$item->getId(); break; case "itemTitle": $label = $trklib->get_isMain_value($model['sourceTrackerId'], $item->getId()); break; default: $label = ""; } $sample = []; foreach ($model['trackerFields'] as $fieldId) { $field = $definition->getField($fieldId); if (empty($field)) { continue; } $field = $item->prepareFieldOutput($field); $value = $trklib->field_render_value([ 'field' => $field, 'itemId' => $item->getId(), ]); if (empty($value) && $model['ignoreEmpty']) { continue 2; } $sample[] = $value; if ($model['labelField'] == $fieldId) { $label = is_numeric($value) ? floatval($value) : $value; } } if (empty($sample)) { continue; } $samples[] = $sample; $labels[] = $label; } if (empty($samples) || empty($labels)) { throw new Exception(tr("No data found in data source. Check your model settings.")); } if ($model['labelField']) { $dataset = Rubix\ML\Datasets\Labeled::build($samples, $labels); } else { $dataset = Rubix\ML\Datasets\Unlabeled::build($samples); } $estimator->train($dataset); if (! $test) { $serialized = serialize($estimator); TikiLib::lib('cache')->cacheItem($model['mlmId'], $serialized, 'mlmodel'); if ($serialized != TikiLib::lib('cache')->getCached($model['mlmId'], 'mlmodel')) { throw new Exception(tr('Model trained but could not be saved to cache. Check cache storage permissions.')); } } } public function probaSample($model, $processedFields) { $sample = []; foreach ($processedFields as $field) { $value = TikiLib::lib('trk')->field_render_value([ 'field' => $field, ]); $sample[] = $value; } $estimator = $this->getTrainedModel($model); $result = $estimator->probaSample($sample); $result = array_filter($result); arsort($result); return $result; } public function predictSample($model, $processedFields) { $sample = []; foreach ($processedFields as $field) { $value = TikiLib::lib('trk')->field_render_value([ 'field' => $field, ]); $sample[] = $value; } $estimator = $this->getTrainedModel($model); $result = $estimator->predictSample($sample); return $result; } public function isRegressor($model) { $estimator = $this->getTrainedModel($model); return $estimator->type() == Rubix\ML\EstimatorType::regressor(); } public function ensureModelTrained($model) { $this->getTrainedModel($model); } public function predefined($template) { switch ($template) { case 'MLT': return json_encode([ [ "class" => "Transformers\\TextNormalizer", "args" => [] ], [ "class" => "Transformers\\StopWordFilter", "args" => [ [ "name" => "stopWords", "default" => [], "arg_type" => "array", "input_type" => "text", "value" => ["i","me","my","myself","we","our","ours","ourselves","you","your","yours","yourself","yourselves","he","him","his","himself","she","her","hers","herself","it","its","itself","they","them","their","theirs","themselves","what","which","who","whom","this","that","these","those","am","is","are","was","were","be","been","being","have","has","had","having","do","does","did","doing","a","an","the","and","but","if","or","because","as","until","while","of","at","by","for","with","about","against","between","into","through","during","before","after","above","below","to","from","up","down","in","out","on","off","over","under","again","further","then","once","here","there","when","where","why","how","all","any","both","each","few","more","most","other","some","such","no","nor","not","only","own","same","so","than","too","very","s","t","can","will","just","don","should","now"] ] ] ], [ "class" => "Transformers\\WordCountVectorizer", "args" => [ [ "name" => "maxVocabulary", "default" => PHP_INT_MAX, "arg_type" => "int", "input_type" => "text", "value" => 10000 ], [ "name" => "minDocumentFrequency", "default" => 1, "arg_type" => "int", "input_type" => "text", "value" => "1" ], [ "name" => "maxDocumentFrequency", "default" => PHP_INT_MAX, "arg_type" => "int", "input_type" => "text", "value" => 500 ], [ "name" => "tokenizer", "default" => null, "arg_type" => "Rubix\\ML\\Other\\Tokenizers\\Tokenizer", "input_type" => "rubix", "value" => null ] ] ], [ "class" => "Transformers\\BM25Transformer", "args" => [ [ "name" => "alpha", "default" => 1.2, "arg_type" => "float", "input_type" => "text", "value" => 1.2 ], [ "name" => "beta", "default" => 0.75, "arg_type" => "float", "input_type" => "text", "value" => 0.75 ] ] ], [ "class" => "Classifiers\\KDNeighbors", "args" => [ [ "name" => "k", "default" => 5, "arg_type" => "int", "input_type" => "text", "value" => 20 ], [ "name" => "weighted", "default" => true, "arg_type" => "bool", "input_type" => "text", "value" => "true" ], [ "name" => "tree", "default" => null, "arg_type" => "Rubix\\ML\\Graph\\Trees\\Spatial", "input_type" => "rubix", "value" => [ "class" => "Graph\\Trees\\BallTree", "args" => [ [ "name" => "maxLeafSize", "default" => 30, "arg_type" => "int", "input_type" => "text", "value" => 20 ], [ "name" => "kernel", "default" => null, "arg_type" => "Rubix\\ML\\Kernels\\Distance\\Distance", "input_type" => "rubix", "value" => [ "class" => "Kernels\\Distance\\Cosine", "args" => [] ] ] ] ] ] ] ] ]); default: return ''; } } protected function getTrainedModel($model) { $estimator = TikiLib::lib('cache')->getSerialized($model['mlmId'], 'mlmodel'); if (! $estimator || ! $estimator->trained()) { throw new Exception(tr('Model was not trained.')); } return $estimator; } protected function serialize($model) { if (is_array($model['trackerFields'])) { $model['trackerFields'] = implode(',', $model['trackerFields']); } return $model; } protected function deserialize($model) { $model['trackerFields'] = explode(',', $model['trackerFields']); if (empty($model['payload'])) { $model['payload'] = '[]'; } return $model; } }