Skip to content
Snippets Groups Projects
Commit 44d044b1 authored by Eike Cochu's avatar Eike Cochu
Browse files

updated dtm analyzer, first version

parent 937fd46d
No related branches found
No related tags found
No related merge requests found
...@@ -12,14 +12,13 @@ import java.util.ArrayList; ...@@ -12,14 +12,13 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map;
import de.vipra.util.Config; import de.vipra.util.Config;
import de.vipra.util.Constants; import de.vipra.util.Constants;
import de.vipra.util.Constants.WindowResolution; import de.vipra.util.Constants.WindowResolution;
import de.vipra.util.CountMap;
import de.vipra.util.FileUtils; import de.vipra.util.FileUtils;
import de.vipra.util.ex.ConfigException; import de.vipra.util.ex.ConfigException;
...@@ -48,11 +47,12 @@ public class DTMSequenceIndex implements Closeable, Iterable<DTMSequenceIndex.DT ...@@ -48,11 +47,12 @@ public class DTMSequenceIndex implements Closeable, Iterable<DTMSequenceIndex.DT
} }
} }
private File file; private final File file;
private boolean readonly = false; private final boolean readonly = false;
private WindowResolution windowResolution; private final WindowResolution windowResolution;
private static List<DTMDateIndexEntry> entries; private static List<DTMDateIndexEntry> entries;
private static SimpleDateFormat df = new SimpleDateFormat(Constants.DATETIME_FORMAT); private static SimpleDateFormat df = new SimpleDateFormat(Constants.DATETIME_FORMAT);
private static CountMap<String> windowSizes = new CountMap<>();
public DTMSequenceIndex(File modelDir) throws IOException, ParseException, ConfigException { public DTMSequenceIndex(File modelDir) throws IOException, ParseException, ConfigException {
this(modelDir, false); this(modelDir, false);
...@@ -66,17 +66,21 @@ public class DTMSequenceIndex implements Closeable, Iterable<DTMSequenceIndex.DT ...@@ -66,17 +66,21 @@ public class DTMSequenceIndex implements Closeable, Iterable<DTMSequenceIndex.DT
if (entries == null || reread) { if (entries == null || reread) {
List<String> dates = FileUtils.readFile(file); List<String> dates = FileUtils.readFile(file);
entries = new ArrayList<>(dates.size()); entries = new ArrayList<>(dates.size());
for (String date : dates) { for (String date : dates)
entries.add(new DTMDateIndexEntry(df.parse(date), true, null)); add(df.parse(date));
}
} }
} else if (entries == null || reread) { } else if (entries == null || reread) {
entries = new ArrayList<>(); entries = new ArrayList<>();
} }
} }
private void add(Date date) {
add(date, null);
}
public void add(Date date, String line) { public void add(Date date, String line) {
entries.add(new DTMDateIndexEntry(date, false, line)); entries.add(new DTMDateIndexEntry(date, line == null, line));
windowSizes.count(windowResolution.fromDate(date));
} }
@Override @Override
...@@ -89,21 +93,12 @@ public class DTMSequenceIndex implements Closeable, Iterable<DTMSequenceIndex.DT ...@@ -89,21 +93,12 @@ public class DTMSequenceIndex implements Closeable, Iterable<DTMSequenceIndex.DT
public void close() throws IOException { public void close() throws IOException {
if (readonly) if (readonly)
return; return;
Map<String, Integer> windowSizes = new HashMap<>();
// write date index // write date index
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file, false))); BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file, false)));
for (DTMDateIndexEntry entry : entries) { for (DTMDateIndexEntry entry : entries) {
writer.write(df.format(entry.date)); writer.write(df.format(entry.date));
writer.write(Constants.LINE_SEP); writer.write(Constants.LINE_SEP);
String window = windowResolution.fromDate(entry.date);
Integer count = windowSizes.get(window);
if (count == null) {
windowSizes.put(window, 1);
} else {
windowSizes.put(window, count + 1);
}
} }
writer.close(); writer.close();
...@@ -112,7 +107,7 @@ public class DTMSequenceIndex implements Closeable, Iterable<DTMSequenceIndex.DT ...@@ -112,7 +107,7 @@ public class DTMSequenceIndex implements Closeable, Iterable<DTMSequenceIndex.DT
writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(seqFile, false))); writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(seqFile, false)));
writer.write(Integer.toString(windowSizes.size())); writer.write(Integer.toString(windowSizes.size()));
writer.write(Constants.LINE_SEP); writer.write(Constants.LINE_SEP);
// write window sizes // write window sizes
String[] windows = windowSizes.keySet().toArray(new String[windowSizes.size()]); String[] windows = windowSizes.keySet().toArray(new String[windowSizes.size()]);
Arrays.sort(windows); Arrays.sort(windows);
...@@ -120,12 +115,12 @@ public class DTMSequenceIndex implements Closeable, Iterable<DTMSequenceIndex.DT ...@@ -120,12 +115,12 @@ public class DTMSequenceIndex implements Closeable, Iterable<DTMSequenceIndex.DT
writer.write(Integer.toString(windowSizes.get(window))); writer.write(Integer.toString(windowSizes.get(window)));
writer.write(Constants.LINE_SEP); writer.write(Constants.LINE_SEP);
} }
writer.close(); writer.close();
} }
public int size() { public int size() {
return entries.size(); return windowSizes.size();
} }
} }
...@@ -20,7 +20,7 @@ public class DTMVocabulary implements Closeable, Iterable<String> { ...@@ -20,7 +20,7 @@ public class DTMVocabulary implements Closeable, Iterable<String> {
private File file; private File file;
private static List<String> vocables; private static List<String> vocables;
private static Map<String, Integer> vocablesMap; private static Map<String, Integer> vocablesMap;
private static int nextIndex = 1; private static int nextIndex = 0;
public DTMVocabulary(File modelDir) throws IOException { public DTMVocabulary(File modelDir) throws IOException {
this(modelDir, false); this(modelDir, false);
......
...@@ -7,11 +7,10 @@ import java.io.IOException; ...@@ -7,11 +7,10 @@ import java.io.IOException;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.text.ParseException; import java.text.ParseException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
...@@ -27,8 +26,8 @@ import de.vipra.cmd.file.FilebaseIndex; ...@@ -27,8 +26,8 @@ import de.vipra.cmd.file.FilebaseIndex;
import de.vipra.util.Config; import de.vipra.util.Config;
import de.vipra.util.Constants; import de.vipra.util.Constants;
import de.vipra.util.CountMap; import de.vipra.util.CountMap;
import de.vipra.util.FileUtils;
import de.vipra.util.StringUtils; import de.vipra.util.StringUtils;
import de.vipra.util.Tuple;
import de.vipra.util.ex.ConfigException; import de.vipra.util.ex.ConfigException;
import de.vipra.util.ex.DatabaseException; import de.vipra.util.ex.DatabaseException;
import de.vipra.util.model.ArticleFull; import de.vipra.util.model.ArticleFull;
...@@ -141,7 +140,7 @@ public class DTMAnalyzer extends Analyzer { ...@@ -141,7 +140,7 @@ public class DTMAnalyzer extends Analyzer {
// TODO find out what proportions are good for and where to store // TODO find out what proportions are good for and where to store
// them // them
File gamFile = new File(outDir, "gam.dat"); File gamFile = new File(outDirSeq, "gam.dat");
in = new BufferedReader(new InputStreamReader(new FileInputStream(gamFile))); in = new BufferedReader(new InputStreamReader(new FileInputStream(gamFile)));
for (int idxArticle = 0; idxArticle < index.size(); idxArticle++) { for (int idxArticle = 0; idxArticle < index.size(); idxArticle++) {
...@@ -166,27 +165,24 @@ public class DTMAnalyzer extends Analyzer { ...@@ -166,27 +165,24 @@ public class DTMAnalyzer extends Analyzer {
// read topic definition files and create topics // read topic definition files and create topics
Map<Word, Topic> topicWordMap = new HashMap<>(vocab.size());
List<TopicFull> newTopics = new ArrayList<>(Constants.K_TOPICS);
List<Word> newWords = new ArrayList<>(vocab.size());
int sequencesCount = sequences.size(); int sequencesCount = sequences.size();
int wordCount = vocab.size();
// collects created topics
List<TopicFull> newTopics = new ArrayList<>(Constants.K_TOPICS);
// collects created words
List<Word> newWords = new ArrayList<>(wordCount);
// collect mapping between words and topics
@SuppressWarnings("unchecked")
Tuple<Double, Integer>[] wordTopicMapping = (Tuple<Double, Integer>[]) new Tuple[wordCount];
// for each topic file // for each topic file
for (int i = 0; i < Constants.K_TOPICS; i++) { for (int i = 0; i < Constants.K_TOPICS; i++) {
File seqFile = new File(outDirSeq, "topic-" + StringUtils.padNumber(i, 3) + "-var-e-log-prob.dat"); File seqFile = new File(outDirSeq, "topic-" + StringUtils.padNumber(i, 3) + "-var-e-log-prob.dat");
int lineCount = FileUtils.countLines(seqFile);
int wordsCount = lineCount / sequencesCount;
if (wordsCount * sequencesCount != lineCount) {
log.error("unexpected number of words per sequence");
continue;
}
// create new topic // create new topic
TopicFull newTopic = new TopicFull(); TopicFull newTopic = new TopicFull();
List<Sequence> newSequences = new ArrayList<>(sequencesCount); List<Sequence> newSequences = new ArrayList<>(sequencesCount);
List<TopicWord> newTopicWords = new ArrayList<>(wordsCount); List<TopicWord> newTopicWords = new ArrayList<>(wordCount);
newTopic.setSequences(newSequences); newTopic.setSequences(newSequences);
newTopic.setWords(newTopicWords); newTopic.setWords(newTopicWords);
newTopics.add(newTopic); newTopics.add(newTopic);
...@@ -196,9 +192,11 @@ public class DTMAnalyzer extends Analyzer { ...@@ -196,9 +192,11 @@ public class DTMAnalyzer extends Analyzer {
// read file lines into word x sequence matrix // read file lines into word x sequence matrix
// gather maximum likeliness per sequence and per word // gather maximum likeliness per sequence and per word
double[] maxSeqLikelinesses = new double[sequencesCount]; double[] maxSeqLikelinesses = new double[sequencesCount];
double[] maxWordLikelinesses = new double[wordsCount]; Arrays.fill(maxSeqLikelinesses, Integer.MIN_VALUE);
double[][] likelinesses = new double[wordsCount][sequencesCount]; double[] maxWordLikelinesses = new double[wordCount];
for (int idxWord = 0; idxWord < wordsCount; idxWord++) { Arrays.fill(maxWordLikelinesses, Integer.MIN_VALUE);
double[][] likelinesses = new double[wordCount][sequencesCount];
for (int idxWord = 0; idxWord < wordCount; idxWord++) {
for (int idxSeq = 0; idxSeq < sequencesCount; idxSeq++) { for (int idxSeq = 0; idxSeq < sequencesCount; idxSeq++) {
double likeliness = Double.parseDouble(in.readLine()); double likeliness = Double.parseDouble(in.readLine());
likelinesses[idxWord][idxSeq] = likeliness; likelinesses[idxWord][idxSeq] = likeliness;
...@@ -211,8 +209,18 @@ public class DTMAnalyzer extends Analyzer { ...@@ -211,8 +209,18 @@ public class DTMAnalyzer extends Analyzer {
in.close(); in.close();
// compare to current word <-> topic mapping, accept higher
// likeliness as better
for (int idxWord = 0; idxWord < maxWordLikelinesses.length; idxWord++) {
Tuple<Double, Integer> tuple = wordTopicMapping[idxWord];
if (tuple == null)
wordTopicMapping[idxWord] = new Tuple<>(maxWordLikelinesses[idxWord], i);
else if (maxWordLikelinesses[idxWord] > tuple.first())
tuple.setSecond(i);
}
// find maximum overall likeliness // find maximum overall likeliness
double maxOverallLikeliness = 0; double maxOverallLikeliness = Integer.MIN_VALUE;
for (double likeliness : maxSeqLikelinesses) { for (double likeliness : maxSeqLikelinesses) {
if (likeliness > maxOverallLikeliness) if (likeliness > maxOverallLikeliness)
maxOverallLikeliness = likeliness; maxOverallLikeliness = likeliness;
...@@ -220,7 +228,7 @@ public class DTMAnalyzer extends Analyzer { ...@@ -220,7 +228,7 @@ public class DTMAnalyzer extends Analyzer {
// static topic // static topic
// most likely words form the static topic over all sequences // most likely words form the static topic over all sequences
for (int idxWord = 0; idxWord < wordsCount; idxWord++) { for (int idxWord = 0; idxWord < wordCount; idxWord++) {
if (maxWordLikelinesses[idxWord] >= Constants.MINIMUM_RELATIVE_PROB * maxOverallLikeliness) { if (maxWordLikelinesses[idxWord] >= Constants.MINIMUM_RELATIVE_PROB * maxOverallLikeliness) {
Word newWord = new Word(vocab.get(idxWord)); Word newWord = new Word(vocab.get(idxWord));
newWords.add(newWord); newWords.add(newWord);
...@@ -234,8 +242,8 @@ public class DTMAnalyzer extends Analyzer { ...@@ -234,8 +242,8 @@ public class DTMAnalyzer extends Analyzer {
// the minimum relative word likeliness // the minimum relative word likeliness
for (int idxSeq = 0; idxSeq < sequencesCount; idxSeq++) { for (int idxSeq = 0; idxSeq < sequencesCount; idxSeq++) {
double maxLikeliness = maxSeqLikelinesses[idxSeq]; double maxLikeliness = maxSeqLikelinesses[idxSeq];
List<TopicWord> newSeqTopicWords = new ArrayList<>(wordsCount); List<TopicWord> newSeqTopicWords = new ArrayList<>(wordCount);
for (int idxWord = 0; idxWord < wordsCount; idxWord++) { for (int idxWord = 0; idxWord < wordCount; idxWord++) {
double likeliness = likelinesses[idxWord][idxSeq]; double likeliness = likelinesses[idxWord][idxSeq];
if (likeliness >= Constants.MINIMUM_RELATIVE_PROB * maxLikeliness) { if (likeliness >= Constants.MINIMUM_RELATIVE_PROB * maxLikeliness) {
Word newWord = new Word(vocab.get(idxWord)); Word newWord = new Word(vocab.get(idxWord));
...@@ -272,14 +280,18 @@ public class DTMAnalyzer extends Analyzer { ...@@ -272,14 +280,18 @@ public class DTMAnalyzer extends Analyzer {
// for each article in the model file // for each article in the model file
while ((line = in.readLine()) != null) { while ((line = in.readLine()) != null) {
// extract unique word ids and count // get topic id from word id, count topics
CountMap<Integer> countMap = new CountMap<>(); CountMap<Integer> countMap = new CountMap<>();
Matcher matcher = wordCountPattern.matcher(line); Matcher matcher = wordCountPattern.matcher(line);
double totalCount = 0; double totalCount = 0;
while (matcher.find()) { while (matcher.find()) {
int count = Integer.parseInt(matcher.group(2)); Integer wordId = Integer.parseInt(matcher.group(1));
countMap.count(Integer.parseInt(matcher.group(1)), count); Tuple<Double, Integer> wordTopicTuple = wordTopicMapping[wordId];
totalCount += count; if (wordTopicTuple != null) {
int count = Integer.parseInt(matcher.group(2));
countMap.count(wordTopicTuple.second(), count);
totalCount += count;
}
} }
// create list of topics refs referencing topics with counted // create list of topics refs referencing topics with counted
...@@ -290,14 +302,15 @@ public class DTMAnalyzer extends Analyzer { ...@@ -290,14 +302,15 @@ public class DTMAnalyzer extends Analyzer {
// check if topic above threshold // check if topic above threshold
if ((entry.getValue() / totalCount) >= Constants.TOPIC_THRESHOLD) { if ((entry.getValue() / totalCount) >= Constants.TOPIC_THRESHOLD) {
reducedCount += entry.getValue(); reducedCount += entry.getValue();
TopicFull topic = null; TopicFull topic = newTopics.get(entry.getKey());
// TODO find topic of this word // TODO words with low relative likeliness are ignored.
if (topic != null) { // topic references from this file are possibly wrong.
TopicRef ref = new TopicRef(); // fix this by checking if the word is actually accepted
ref.setCount(entry.getValue()); // by the referenced topic.
ref.setTopic(new Topic(topic.getId())); TopicRef ref = new TopicRef();
newTopicRefs.add(ref); ref.setCount(entry.getValue());
} ref.setTopic(new Topic(topic.getId()));
newTopicRefs.add(ref);
} }
} }
......
...@@ -4,5 +4,6 @@ db.name=test ...@@ -4,5 +4,6 @@ db.name=test
es.host=localhost es.host=localhost
es.port=9300 es.port=9300
tm.processor=corenlp tm.processor=corenlp
tm.analyzer=jgibb tm.analyzer=dtm
tm.dtmpath=/home/eike/repos/master/dtm_release/dtm/main tm.dtmpath=/home/eike/repos/master/ma-impl/dtm_release/dtm/main
\ No newline at end of file tm.windowresolution=monthly
\ No newline at end of file
...@@ -37,4 +37,12 @@ public class CountMap<T> { ...@@ -37,4 +37,12 @@ public class CountMap<T> {
return map.size(); return map.size();
} }
public Integer get(T key) {
return map.get(key);
}
public Set<T> keySet() {
return map.keySet();
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment