001 package org.maltparser.parser.guide.instance;
002
003 import java.io.BufferedReader;
004 import java.io.BufferedWriter;
005 import java.io.IOException;
006 import java.util.ArrayList;
007 import java.util.Collections;
008 import java.util.HashMap;
009 import java.util.HashSet;
010 import java.util.LinkedList;
011 import java.util.List;
012 import java.util.Map;
013 import java.util.Set;
014 import java.util.SortedMap;
015 import java.util.TreeMap;
016 import java.util.Map.Entry;
017 import java.util.regex.Pattern;
018
019 import org.maltparser.core.config.ConfigurationDir;
020 import org.maltparser.core.exception.MaltChainedException;
021 import org.maltparser.core.feature.FeatureException;
022 import org.maltparser.core.feature.FeatureVector;
023 import org.maltparser.core.feature.function.FeatureFunction;
024 import org.maltparser.core.feature.function.Modifiable;
025 import org.maltparser.core.feature.value.SingleFeatureValue;
026 import org.maltparser.core.syntaxgraph.DependencyStructure;
027 import org.maltparser.parser.guide.ClassifierGuide;
028 import org.maltparser.parser.guide.GuideException;
029 import org.maltparser.parser.guide.Model;
030 import org.maltparser.parser.history.action.SingleDecision;
031
032 /**
033 * This class implements a decision tree model. The class is recursive and an
034 * instance of the class can be a root model or belong to an other decision tree
035 * model. Every node in the decision tree is represented by an instance of the
036 * class. Node can be in one of the three states branch model, leaf model or not
037 * decided. A branch model has several sub decision tree models and a leaf model
038 * owns an atomic model that is used to classify instances. When a decision tree
039 * model is in the not decided state it has both sub decision trees and an
040 * atomic model. It can be in the not decided state during training before it is
041 * tested by cross validation if the sub decision tree models provide better
042 * accuracy than the atomic model.
043 *
044 *
045 * @author Kjell Winblad
046 */
047 public class DecisionTreeModel implements InstanceModel {
048
049 /*
050 * The leaf nodes needs a int index that is unique among all leaf nodes
051 * because they have an AtomicModel which need such an index.
052 */
053 private static int leafModelIndexConter = 0;
054
055 private final static int OTHER_BRANCH_ID = 1000000;// Integer.MAX_VALUE;
056
057 // The number of division used when doing cross validation test
058 private int numberOfCrossValidationSplits = 10;
059 /*
060 * Cross validation accuracy is calculated for every node during training
061 * This should be calculated for every node and is set to -1.0 if it isn't
062 * calculated yet
063 */
064 private final static double CROSS_VALIDATION_ACCURACY_NOT_SET_VALUE = -1.0;
065 private double crossValidationAccuracy = CROSS_VALIDATION_ACCURACY_NOT_SET_VALUE;
066 // The parent model
067 private Model parent = null;
068 // An ordered list of features to divide on
069 private LinkedList<FeatureFunction> divideFeatures = null;
070 /*
071 * The branches of the tree Is set to null if this is a leaf node
072 */
073 private SortedMap<Integer, DecisionTreeModel> branches = null;
074
075 /*
076 * This model is used if this is a leaf node Is set to null if this is a
077 * branch node
078 */
079 private AtomicModel leafModel = null;
080 // Number of training instances added
081 private int frequency = 0;
082 /*
083 * min number of instances for a node to existAll sub nodes with less
084 * instances will be concatenated to one sub node
085 */
086 private int divideThreshold = 0;
087 // The feature vector for this problem
088 private FeatureVector featureVector;
089
090 private FeatureVector subFeatureVector = null;
091
092 // Used to indicate that the modelIndex field is not set
093 private static final int MODEL_INDEX_NOT_SET = Integer.MIN_VALUE;
094 /*
095 * Model index is the identifier used to distinguish this model from other
096 * models at the same level. This should not be used in the root model and
097 * has the value MODEL_INDEX_NOT_SET in it.
098 */
099 private int modelIndex = MODEL_INDEX_NOT_SET;
100 // Indexes of the column used to divide on
101 private ArrayList<Integer> divideFeatureIndexVector;
102
103 private boolean automaticSplit = false;
104 private boolean treeForceDivide = false;
105
106 /**
107 * Constructs a feature divide model.
108 *
109 * @param featureVector
110 * the feature vector used by the decision tree model
111 * @param parent
112 * the parent guide model.
113 * @throws MaltChainedException
114 */
115 public DecisionTreeModel(FeatureVector featureVector, Model parent)
116 throws MaltChainedException {
117
118 this.featureVector = featureVector;
119 this.divideFeatures = new LinkedList<FeatureFunction>();
120 setParent(parent);
121 setFrequency(0);
122 initDecisionTreeParam();
123
124 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
125
126 // Prepare for training
127
128 branches = new TreeMap<Integer, DecisionTreeModel>();
129 leafModel = new AtomicModel(-1, featureVector, this);
130
131 } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
132 load();
133 }
134 }
135
136 /*
137 * This constructor is used from within objects of the class to create sub decision tree models
138 *
139 *
140 */
141 private DecisionTreeModel(int modelIndex, FeatureVector featureVector,
142 Model parent, LinkedList<FeatureFunction> divideFeatures,
143 int divideThreshold) throws MaltChainedException {
144
145 this.featureVector = featureVector;
146
147 setParent(parent);
148 setFrequency(0);
149
150 this.modelIndex = modelIndex;
151 this.divideFeatures = divideFeatures;
152 this.divideThreshold = divideThreshold;
153
154 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
155
156 //Create the divide feature index vector
157 if (divideFeatures.size() > 0) {
158
159 divideFeatureIndexVector = new ArrayList<Integer>();
160 for (int i = 0; i < featureVector.size(); i++) {
161 if (featureVector.get(i).equals(divideFeatures.get(0))) {
162 divideFeatureIndexVector.add(i);
163 }
164 }
165
166 }
167 leafModelIndexConter++;
168
169
170 // Prepare for training
171 branches = new TreeMap<Integer, DecisionTreeModel>();
172 leafModel = new AtomicModel(-1, featureVector, this);
173
174 } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
175 load();
176 }
177 }
178
179 /**
180 * Loads the feature divide model settings .fsm file.
181 *
182 * @throws MaltChainedException
183 */
184 protected void load() throws MaltChainedException {
185
186 ConfigurationDir configDir = getGuide().getConfiguration()
187 .getConfigurationDir();
188
189
190 // load the dsm file
191
192 try {
193
194 final BufferedReader in = new BufferedReader(
195 configDir.getInputStreamReaderFromConfigFile(getModelName()
196 + ".dsm"));
197 final Pattern tabPattern = Pattern.compile("\t");
198
199 boolean first = true;
200 while (true) {
201 String line = in.readLine();
202 if (line == null)
203 break;
204 String[] cols = tabPattern.split(line);
205 if (cols.length != 2) {
206 throw new GuideException("");
207 }
208 int code = -1;
209 int freq = 0;
210 try {
211 code = Integer.parseInt(cols[0]);
212 freq = Integer.parseInt(cols[1]);
213 } catch (NumberFormatException e) {
214 throw new GuideException(
215 "Could not convert a string value into an integer value when loading the feature divide model settings (.fsm). ",
216 e);
217 }
218
219 if (code == MODEL_INDEX_NOT_SET) {
220 if (!first)
221 throw new GuideException(
222 "Error in config file '"
223 + getModelName()
224 + ".dsm"
225 + "'. If the index in the .dsm file is MODEL_INDEX_NOT_SET it should be the first.");
226
227 first = false;
228 // It is a leaf node
229 // Create atomic model for the leaf node
230 leafModel = new AtomicModel(-1, featureVector, this);
231
232 // setIsLeafNode();
233
234 } else {
235 if (first) {
236 // Create the branches holder
237
238 branches = new TreeMap<Integer, DecisionTreeModel>();
239
240 // setIsBranchNode();
241
242 first = false;
243 }
244
245 if (branches == null)
246 throw new GuideException(
247 "Error in config file '"
248 + getModelName()
249 + ".dsm"
250 + "'. If MODEL_INDEX_NOT_SET is the first model index in the .dsm file it should be the only.");
251
252 if (code == OTHER_BRANCH_ID)
253 branches.put(code, new DecisionTreeModel(code,
254 featureVector, this,
255 new LinkedList<FeatureFunction>(),
256 divideThreshold));
257 else
258 branches.put(code, new DecisionTreeModel(code,
259 getSubFeatureVector(), this,
260 createNextLevelDivideFeatures(),
261 divideThreshold));
262
263 branches.get(code).setFrequency(freq);
264
265 setFrequency(getFrequency() + freq);
266
267 }
268
269 }
270 in.close();
271
272 } catch (IOException e) {
273 throw new GuideException(
274 "Could not read from the guide model settings file '"
275 + getModelName() + ".dsm" + "', when "
276 + "loading the guide model settings. ", e);
277 }
278
279 }
280
281 private void initDecisionTreeParam() throws MaltChainedException {
282 String treeSplitColumns = getGuide().getConfiguration().getOptionValue(
283 "guide", "tree_split_columns").toString();
284 String treeSplitStructures = getGuide().getConfiguration()
285 .getOptionValue("guide", "tree_split_structures").toString();
286
287 automaticSplit = getGuide().getConfiguration()
288 .getOptionValue("guide", "tree_automatic_split_order").toString().equals("yes");
289
290 treeForceDivide = getGuide().getConfiguration()
291 .getOptionValue("guide", "tree_force_divide").toString().equals("yes");
292
293 if(automaticSplit){
294 divideFeatures = new LinkedList<FeatureFunction>();
295 for(FeatureFunction feature:featureVector){
296 if(feature.getFeatureValue() instanceof SingleFeatureValue)
297 divideFeatures.add(feature);
298 }
299
300
301 }else{
302
303 if (treeSplitColumns == null || treeSplitColumns.length() == 0) {
304 throw new GuideException(
305 "The option '--guide-tree_split_columns' cannot be found, when initializing the decision tree model. ");
306 }
307
308 if (treeSplitStructures == null || treeSplitStructures.length() == 0) {
309 throw new GuideException(
310 "The option '--guide-tree_split_structures' cannot be found, when initializing the decision tree model. ");
311 }
312
313 String[] treeSplitColumnsArray = treeSplitColumns.split("@");
314 String[] treeSplitStructuresArray = treeSplitStructures.split("@");
315
316 if (treeSplitColumnsArray.length != treeSplitStructuresArray.length)
317 throw new GuideException(
318 "The option '--guide-tree_split_structures' and '--guide-tree_split_columns' must be followed by a ; separated lists of the same length");
319
320 try {
321
322 for (int n = 0; n < treeSplitColumnsArray.length; n++) {
323
324 final String spec = "InputColumn("
325 + treeSplitColumnsArray[n].trim() + ", "
326 + treeSplitStructuresArray[n].trim() + ")";
327
328 divideFeatures.addLast(featureVector.getFeatureModel()
329 .identifyFeature(spec));
330 }
331
332 } catch (FeatureException e) {
333 throw new GuideException("The data split feature 'InputColumn("
334 + getGuide().getConfiguration().getOptionValue("guide",
335 "data_split_column").toString()
336 + ", "
337 + getGuide().getConfiguration().getOptionValue("guide",
338 "data_split_structure").toString()
339 + ") cannot be initialized. ", e);
340 }
341
342 for (FeatureFunction divideFeature : divideFeatures) {
343 if (!(divideFeature instanceof Modifiable)) {
344 throw new GuideException("The data split feature 'InputColumn("
345 + getGuide().getConfiguration().getOptionValue("guide",
346 "data_split_column").toString()
347 + ", "
348 + getGuide().getConfiguration().getOptionValue("guide",
349 "data_split_structure").toString()
350 + ") does not implement Modifiable interface. ");
351 }
352 }
353
354 divideFeatureIndexVector = new ArrayList<Integer>();
355 for (int i = 0; i < featureVector.size(); i++) {
356
357 if (featureVector.get(i).equals(divideFeatures.get(0))) {
358
359 divideFeatureIndexVector.add(i);
360 }
361 }
362
363 if (divideFeatureIndexVector.size() == 0) {
364 throw new GuideException(
365 "Could not match the given divide features to any of the available features.");
366 }
367
368
369
370 }
371
372 try {
373
374 String treeSplitTreshold = getGuide().getConfiguration()
375 .getOptionValue("guide", "tree_split_threshold").toString();
376
377 if (treeSplitTreshold != null && treeSplitTreshold.length() > 0) {
378
379 divideThreshold = Integer.parseInt(treeSplitTreshold);
380
381 } else {
382 divideThreshold = 0;
383 }
384 } catch (NumberFormatException e) {
385 throw new GuideException(
386 "The --guide-tree_split_threshold option is not an integer value. ",
387 e);
388 }
389
390 try {
391
392 String treeNumberOfCrossValidationDivisions = getGuide()
393 .getConfiguration().getOptionValue("guide",
394 "tree_number_of_cross_validation_divisions")
395 .toString();
396
397 if (treeNumberOfCrossValidationDivisions != null
398 && treeNumberOfCrossValidationDivisions.length() > 0) {
399
400 numberOfCrossValidationSplits = Integer
401 .parseInt(treeNumberOfCrossValidationDivisions);
402
403 } else {
404 divideThreshold = 0;
405 }
406 } catch (NumberFormatException e) {
407 throw new GuideException(
408 "The --guide-tree_number_of_cross_validation_divisions option is not an integer value. ",
409 e);
410 }
411
412 }
413
414 @Override
415 public void addInstance(SingleDecision decision)
416 throws MaltChainedException {
417
418 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
419 throw new GuideException("Can only add instance during learning. ");
420 } else if (divideFeatures.size() > 0) {
421 //FeatureFunction divideFeature = divideFeatures.getFirst();
422
423 for (FeatureFunction divideFeature : divideFeatures) {
424 if (!(divideFeature.getFeatureValue() instanceof SingleFeatureValue)) {
425 throw new GuideException(
426 "The divide feature does not have a single value. ");
427 }
428 // Is this necessary?
429 divideFeature.update();
430 }
431 leafModel.addInstance(decision);
432
433 //Update statistics data
434 updateStatistics(decision);
435
436
437 } else {
438 // Model has already been decided. It is a leaf node
439 if (branches != null)
440 setIsLeafNode();
441
442 leafModel.addInstance(decision);
443
444 //Update statistics data
445 updateStatistics(decision);
446
447 }
448
449
450
451
452 }
453
454 /*
455 private class StatisticsItem{
456
457 private int columnValue;
458
459 private int classValue;
460
461 public StatisticsItem(int columnValue, int classValue) {
462 super();
463 this.columnValue = columnValue;
464 this.classValue = classValue;
465 }
466
467 public int getColumnValue() {
468 return columnValue;
469 }
470
471 public int getClassValue() {
472 return classValue;
473 }
474
475 @Override
476 public int hashCode() {
477 return new Integer(columnValue/2).hashCode() + new Integer(classValue/2).hashCode();
478 }
479
480 @Override
481 public boolean equals(Object obj) {
482
483 StatisticsItem compItem = (StatisticsItem)obj;
484
485 return compItem.getClassValue()==this.getClassValue() && compItem.getColumnValue()==this.getColumnValue();
486 }
487 }
488 */
489
490 /*
491 * Helper method used for automatic division by gain ratio
492 * @param n
493 * @return
494 */
495 private double log2(double n){
496 return Math.log(n)/Math.log(2.0);
497 }
498
499 /*
500 * This map contains one item per element in the divideFeatures. Mappings exist from every Feature function
501 * in divideFeatures to a corresponding Statistics Item list that contains statistics for that divide feature.
502 * In all positions in the list are a list of StatisticsItems one for every unique feature class
503 * combination in the column. The statistics item also contain a count of that combination.
504 */
505 //private HashMap<FeatureFunction, HashMap<StatisticsItem, Integer>> statisticsForDivideFatureMap = null;
506 //The keys are class id's and the value is a count of the number of this
507 private HashMap<Integer,Integer> classIdToCountMap = null;
508
509 private HashMap<FeatureFunction, HashMap<Integer,Integer>> featureIdToCountMap = null;
510
511 //private HashMap<FeatureFunction, HashMap<Integer,Integer>> classIdToCountMap = new HashMap<FeatureFunction, HashMap<Integer,Integer>>();
512
513 private HashMap<FeatureFunction, HashMap<Integer,HashMap<Integer,Integer>>> featureIdToClassIdToCountMap = null;
514
515 private void updateStatistics(SingleDecision decision)
516 throws MaltChainedException {
517
518 // if(statisticsForDivideFatureMap==null){
519 // statisticsForDivideFatureMap = new HashMap<FeatureFunction,
520 // HashMap<StatisticsItem, Integer>>();
521 //
522 // for(FeatureFunction columnsDivideFeature : divideFeatures)
523 // statisticsForDivideFatureMap.put(columnsDivideFeature, new
524 // HashMap<StatisticsItem, Integer>());
525 // }
526 //
527 //
528 // int instanceClass = decision.getDecisionCode();
529 //
530 // Integer classCount = classCountStatistics.get(instanceClass);
531 //
532 // if(classCount==null){
533 // classCount=0;
534 // }
535 //
536 // classCountStatistics.put(instanceClass, classCount+1);
537 //
538 // for(FeatureFunction columnsDivideFeature : featureVector){
539 //
540 // int featureCode =
541 // ((SingleFeatureValue)columnsDivideFeature.getFeatureValue()).getCode();
542 // HashMap<StatisticsItem, Integer> statisticsMap =
543 // statisticsForDivideFatureMap.get(columnsDivideFeature);
544 // if(statisticsMap!=null){
545 //
546 // StatisticsItem item = new StatisticsItem(featureCode, instanceClass);
547 //
548 // Integer count = statisticsMap.get(item);
549 //
550 // if(count==null){
551 // //Add the statistic item to the map
552 // count = 0;
553 // }
554 //
555 // statisticsMap.put(item, count + 1);
556 //
557 // }
558 //
559 // }
560
561 // If it is not done initialize the statistics maps
562 if (featureIdToCountMap == null) {
563
564 featureIdToCountMap = new HashMap<FeatureFunction, HashMap<Integer, Integer>>();
565
566 for (FeatureFunction columnsDivideFeature : divideFeatures)
567 featureIdToCountMap.put(columnsDivideFeature,
568 new HashMap<Integer, Integer>());
569
570
571 featureIdToClassIdToCountMap = new HashMap<FeatureFunction, HashMap<Integer, HashMap<Integer, Integer>>>();
572
573 for (FeatureFunction columnsDivideFeature : divideFeatures)
574 featureIdToClassIdToCountMap.put(columnsDivideFeature,
575 new HashMap<Integer, HashMap<Integer, Integer>>());
576
577 classIdToCountMap = new HashMap<Integer, Integer>();
578
579 }
580
581 int instanceClass = decision.getDecisionCode();
582
583 // Increase classCountStatistics
584
585 Integer classCount = classIdToCountMap.get(instanceClass);
586
587 if (classCount == null) {
588 classCount = 0;
589 }
590
591 classIdToCountMap.put(instanceClass, classCount + 1);
592
593 // Increase featureIdToCountMap
594
595 for (FeatureFunction columnsDivideFeature : divideFeatures) {
596
597 int featureCode = ((SingleFeatureValue) columnsDivideFeature
598 .getFeatureValue()).getCode();
599
600 HashMap<Integer, Integer> statisticsMap = featureIdToCountMap
601 .get(columnsDivideFeature);
602
603 Integer count = statisticsMap.get(featureCode);
604
605 if (count == null) {
606 // Add the statistic item to the map
607 count = 0;
608 }
609
610 statisticsMap.put(featureCode, count + 1);
611
612 }
613
614 // Increase featureIdToClassIdToCountMap
615
616 for (FeatureFunction columnsDivideFeature : divideFeatures) {
617
618 int featureCode = ((SingleFeatureValue) columnsDivideFeature
619 .getFeatureValue()).getCode();
620
621 HashMap<Integer, HashMap<Integer, Integer>> featureIdToclassIdToCountMapTmp = featureIdToClassIdToCountMap
622 .get(columnsDivideFeature);
623
624 HashMap<Integer, Integer> classIdToCountMapTmp = featureIdToclassIdToCountMapTmp.get(featureCode);
625
626 if (classIdToCountMapTmp == null) {
627 // Add the statistic item to the map
628 classIdToCountMapTmp = new HashMap<Integer, Integer>();
629
630 featureIdToclassIdToCountMapTmp.put(featureCode, classIdToCountMapTmp);
631 }
632
633 Integer count = classIdToCountMapTmp.get(instanceClass);
634
635 if (count == null) {
636 // Add the statistic item to the map
637 count = 0;
638 }
639
640 classIdToCountMapTmp.put(instanceClass, count + 1);
641
642 }
643
644 }
645
646 @SuppressWarnings("unchecked")
647 private LinkedList<FeatureFunction> createNextLevelDivideFeatures() {
648
649 LinkedList<FeatureFunction> nextLevelDivideFeatures = (LinkedList<FeatureFunction>) divideFeatures
650 .clone();
651
652 nextLevelDivideFeatures.removeFirst();
653
654 return nextLevelDivideFeatures;
655 }
656
657 /*
658 * Removes the current divide feature from the feature vector so it is not
659 * present in the sub node
660 */
661 private FeatureVector getSubFeatureVector() {
662
663 if (subFeatureVector != null)
664 return subFeatureVector;
665
666 FeatureFunction divideFeature = divideFeatures.getFirst();
667
668 ArrayList<Integer> divideFeatureIndexVector = new ArrayList<Integer>();
669 for (int i = 0; i < featureVector.size(); i++) {
670 if (featureVector.get(i).equals(divideFeature)) {
671 divideFeatureIndexVector.add(i);
672 }
673 }
674
675 FeatureVector divideFeatureVector = (FeatureVector) featureVector
676 .clone();
677
678 for (Integer i : divideFeatureIndexVector) {
679 divideFeatureVector.remove(divideFeatureVector.get(i));
680 }
681
682 subFeatureVector = divideFeatureVector;
683
684 return divideFeatureVector;
685 }
686
687 @Override
688 public FeatureVector extract() throws MaltChainedException {
689
690 return getCurrentAtomicModel().extract();
691
692 }
693
694 /*
695 * Returns the atomic model that is effected by this parsing step
696 */
697 private AtomicModel getCurrentAtomicModel() throws MaltChainedException {
698
699 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
700 throw new GuideException("Can only predict during parsing. ");
701 }
702
703 if (branches == null && leafModel != null)
704 return leafModel;
705
706 FeatureFunction divideFeature = divideFeatures.getFirst();
707
708 if (!(divideFeature.getFeatureValue() instanceof SingleFeatureValue)) {
709 throw new GuideException(
710 "The divide feature does not have a single value. ");
711 }
712
713 if (branches != null
714 && branches.containsKey(((SingleFeatureValue) divideFeature
715 .getFeatureValue()).getCode())) {
716 return branches.get(
717 ((SingleFeatureValue) divideFeature.getFeatureValue())
718 .getCode()).getCurrentAtomicModel();
719 } else if (branches.containsKey(OTHER_BRANCH_ID)
720 && branches.get(OTHER_BRANCH_ID).getFrequency() > 0) {
721 return branches.get(OTHER_BRANCH_ID).getCurrentAtomicModel();
722 } else {
723 getGuide()
724 .getConfiguration()
725 .getConfigLogger()
726 .info(
727 "Could not predict the next parser decision because there is "
728 + "no divide or master model that covers the divide value '"
729 + ((SingleFeatureValue) divideFeature
730 .getFeatureValue()).getCode()
731 + "', as default"
732 + " class code '1' is used. ");
733 }
734 return null;
735 }
736
737 /**
738 * Increase the frequency by 1
739 */
740 public void increaseFrequency() {
741 frequency++;
742 }
743
744 public void decreaseFrequency() {
745 frequency--;
746 }
747
748 @Override
749 public boolean predict(SingleDecision decision) throws MaltChainedException {
750
751 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
752 throw new GuideException("Can only predict during parsing. ");
753 } else if (divideFeatures.size() > 0
754 && !(divideFeatures.getFirst().getFeatureValue() instanceof SingleFeatureValue)) {
755 throw new GuideException(
756 "The divide feature does not have a single value. ");
757 }
758
759
760 if (branches != null
761 && branches.containsKey(((SingleFeatureValue) divideFeatures
762 .getFirst().getFeatureValue()).getCode())) {
763
764 return branches.get(
765 ((SingleFeatureValue) divideFeatures.getFirst()
766 .getFeatureValue()).getCode()).predict(decision);
767 } else if (branches != null && branches.containsKey(OTHER_BRANCH_ID)) {
768
769 return branches.get(OTHER_BRANCH_ID).predict(decision);
770 } else if (leafModel != null) {
771
772 return leafModel.predict(decision);
773 } else {
774
775 getGuide()
776 .getConfiguration()
777 .getConfigLogger()
778 .info(
779 "Could not predict the next parser decision because there is "
780 + "no divide or master model that covers the divide value '"
781 + ((SingleFeatureValue) divideFeatures
782 .getFirst().getFeatureValue())
783 .getCode() + "', as default"
784 + " class code '1' is used. ");
785
786 decision.addDecision(1); // default prediction
787 // classCodeTable.getEmptyKBestList().addKBestItem(1);
788 }
789 return true;
790 }
791
792 @Override
793 public FeatureVector predictExtract(SingleDecision decision)
794 throws MaltChainedException {
795 return getCurrentAtomicModel().predictExtract(decision);
796 }
797
798 /*
799 * Decides if this is a branch or leaf node by doing cross validation and
800 * returns the cross validation score for this node
801 */
802 private double decideNodeType() throws MaltChainedException {
803
804 // We don't want to do this twice test
805 if (crossValidationAccuracy != CROSS_VALIDATION_ACCURACY_NOT_SET_VALUE)
806 return crossValidationAccuracy;
807
808 if (modelIndex == MODEL_INDEX_NOT_SET)
809 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
810 getGuide().getConfiguration().getConfigLogger().info(
811 "Starting deph first pruning of the decision tree\n");
812 }
813
814 long start = System.currentTimeMillis();
815
816 double leafModelCrossValidationAccuracy = 0.0;
817
818 if(treeForceDivide)
819 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
820 getGuide().getConfiguration().getConfigLogger().info(
821 "Skipping cross validation of the root node since the flag tree_force_divide is set to yes. " +
822 "The cross validation score for the root node is set to zero.\n");
823 }
824
825 if(!treeForceDivide)
826 leafModelCrossValidationAccuracy = leafModel.getMethod()
827 .crossValidate(featureVector, numberOfCrossValidationSplits);
828
829 long stop = System.currentTimeMillis();
830
831 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
832 getGuide().getConfiguration().getConfigLogger().info(
833 "Cross Validation Time: " + (stop - start) + " ms"
834 + " for model " + getModelName() + "\n");
835 }
836
837 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
838 getGuide().getConfiguration().getConfigLogger().info(
839 "Cross Validation Accuracy as leaf node = "
840 + leafModelCrossValidationAccuracy + " for model "
841 + getModelName() + "\n");
842 }
843
844 if (branches == null && leafModel != null) {// If it is already decided
845 // that this is a leaf node
846
847 crossValidationAccuracy = leafModelCrossValidationAccuracy;
848
849 return crossValidationAccuracy;
850
851 }
852
853 int totalFrequency = 0;
854 double totalAccuracyCount = 0.0;
855 // Calculate crossValidationAccuracy for branch nodes
856 for (DecisionTreeModel b : branches.values()) {
857
858 double bAccuracy = b.decideNodeType();
859
860 totalFrequency = totalFrequency + b.getFrequency();
861
862 totalAccuracyCount = totalAccuracyCount + bAccuracy
863 * b.getFrequency();
864
865 }
866
867 double branchModelCrossValidationAccuracy = totalAccuracyCount
868 / totalFrequency;
869
870 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
871 getGuide().getConfiguration().getConfigLogger().info(
872 "Total Cross Validation Accuracy for branches = "
873 + branchModelCrossValidationAccuracy
874 + " for model " + getModelName() + "\n");
875 }
876
877 // Finally decide which model to use
878 if (branchModelCrossValidationAccuracy > leafModelCrossValidationAccuracy) {
879
880 setIsBranchNode();
881
882 crossValidationAccuracy = branchModelCrossValidationAccuracy;
883
884 return crossValidationAccuracy;
885
886 } else {
887
888 setIsLeafNode();
889
890 crossValidationAccuracy = leafModelCrossValidationAccuracy;
891
892 return crossValidationAccuracy;
893
894 }
895
896 }
897
898 @Override
899 public void train() throws MaltChainedException {
900
901 // Decide node type
902 // This operation is more expensive than the training itself
903 decideNodeType();
904
905 // Do the training depending on which type of node this is
906 if (branches == null && leafModel != null) {
907
908 // If it is a leaf node
909
910 leafModel.train();
911
912 save();
913
914 leafModel.terminate();
915
916 } else {
917 // It is a branch node
918
919 for (DecisionTreeModel b : branches.values())
920 b.train();
921
922 save();
923
924 for (DecisionTreeModel b : branches.values())
925 b.terminate();
926
927 }
928 terminate();
929
930 }
931
932 /**
933 * Saves the decision tree model settings .dsm file.
934 *
935 * @throws MaltChainedException
936 */
937 private void save() throws MaltChainedException {
938 try {
939
940 final BufferedWriter out = new BufferedWriter(getGuide()
941 .getConfiguration().getConfigurationDir()
942 .getOutputStreamWriter(getModelName() + ".dsm"));
943
944 if (branches != null) {
945 for (DecisionTreeModel b : branches.values()) {
946 out.write(b.getModelIndex() + "\t" + b.getFrequency()
947 + "\n");
948 }
949 } else {
950 out.write(MODEL_INDEX_NOT_SET + "\t" + getFrequency() + "\n");
951 }
952
953 out.close();
954
955 } catch (IOException e) {
956 throw new GuideException(
957 "Could not write to the guide model settings file '"
958 + getModelName() + ".dsm"
959 + "' or the name mapping file '" + getModelName()
960 + ".nmf" + "', when "
961 + "saving the guide model settings to files. ", e);
962 }
963 }
964
965 @Override
966 public void finalizeSentence(DependencyStructure dependencyGraph)
967 throws MaltChainedException {
968
969 if (branches != null) {
970
971 for (DecisionTreeModel b : branches.values()) {
972 b.finalizeSentence(dependencyGraph);
973 }
974
975 } else if (leafModel != null) {
976
977 leafModel.finalizeSentence(dependencyGraph);
978
979 } else {
980
981 throw new GuideException(
982 "The feature divide models cannot be found. ");
983
984 }
985
986 }
987
988 @Override
989 public ClassifierGuide getGuide() {
990 return parent.getGuide();
991 }
992
993 @Override
994 public String getModelName() throws MaltChainedException {
995 try {
996
997 return parent.getModelName()
998 + (modelIndex == MODEL_INDEX_NOT_SET ? ""
999 : ("_" + modelIndex));
1000 } catch (NullPointerException e) {
1001 throw new GuideException(
1002 "The parent guide model cannot be found. ", e);
1003 }
1004 }
1005
1006 /*
1007 * This is called to define this node as to be in the leaf state. It sets branches to null.
1008 */
1009 private void setIsLeafNode() throws MaltChainedException {
1010
1011 if (branches == null && leafModel != null)
1012 return;
1013
1014 if (branches != null && leafModel != null) {
1015
1016 for (DecisionTreeModel t : branches.values())
1017 t.terminate();
1018
1019 branches = null;
1020
1021 } else
1022 throw new MaltChainedException(
1023 "Can't set a node that have aleready been set to a leaf node.");
1024
1025 }
1026 /*
1027 * This is called to define this node as to be in the branch state. It sets leafModel to null.
1028 */
1029 private void setIsBranchNode() throws MaltChainedException {
1030 if (branches != null && leafModel != null) {
1031
1032 leafModel.terminate();
1033
1034 leafModel = null;
1035
1036 } else
1037 throw new MaltChainedException(
1038 "Can't set a node that have aleready been set to a branch node.");
1039
1040 }
1041
1042
1043 @Override
1044 public void noMoreInstances() throws MaltChainedException {
1045
1046 if (leafModel == null)
1047 throw new GuideException(
1048 "The model in tree node is null in a state where it is not allowed");
1049
1050 leafModel.noMoreInstances();
1051
1052 if (divideFeatures.size() == 0)
1053 setIsLeafNode();
1054
1055 if (branches != null) {
1056
1057 if(automaticSplit){
1058
1059 divideFeatures = createGainRatioSplitList(divideFeatures);
1060
1061 divideFeatureIndexVector = new ArrayList<Integer>();
1062 for (int i = 0; i < featureVector.size(); i++) {
1063
1064 if (featureVector.get(i).equals(divideFeatures.get(0))) {
1065
1066 divideFeatureIndexVector.add(i);
1067 }
1068 }
1069
1070 if (divideFeatureIndexVector.size() == 0) {
1071 throw new GuideException(
1072 "Could not match the given divide features to any of the available features.");
1073 }
1074
1075 }
1076
1077 FeatureFunction divideFeature = divideFeatures.getFirst();
1078
1079 divideFeature.updateCardinality();
1080
1081 leafModel.noMoreInstances();
1082
1083 Map<Integer, Integer> divideFeatureIdToCountMap = leafModel
1084 .getMethod().createFeatureIdToCountMap(
1085 divideFeatureIndexVector);
1086
1087 int totalInOther = 0;
1088
1089 Set<Integer> featureIdsToCreateSeparateBranchesForSet = new HashSet<Integer>();
1090
1091 List<Integer> removeFromDivideFeatureIdToCountMap = new LinkedList<Integer>();
1092
1093 for (Entry<Integer, Integer> entry : divideFeatureIdToCountMap
1094 .entrySet())
1095 if (entry.getValue() >= divideThreshold) {
1096 featureIdsToCreateSeparateBranchesForSet
1097 .add(entry.getKey());
1098 } else {
1099 removeFromDivideFeatureIdToCountMap.add(entry.getKey());
1100 totalInOther = totalInOther + entry.getValue();
1101 }
1102
1103 for (int removeIndex : removeFromDivideFeatureIdToCountMap)
1104 divideFeatureIdToCountMap.remove(removeIndex);
1105
1106 boolean otherExists = false;
1107
1108 if (totalInOther > 0)
1109 otherExists = true;
1110
1111 if ((totalInOther < divideThreshold && featureIdsToCreateSeparateBranchesForSet
1112 .size() <= 1)
1113 || featureIdsToCreateSeparateBranchesForSet.size() == 0) {
1114 // Node enough instances, make this a leaf node
1115 setIsLeafNode();
1116 } else {
1117
1118 // If total in other is less then divideThreshold then add the
1119 // smallest of the other parts to other
1120 if (otherExists && totalInOther < divideThreshold) {
1121 int smallestSoFar = Integer.MAX_VALUE;
1122 int smallestSoFarId = Integer.MAX_VALUE;
1123 for (Entry<Integer, Integer> entry : divideFeatureIdToCountMap
1124 .entrySet()) {
1125 if (entry.getValue() < smallestSoFar) {
1126 smallestSoFar = entry.getValue();
1127 smallestSoFarId = entry.getKey();
1128 }
1129 }
1130
1131 featureIdsToCreateSeparateBranchesForSet
1132 .remove(smallestSoFarId);
1133 }
1134
1135 // Create new files for all feature ids with count value greater
1136 // than divideThreshold and one for the
1137 // other branch
1138 leafModel.getMethod().divideByFeatureSet(
1139 featureIdsToCreateSeparateBranchesForSet,
1140 divideFeatureIndexVector, "" + OTHER_BRANCH_ID);
1141
1142 for (int id : featureIdsToCreateSeparateBranchesForSet) {
1143 DecisionTreeModel newBranch = new DecisionTreeModel(id,
1144 getSubFeatureVector(), this,
1145 createNextLevelDivideFeatures(), divideThreshold);
1146 branches.put(id, newBranch);
1147
1148 }
1149 if (otherExists) {
1150 DecisionTreeModel newBranch = new DecisionTreeModel(
1151 OTHER_BRANCH_ID, featureVector, this,
1152 new LinkedList<FeatureFunction>(), divideThreshold);
1153 branches.put(OTHER_BRANCH_ID, newBranch);
1154
1155 }
1156
1157 for (DecisionTreeModel b : branches.values())
1158 b.noMoreInstances();
1159
1160 }
1161
1162 }
1163
1164 }
1165
1166 @Override
1167 public void terminate() throws MaltChainedException {
1168 if (branches != null) {
1169 for (DecisionTreeModel branch : branches.values()) {
1170 branch.terminate();
1171 }
1172 branches = null;
1173 }
1174 if (leafModel != null) {
1175 leafModel.terminate();
1176 leafModel = null;
1177 }
1178
1179 }
1180
1181 public void setParent(Model parent) {
1182 this.parent = parent;
1183 }
1184
1185 public Model getParent() {
1186 return parent;
1187 }
1188
1189 public void setFrequency(int frequency) {
1190 this.frequency = frequency;
1191 }
1192
1193 public int getFrequency() {
1194 return frequency;
1195 }
1196
1197 public int getModelIndex() {
1198 return modelIndex;
1199 }
1200
1201
1202 private LinkedList<FeatureFunction> createGainRatioSplitList(LinkedList<FeatureFunction> divideFeatures) {
1203
1204 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
1205
1206 getGuide().getConfiguration().getConfigLogger().info(
1207 "Start calculating gain ratio for all posible divide features");
1208 }
1209
1210 //Calculate the root entropy
1211
1212 double total = 0;
1213
1214 for(int count: classIdToCountMap.values()){
1215 double fraction = ((double)count) / getFrequency();
1216 total = total + fraction*log2(fraction);
1217 }
1218
1219 double rootEntropy = -total;
1220
1221
1222 class FeatureFunctionInformationGainPair implements Comparable<FeatureFunctionInformationGainPair>{
1223 double informationGain;
1224 FeatureFunction featureFunction;
1225 double splitInfo;
1226
1227 public FeatureFunctionInformationGainPair(
1228 FeatureFunction featureFunction) {
1229 super();
1230 this.featureFunction = featureFunction;
1231 }
1232
1233 public double getGainRatio(){
1234 return informationGain/splitInfo;
1235 }
1236
1237 @Override
1238 public int compareTo(FeatureFunctionInformationGainPair o) {
1239
1240 int result = 0;
1241
1242 if((this.getGainRatio() - o.getGainRatio()) <0)
1243 result = -1;
1244 else if ((this.getGainRatio() - o.getGainRatio()) >0)
1245 result = 1;
1246
1247 return result;
1248 }
1249 }
1250
1251 ArrayList<FeatureFunctionInformationGainPair> gainRatioList = new ArrayList<FeatureFunctionInformationGainPair>();
1252
1253 for(FeatureFunction f: divideFeatures)
1254 gainRatioList.add(new FeatureFunctionInformationGainPair(f));
1255
1256 //For all divide features calculate the gain ratio
1257
1258 for(FeatureFunctionInformationGainPair p : gainRatioList){
1259
1260 HashMap<Integer, Integer> featureIdToCountMapTmp = featureIdToCountMap.get(p.featureFunction);
1261
1262 HashMap<Integer, HashMap<Integer, Integer>> featureIdToClassIdToCountMapTmp = featureIdToClassIdToCountMap.get(p.featureFunction);
1263
1264 double sum = 0;
1265
1266 for(Entry<Integer, Integer> entry:featureIdToCountMapTmp.entrySet()){
1267 int featureId = entry.getKey();
1268 int numberOfElementsWithFeatureId = entry.getValue();
1269 HashMap<Integer, Integer> classIdToCountMapTmp = featureIdToClassIdToCountMapTmp.get(featureId);
1270
1271 double sumImpurityMesure = 0;
1272 int totalElementsWithIdAndClass = 0;
1273 for(int elementsWithIdAndClass : classIdToCountMapTmp.values()){
1274
1275 double fractionOfInstancesBelongingToClass = ((double)elementsWithIdAndClass)/numberOfElementsWithFeatureId;
1276
1277 totalElementsWithIdAndClass = totalElementsWithIdAndClass + elementsWithIdAndClass;
1278
1279 sumImpurityMesure= sumImpurityMesure+fractionOfInstancesBelongingToClass*log2(fractionOfInstancesBelongingToClass);
1280
1281 }
1282
1283 double impurityMesure = -sumImpurityMesure;
1284
1285 sum = sum + (((double)numberOfElementsWithFeatureId)/getFrequency())*impurityMesure;
1286
1287 }
1288 p.informationGain = rootEntropy - sum;
1289
1290 //Calculate split info
1291
1292 double splitInfoTotal = 0;
1293
1294 for(int nrOfElementsWithFeatureId:featureIdToCountMapTmp.values()){
1295 double fractionOfTotal = ((double)nrOfElementsWithFeatureId)/getFrequency();
1296 splitInfoTotal = splitInfoTotal + fractionOfTotal*log2(fractionOfTotal);
1297 }
1298 p.splitInfo= splitInfoTotal;
1299
1300
1301 }
1302 Collections.sort(gainRatioList);
1303
1304
1305
1306 //Log the result if info is enabled
1307 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
1308
1309 getGuide().getConfiguration().getConfigLogger().info(
1310 "Gain ratio calculation finished the result follows:\n");
1311 getGuide().getConfiguration().getConfigLogger().info(
1312 "Divide Feature\tGain Ratio\tInformation Gain\tSplit Info\n");
1313
1314 for(FeatureFunctionInformationGainPair p :gainRatioList)
1315 getGuide().getConfiguration().getConfigLogger().info(
1316 p.featureFunction + "\t" + p.getGainRatio() + "\t" + p.informationGain + "\t" + p.splitInfo +"\n");
1317 }
1318
1319 LinkedList<FeatureFunction> divideFeaturesNew = new LinkedList<FeatureFunction>();
1320
1321 for(FeatureFunctionInformationGainPair p :gainRatioList)
1322 divideFeaturesNew.add(p.featureFunction);
1323
1324
1325 return divideFeaturesNew;
1326
1327 }
1328
1329 }