<?php
|
|
|
|
// (c) Copyright by authors of the Tiki Wiki CMS Groupware Project
|
|
//
|
|
// All Rights Reserved. See copyright.txt for details and a complete list of authors.
|
|
// Licensed under the GNU LESSER GENERAL PUBLIC LICENSE. See license.txt for details.
|
|
// $Id$
|
|
|
|
/**
|
|
* MachineLearningLib
|
|
*
|
|
* @uses TikiDb_Bridge
|
|
*/
|
|
class MachineLearningLib extends TikiDb_Bridge
|
|
{
|
|
private $table;
|
|
|
|
/**
|
|
*
|
|
*/
|
|
public function __construct()
|
|
{
|
|
$this->table = $this->table('tiki_machine_learning_models');
|
|
}
|
|
|
|
public function get_models()
|
|
{
|
|
$models = $this->table->fetchAll();
|
|
return array_map([$this, 'deserialize'], $models);
|
|
}
|
|
|
|
public function get_model($mlmId)
|
|
{
|
|
$model = $this->table->fetchFullRow(['mlmId' => $mlmId]);
|
|
if (! $model) {
|
|
return false;
|
|
}
|
|
$model = $this->deserialize($model);
|
|
$model['instances'] = $this->hydrate($model['payload']);
|
|
return $model;
|
|
}
|
|
|
|
public function set_model($mlmId, $data)
|
|
{
|
|
$data = $this->serialize($data);
|
|
return $this->table->insertOrUpdate($data, ['mlmId' => $mlmId]);
|
|
}
|
|
|
|
public function delete_model($mlmId)
|
|
{
|
|
$this->table->delete(['mlmId' => $mlmId]);
|
|
TikiLib::lib('cache')->invalidate($mlmId, 'mlmodel');
|
|
return true;
|
|
}
|
|
|
|
public function hydrate($payload)
|
|
{
|
|
$instances = [];
|
|
$payload = json_decode($payload);
|
|
foreach ($payload as $learner) {
|
|
$instances[] = $this->hydrate_single($learner->class, $learner->args);
|
|
}
|
|
return $instances;
|
|
}
|
|
|
|
public function hydrate_single($class, $args)
|
|
{
|
|
if (empty($class)) {
|
|
return [
|
|
'learner' => null,
|
|
'class' => null,
|
|
'instance' => null,
|
|
'serialized_args' => null
|
|
];
|
|
}
|
|
$ref = new ReflectionClass('Rubix\ML\\' . $class);
|
|
$instance_args = [];
|
|
if ($args) {
|
|
foreach ($args as $arg) {
|
|
if ($arg->input_type == 'layers' && ! empty($arg->value)) {
|
|
$layers = [];
|
|
foreach ($arg->value as $argv) {
|
|
$instance = $this->hydrate_single($argv->class, $argv->args);
|
|
$layers[] = $instance['instance'];
|
|
}
|
|
$instance_args[] = $layers;
|
|
} elseif ($arg->input_type == 'rubix' && ! empty($arg->value)) {
|
|
$iargs = $arg->value;
|
|
$instance = $this->hydrate_single($iargs->class, $iargs->args);
|
|
$instance_args[] = $instance['instance'];
|
|
} else {
|
|
$instance_args[] = $arg->value;
|
|
}
|
|
}
|
|
}
|
|
try {
|
|
$instance = $ref->newInstanceArgs($instance_args);
|
|
} catch (TypeError $e) {
|
|
Feedback::error(tr('Error instantiating %0 with arguments %1: %2', $class, print_r($instance_args, 1), $e->getMessage()));
|
|
$instance = tr('(error instantiating)');
|
|
}
|
|
return [
|
|
'learner' => preg_replace('/^[^\\\\]*\\\\/', '', $class),
|
|
'class' => $class,
|
|
'instance' => $instance,
|
|
'serialized_args' => json_encode(['class' => $class, 'args' => $args])
|
|
];
|
|
}
|
|
|
|
public function train($model, $test = false)
|
|
{
|
|
$learner = null;
|
|
$transformers = [];
|
|
foreach ($model['instances'] as $row) {
|
|
$instance = $row['instance'];
|
|
if ($instance instanceof Rubix\ML\Transformers\Transformer) {
|
|
$transformers[] = $instance;
|
|
} elseif ($instance instanceof Rubix\ML\Learner) {
|
|
$learner = $instance;
|
|
} else {
|
|
throw new Exception(tr('Not implemented: %0', get_class($instance)));
|
|
}
|
|
}
|
|
$estimator = new Rubix\ML\Pipeline($transformers, $learner);
|
|
|
|
$samples = [];
|
|
$labels = [];
|
|
|
|
$trklib = TikiLib::lib('trk');
|
|
$items = $trklib->list_items($model['sourceTrackerId'], 0, $test ? 10 : -1);
|
|
$definition = Tracker_Definition::get($model['sourceTrackerId']);
|
|
foreach ($items['data'] as $item) {
|
|
$item = Tracker_Item::fromId($item['itemId']);
|
|
switch ($model['labelField']) {
|
|
case "itemId":
|
|
$label = (string)$item->getId();
|
|
break;
|
|
case "itemTitle":
|
|
$label = $trklib->get_isMain_value($model['sourceTrackerId'], $item->getId());
|
|
break;
|
|
default:
|
|
$label = "";
|
|
}
|
|
$sample = [];
|
|
foreach ($model['trackerFields'] as $fieldId) {
|
|
$field = $definition->getField($fieldId);
|
|
if (empty($field)) {
|
|
continue;
|
|
}
|
|
$field = $item->prepareFieldOutput($field);
|
|
$value = $trklib->field_render_value([
|
|
'field' => $field,
|
|
'itemId' => $item->getId(),
|
|
]);
|
|
if (empty($value) && $model['ignoreEmpty']) {
|
|
continue 2;
|
|
}
|
|
$sample[] = $value;
|
|
if ($model['labelField'] == $fieldId) {
|
|
$label = is_numeric($value) ? floatval($value) : $value;
|
|
}
|
|
}
|
|
if (empty($sample)) {
|
|
continue;
|
|
}
|
|
$samples[] = $sample;
|
|
$labels[] = $label;
|
|
}
|
|
|
|
if (empty($samples) || empty($labels)) {
|
|
throw new Exception(tr("No data found in data source. Check your model settings."));
|
|
}
|
|
|
|
if ($model['labelField']) {
|
|
$dataset = Rubix\ML\Datasets\Labeled::build($samples, $labels);
|
|
} else {
|
|
$dataset = Rubix\ML\Datasets\Unlabeled::build($samples);
|
|
}
|
|
$estimator->train($dataset);
|
|
|
|
if (! $test) {
|
|
$serialized = serialize($estimator);
|
|
TikiLib::lib('cache')->cacheItem($model['mlmId'], $serialized, 'mlmodel');
|
|
if ($serialized != TikiLib::lib('cache')->getCached($model['mlmId'], 'mlmodel')) {
|
|
throw new Exception(tr('Model trained but could not be saved to cache. Check cache storage permissions.'));
|
|
}
|
|
}
|
|
}
|
|
|
|
public function probaSample($model, $processedFields)
|
|
{
|
|
$sample = [];
|
|
foreach ($processedFields as $field) {
|
|
$value = TikiLib::lib('trk')->field_render_value([
|
|
'field' => $field,
|
|
]);
|
|
$sample[] = $value;
|
|
}
|
|
|
|
$estimator = $this->getTrainedModel($model);
|
|
|
|
$result = $estimator->probaSample($sample);
|
|
$result = array_filter($result);
|
|
arsort($result);
|
|
|
|
return $result;
|
|
}
|
|
|
|
public function predictSample($model, $processedFields)
|
|
{
|
|
$sample = [];
|
|
foreach ($processedFields as $field) {
|
|
$value = TikiLib::lib('trk')->field_render_value([
|
|
'field' => $field,
|
|
]);
|
|
$sample[] = $value;
|
|
}
|
|
|
|
$estimator = $this->getTrainedModel($model);
|
|
|
|
$result = $estimator->predictSample($sample);
|
|
return $result;
|
|
}
|
|
|
|
public function isRegressor($model)
|
|
{
|
|
$estimator = $this->getTrainedModel($model);
|
|
return $estimator->type() == Rubix\ML\EstimatorType::regressor();
|
|
}
|
|
|
|
public function ensureModelTrained($model)
|
|
{
|
|
$this->getTrainedModel($model);
|
|
}
|
|
|
|
public function predefined($template)
|
|
{
|
|
switch ($template) {
|
|
case 'MLT':
|
|
return json_encode([
|
|
[
|
|
"class" => "Transformers\\TextNormalizer",
|
|
"args" => []
|
|
], [
|
|
"class" => "Transformers\\StopWordFilter",
|
|
"args" => [
|
|
[
|
|
"name" => "stopWords",
|
|
"default" => [],
|
|
"arg_type" => "array",
|
|
"input_type" => "text",
|
|
"value" => ["i","me","my","myself","we","our","ours","ourselves","you","your","yours","yourself","yourselves","he","him","his","himself","she","her","hers","herself","it","its","itself","they","them","their","theirs","themselves","what","which","who","whom","this","that","these","those","am","is","are","was","were","be","been","being","have","has","had","having","do","does","did","doing","a","an","the","and","but","if","or","because","as","until","while","of","at","by","for","with","about","against","between","into","through","during","before","after","above","below","to","from","up","down","in","out","on","off","over","under","again","further","then","once","here","there","when","where","why","how","all","any","both","each","few","more","most","other","some","such","no","nor","not","only","own","same","so","than","too","very","s","t","can","will","just","don","should","now"]
|
|
]
|
|
]
|
|
], [
|
|
"class" => "Transformers\\WordCountVectorizer",
|
|
"args" => [
|
|
[
|
|
"name" => "maxVocabulary",
|
|
"default" => PHP_INT_MAX,
|
|
"arg_type" => "int",
|
|
"input_type" => "text",
|
|
"value" => 10000
|
|
], [
|
|
"name" => "minDocumentFrequency",
|
|
"default" => 1,
|
|
"arg_type" => "int",
|
|
"input_type" => "text",
|
|
"value" => "1"
|
|
], [
|
|
"name" => "maxDocumentFrequency",
|
|
"default" => PHP_INT_MAX,
|
|
"arg_type" => "int",
|
|
"input_type" => "text",
|
|
"value" => 500
|
|
], [
|
|
"name" => "tokenizer",
|
|
"default" => null,
|
|
"arg_type" => "Rubix\\ML\\Other\\Tokenizers\\Tokenizer",
|
|
"input_type" => "rubix",
|
|
"value" => null
|
|
]
|
|
]
|
|
], [
|
|
"class" => "Transformers\\BM25Transformer",
|
|
"args" => [
|
|
[
|
|
"name" => "alpha",
|
|
"default" => 1.2,
|
|
"arg_type" => "float",
|
|
"input_type" => "text",
|
|
"value" => 1.2
|
|
], [
|
|
"name" => "beta",
|
|
"default" => 0.75,
|
|
"arg_type" => "float",
|
|
"input_type" => "text",
|
|
"value" => 0.75
|
|
]
|
|
]
|
|
], [
|
|
"class" => "Classifiers\\KDNeighbors",
|
|
"args" => [
|
|
[
|
|
"name" => "k",
|
|
"default" => 5,
|
|
"arg_type" => "int",
|
|
"input_type" => "text",
|
|
"value" => 20
|
|
], [
|
|
"name" => "weighted",
|
|
"default" => true,
|
|
"arg_type" => "bool",
|
|
"input_type" => "text",
|
|
"value" => "true"
|
|
], [
|
|
"name" => "tree",
|
|
"default" => null,
|
|
"arg_type" => "Rubix\\ML\\Graph\\Trees\\Spatial",
|
|
"input_type" => "rubix",
|
|
"value" => [
|
|
"class" => "Graph\\Trees\\BallTree",
|
|
"args" => [
|
|
[
|
|
"name" => "maxLeafSize",
|
|
"default" => 30,
|
|
"arg_type" => "int",
|
|
"input_type" => "text",
|
|
"value" => 20
|
|
], [
|
|
"name" => "kernel",
|
|
"default" => null,
|
|
"arg_type" => "Rubix\\ML\\Kernels\\Distance\\Distance",
|
|
"input_type" => "rubix",
|
|
"value" => [
|
|
"class" => "Kernels\\Distance\\Cosine",
|
|
"args" => []
|
|
]
|
|
]
|
|
]
|
|
]
|
|
]
|
|
]
|
|
]
|
|
]);
|
|
default:
|
|
return '';
|
|
}
|
|
}
|
|
|
|
protected function getTrainedModel($model)
|
|
{
|
|
$estimator = TikiLib::lib('cache')->getSerialized($model['mlmId'], 'mlmodel');
|
|
if (! $estimator || ! $estimator->trained()) {
|
|
throw new Exception(tr('Model was not trained.'));
|
|
}
|
|
return $estimator;
|
|
}
|
|
|
|
protected function serialize($model)
|
|
{
|
|
if (is_array($model['trackerFields'])) {
|
|
$model['trackerFields'] = implode(',', $model['trackerFields']);
|
|
}
|
|
return $model;
|
|
}
|
|
|
|
protected function deserialize($model)
|
|
{
|
|
$model['trackerFields'] = explode(',', $model['trackerFields']);
|
|
if (empty($model['payload'])) {
|
|
$model['payload'] = '[]';
|
|
}
|
|
return $model;
|
|
}
|
|
}
|