diff --git a/jgibblda/src/jgibblda/Constants.java b/jgibblda/src/jgibblda/Constants.java index 93b104457a1df0037d9f8dbd29771019b8419ab4..c65e53a8cd51d4ea405acdbdccd52f467195b272 100644 --- a/jgibblda/src/jgibblda/Constants.java +++ b/jgibblda/src/jgibblda/Constants.java @@ -31,7 +31,7 @@ package jgibblda; public class Constants { public static final long BUFFER_SIZE_LONG = 1000000; public static final short BUFFER_SIZE_SHORT = 512; - + public static final int MODEL_STATUS_UNKNOWN = 0; public static final int MODEL_STATUS_EST = 1; public static final int MODEL_STATUS_ESTC = 2; diff --git a/jgibblda/src/jgibblda/Conversion.java b/jgibblda/src/jgibblda/Conversion.java index 879871b40f47e5bc0fbdd7c434687e380c449e9f..5834e771b429b6f240e6bf64817c6008e49cd5c7 100644 --- a/jgibblda/src/jgibblda/Conversion.java +++ b/jgibblda/src/jgibblda/Conversion.java @@ -29,13 +29,12 @@ package jgibblda; public class Conversion { - public static String ZeroPad( int number, int width ) - { - StringBuffer result = new StringBuffer(""); - for( int i = 0; i < width-Integer.toString(number).length(); i++ ) - result.append( "0" ); - result.append( Integer.toString(number) ); - - return result.toString(); + public static String ZeroPad(int number, int width) { + StringBuffer result = new StringBuffer(""); + for (int i = 0; i < width - Integer.toString(number).length(); i++) + result.append("0"); + result.append(Integer.toString(number)); + + return result.toString(); } } diff --git a/jgibblda/src/jgibblda/Dictionary.java b/jgibblda/src/jgibblda/Dictionary.java index 842e3f66711153ef93341b747d032a50a292b6df..f8550fdf87649a9cca2c7b506f725b181f239a9c 100644 --- a/jgibblda/src/jgibblda/Dictionary.java +++ b/jgibblda/src/jgibblda/Dictionary.java @@ -39,129 +39,127 @@ import java.util.Map; import java.util.StringTokenizer; public class Dictionary { - public Map<String,Integer> word2id; + public Map<String, Integer> word2id; public Map<Integer, String> id2word; - - //-------------------------------------------------- + + // -------------------------------------------------- // constructors - //-------------------------------------------------- - - public Dictionary(){ + // -------------------------------------------------- + + public Dictionary() { word2id = new HashMap<String, Integer>(); id2word = new HashMap<Integer, String>(); } - - //--------------------------------------------------- + + // --------------------------------------------------- // get/set methods - //--------------------------------------------------- - - public String getWord(int id){ + // --------------------------------------------------- + + public String getWord(int id) { return id2word.get(id); } - - public Integer getID (String word){ + + public Integer getID(String word) { return word2id.get(word); } - - //---------------------------------------------------- + + // ---------------------------------------------------- // checking methods - //---------------------------------------------------- + // ---------------------------------------------------- /** * check if this dictionary contains a specified word */ - public boolean contains(String word){ + public boolean contains(String word) { return word2id.containsKey(word); } - - public boolean contains(int id){ + + public boolean contains(int id) { return id2word.containsKey(id); } - //--------------------------------------------------- + + // --------------------------------------------------- // manupulating methods - //--------------------------------------------------- + // --------------------------------------------------- /** - * add a word into this dictionary - * return the corresponding id + * add a word into this dictionary return the corresponding id */ - public int addWord(String word){ - if (!contains(word)){ + public int addWord(String word) { + if (!contains(word)) { int id = word2id.size(); - + word2id.put(word, id); - id2word.put(id,word); - + id2word.put(id, word); + return id; - } - else return getID(word); + } else + return getID(word); } - - //--------------------------------------------------- + + // --------------------------------------------------- // I/O methods - //--------------------------------------------------- + // --------------------------------------------------- /** * read dictionary from file */ - public boolean readWordMap(String wordMapFile){ - try{ - BufferedReader reader = new BufferedReader(new InputStreamReader( - new FileInputStream(wordMapFile), "UTF-8")); + public boolean readWordMap(String wordMapFile) { + try { + BufferedReader reader = new BufferedReader( + new InputStreamReader(new FileInputStream(wordMapFile), "UTF-8")); String line; - - //read the number of words - line = reader.readLine(); + + // read the number of words + line = reader.readLine(); int nwords = Integer.parseInt(line); - - //read map - for (int i = 0; i < nwords; ++i){ + + // read map + for (int i = 0; i < nwords; ++i) { line = reader.readLine(); StringTokenizer tknr = new StringTokenizer(line, " \t\n\r"); - - if (tknr.countTokens() != 2) continue; - + + if (tknr.countTokens() != 2) + continue; + String word = tknr.nextToken(); String id = tknr.nextToken(); int intID = Integer.parseInt(id); - + id2word.put(intID, word); word2id.put(word, intID); } - + reader.close(); return true; - } - catch (Exception e){ + } catch (Exception e) { System.out.println("Error while reading dictionary:" + e.getMessage()); e.printStackTrace(); return false; - } + } } - - public boolean writeWordMap(String wordMapFile){ - try{ - BufferedWriter writer = new BufferedWriter(new OutputStreamWriter( - new FileOutputStream(wordMapFile), "UTF-8")); - - //write number of words + + public boolean writeWordMap(String wordMapFile) { + try { + BufferedWriter writer = new BufferedWriter( + new OutputStreamWriter(new FileOutputStream(wordMapFile), "UTF-8")); + + // write number of words writer.write(word2id.size() + "\n"); - - //write word to id + + // write word to id Iterator<String> it = word2id.keySet().iterator(); - while (it.hasNext()){ + while (it.hasNext()) { String key = it.next(); Integer value = word2id.get(key); - + writer.write(key + " " + value + "\n"); } - + writer.close(); return true; - } - catch (Exception e){ + } catch (Exception e) { System.out.println("Error while writing word map " + e.getMessage()); e.printStackTrace(); return false; } - - + } } diff --git a/jgibblda/src/jgibblda/Document.java b/jgibblda/src/jgibblda/Document.java index 679d568a98691f5a82f98c8bc1dca5b7b100e7a8..c9faf020f3fff5b48802a265e11c98c2557c2aa9 100644 --- a/jgibblda/src/jgibblda/Document.java +++ b/jgibblda/src/jgibblda/Document.java @@ -32,62 +32,62 @@ import java.util.Vector; public class Document { - //---------------------------------------------------- - //Instance Variables - //---------------------------------------------------- - public int [] words; + // ---------------------------------------------------- + // Instance Variables + // ---------------------------------------------------- + public int[] words; public String rawStr; public int length; - - //---------------------------------------------------- - //Constructors - //---------------------------------------------------- - public Document(){ + + // ---------------------------------------------------- + // Constructors + // ---------------------------------------------------- + public Document() { words = null; rawStr = ""; length = 0; } - - public Document(int length){ + + public Document(int length) { this.length = length; rawStr = ""; words = new int[length]; } - - public Document(int length, int [] words){ + + public Document(int length, int[] words) { this.length = length; rawStr = ""; - + this.words = new int[length]; - for (int i =0 ; i < length; ++i){ + for (int i = 0; i < length; ++i) { this.words[i] = words[i]; } } - - public Document(int length, int [] words, String rawStr){ + + public Document(int length, int[] words, String rawStr) { this.length = length; this.rawStr = rawStr; - + this.words = new int[length]; - for (int i =0 ; i < length; ++i){ + for (int i = 0; i < length; ++i) { this.words[i] = words[i]; } } - - public Document(Vector<Integer> doc){ + + public Document(Vector<Integer> doc) { this.length = doc.size(); rawStr = ""; this.words = new int[length]; - for (int i = 0; i < length; i++){ + for (int i = 0; i < length; i++) { this.words[i] = doc.get(i); } } - - public Document(Vector<Integer> doc, String rawStr){ + + public Document(Vector<Integer> doc, String rawStr) { this.length = doc.size(); this.rawStr = rawStr; this.words = new int[length]; - for (int i = 0; i < length; ++i){ + for (int i = 0; i < length; ++i) { this.words[i] = doc.get(i); } } diff --git a/jgibblda/src/jgibblda/Estimator.java b/jgibblda/src/jgibblda/Estimator.java index 24f9b85efcd5ad7c4260ace6b9c614fa76f82693..c26a3ecdeb36a677098c46ac4824afa21ff562f7 100644 --- a/jgibblda/src/jgibblda/Estimator.java +++ b/jgibblda/src/jgibblda/Estimator.java @@ -32,55 +32,54 @@ import java.io.File; import java.util.Vector; public class Estimator { - + // output model protected Model trnModel; LDACmdOption option; - - public boolean init(LDACmdOption option){ + + public boolean init(LDACmdOption option) { this.option = option; trnModel = new Model(); - - if (option.est){ + + if (option.est) { if (!trnModel.initNewModel(option)) return false; trnModel.data.localDict.writeWordMap(option.dir + File.separator + option.wordMapFileName); - } - else if (option.estc){ + } else if (option.estc) { if (!trnModel.initEstimatedModel(option)) return false; } - + return true; } - - public void estimate(){ + + public void estimate() { System.out.println("Sampling " + trnModel.niters + " iteration!"); - + int lastIter = trnModel.liter; - for (trnModel.liter = lastIter + 1; trnModel.liter < trnModel.niters + lastIter; trnModel.liter++){ + for (trnModel.liter = lastIter + 1; trnModel.liter < trnModel.niters + lastIter; trnModel.liter++) { System.out.println("Iteration " + trnModel.liter + " ..."); - + // for all z_i - for (int m = 0; m < trnModel.M; m++){ - for (int n = 0; n < trnModel.data.docs[m].length; n++){ + for (int m = 0; m < trnModel.M; m++) { + for (int n = 0; n < trnModel.data.docs[m].length; n++) { // z_i = z[m][n] // sample from p(z_i|z_-i, w) int topic = sampling(m, n); trnModel.z[m].set(n, topic); - }// end for each word - }// end for each document - - if (option.savestep > 0){ - if (trnModel.liter % option.savestep == 0){ + } // end for each word + } // end for each document + + if (option.savestep > 0) { + if (trnModel.liter % option.savestep == 0) { System.out.println("Saving the model at iteration " + trnModel.liter + " ..."); computeTheta(); computePhi(); trnModel.saveModel("model-" + Conversion.ZeroPad(trnModel.liter, 5)); } } - }// end iterations - + } // end iterations + System.out.println("Gibbs sampling completed!\n"); System.out.println("Saving the final model!\n"); computeTheta(); @@ -88,66 +87,71 @@ public class Estimator { trnModel.liter--; trnModel.saveModel("model-final"); } - + /** * Do sampling - * @param m document number - * @param n word number + * + * @param m + * document number + * @param n + * word number * @return topic id */ - public int sampling(int m, int n){ + public int sampling(int m, int n) { // remove z_i from the count variable int topic = trnModel.z[m].get(n); int w = trnModel.data.docs[m].words[n]; - + trnModel.nw[w][topic] -= 1; trnModel.nd[m][topic] -= 1; trnModel.nwsum[topic] -= 1; trnModel.ndsum[m] -= 1; - + double Vbeta = trnModel.V * trnModel.beta; double Kalpha = trnModel.K * trnModel.alpha; - - //do multinominal sampling via cumulative method - for (int k = 0; k < trnModel.K; k++){ - trnModel.p[k] = (trnModel.nw[w][k] + trnModel.beta)/(trnModel.nwsum[k] + Vbeta) * - (trnModel.nd[m][k] + trnModel.alpha)/(trnModel.ndsum[m] + Kalpha); + + // do multinominal sampling via cumulative method + for (int k = 0; k < trnModel.K; k++) { + trnModel.p[k] = (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + Vbeta) + * (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + Kalpha); } - + // cumulate multinomial parameters - for (int k = 1; k < trnModel.K; k++){ + for (int k = 1; k < trnModel.K; k++) { trnModel.p[k] += trnModel.p[k - 1]; } - + // scaled sample because of unnormalized p[] double u = Math.random() * trnModel.p[trnModel.K - 1]; - - for (topic = 0; topic < trnModel.K; topic++){ - if (trnModel.p[topic] > u) //sample topic w.r.t distribution p + + for (topic = 0; topic < trnModel.K; topic++) { + if (trnModel.p[topic] > u) // sample topic w.r.t distribution p break; } - + // add newly estimated z_i to count variables trnModel.nw[w][topic] += 1; trnModel.nd[m][topic] += 1; trnModel.nwsum[topic] += 1; trnModel.ndsum[m] += 1; - - return topic; + + return topic; } - - public void computeTheta(){ - for (int m = 0; m < trnModel.M; m++){ - for (int k = 0; k < trnModel.K; k++){ - trnModel.theta[m][k] = (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + trnModel.K * trnModel.alpha); + + public void computeTheta() { + for (int m = 0; m < trnModel.M; m++) { + for (int k = 0; k < trnModel.K; k++) { + trnModel.theta[m][k] = (trnModel.nd[m][k] + trnModel.alpha) + / (trnModel.ndsum[m] + trnModel.K * trnModel.alpha); } } } - - public void computePhi(){ - for (int k = 0; k < trnModel.K; k++){ - for (int w = 0; w < trnModel.V; w++){ - trnModel.phi[k][w] = (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + trnModel.V * trnModel.beta); + + public void computePhi() { + for (int k = 0; k < trnModel.K; k++) { + for (int w = 0; w < trnModel.V; w++) { + trnModel.phi[k][w] = (trnModel.nw[w][k] + trnModel.beta) + / (trnModel.nwsum[k] + trnModel.V * trnModel.beta); } } } diff --git a/jgibblda/src/jgibblda/Inferencer.java b/jgibblda/src/jgibblda/Inferencer.java index 2248db6627760f98f3e5d5b471d5f787ed98b03e..3811c582e7f1db4bf50671a8e89beb5334a233a0 100644 --- a/jgibblda/src/jgibblda/Inferencer.java +++ b/jgibblda/src/jgibblda/Inferencer.java @@ -35,115 +35,114 @@ import java.io.InputStreamReader; import java.util.StringTokenizer; import java.util.Vector; -public class Inferencer { +public class Inferencer { // Train model public Model trnModel; public Dictionary globalDict; private LDACmdOption option; - + private Model newModel; public int niters = 100; - - //----------------------------------------------------- + + // ----------------------------------------------------- // Init method - //----------------------------------------------------- - public boolean init(LDACmdOption option){ + // ----------------------------------------------------- + public boolean init(LDACmdOption option) { this.option = option; trnModel = new Model(); - + if (!trnModel.initEstimatedModel(option)) - return false; - + return false; + globalDict = trnModel.data.localDict; computeTrnTheta(); computeTrnPhi(); - + return true; } - - //inference new model ~ getting data from a specified dataset - public Model inference( LDADataset newData){ + + // inference new model ~ getting data from a specified dataset + public Model inference(LDADataset newData) { System.out.println("init new model"); - Model newModel = new Model(); - - newModel.initNewModel(option, newData, trnModel); - this.newModel = newModel; - - System.out.println("Sampling " + niters + " iteration for inference!"); - for (newModel.liter = 1; newModel.liter <= niters; newModel.liter++){ - //System.out.println("Iteration " + newModel.liter + " ..."); - + Model newModel = new Model(); + + newModel.initNewModel(option, newData, trnModel); + this.newModel = newModel; + + System.out.println("Sampling " + niters + " iteration for inference!"); + for (newModel.liter = 1; newModel.liter <= niters; newModel.liter++) { + // System.out.println("Iteration " + newModel.liter + " ..."); + // for all newz_i - for (int m = 0; m < newModel.M; ++m){ - for (int n = 0; n < newModel.data.docs[m].length; n++){ + for (int m = 0; m < newModel.M; ++m) { + for (int n = 0; n < newModel.data.docs[m].length; n++) { // (newz_i = newz[m][n] // sample from p(z_i|z_-1,w) int topic = infSampling(m, n); newModel.z[m].set(n, topic); } - }//end foreach new doc - - }// end iterations - + } // end foreach new doc + + } // end iterations + System.out.println("Gibbs sampling for inference completed!"); - + computeNewTheta(); computeNewPhi(); newModel.liter--; return this.newModel; } - - public Model inference(String [] strs){ - //System.out.println("inference"); + + public Model inference(String[] strs) { + // System.out.println("inference"); Model newModel = new Model(); - - //System.out.println("read dataset"); + + // System.out.println("read dataset"); LDADataset dataset = LDADataset.readDataSet(strs, globalDict); - + return inference(dataset); } - - //inference new model ~ getting dataset from file specified in option - public Model inference(){ - //System.out.println("inference"); - + + // inference new model ~ getting dataset from file specified in option + public Model inference() { + // System.out.println("inference"); + newModel = new Model(); - if (!newModel.initNewModel(option, trnModel)) return null; - + if (!newModel.initNewModel(option, trnModel)) + return null; + System.out.println("Sampling " + niters + " iteration for inference!"); - - for (newModel.liter = 1; newModel.liter <= niters; newModel.liter++){ - //System.out.println("Iteration " + newModel.liter + " ..."); - + + for (newModel.liter = 1; newModel.liter <= niters; newModel.liter++) { + // System.out.println("Iteration " + newModel.liter + " ..."); + // for all newz_i - for (int m = 0; m < newModel.M; ++m){ - for (int n = 0; n < newModel.data.docs[m].length; n++){ + for (int m = 0; m < newModel.M; ++m) { + for (int n = 0; n < newModel.data.docs[m].length; n++) { // (newz_i = newz[m][n] // sample from p(z_i|z_-1,w) int topic = infSampling(m, n); newModel.z[m].set(n, topic); } - }//end foreach new doc - - }// end iterations - - System.out.println("Gibbs sampling for inference completed!"); + } // end foreach new doc + + } // end iterations + + System.out.println("Gibbs sampling for inference completed!"); System.out.println("Saving the inference outputs!"); - + computeNewTheta(); computeNewPhi(); newModel.liter--; - newModel.saveModel(newModel.dfile + "." + newModel.modelName); - + newModel.saveModel(newModel.dfile + "." + newModel.modelName); + return newModel; } - + /** - * do sampling for inference - * m: document number - * n: word number? + * do sampling for inference m: document number n: word number? */ - protected int infSampling(int m, int n){ + protected int infSampling(int m, int n) { // remove z_i from the count variables int topic = newModel.z[m].get(n); int _w = newModel.data.docs[m].words[n]; @@ -152,70 +151,75 @@ public class Inferencer { newModel.nd[m][topic] -= 1; newModel.nwsum[topic] -= 1; newModel.ndsum[m] -= 1; - + double Vbeta = trnModel.V * newModel.beta; double Kalpha = trnModel.K * newModel.alpha; - - // do multinomial sampling via cummulative method - for (int k = 0; k < newModel.K; k++){ - newModel.p[k] = (trnModel.nw[w][k] + newModel.nw[_w][k] + newModel.beta)/(trnModel.nwsum[k] + newModel.nwsum[k] + Vbeta) * - (newModel.nd[m][k] + newModel.alpha)/(newModel.ndsum[m] + Kalpha); + + // do multinomial sampling via cummulative method + for (int k = 0; k < newModel.K; k++) { + newModel.p[k] = (trnModel.nw[w][k] + newModel.nw[_w][k] + newModel.beta) + / (trnModel.nwsum[k] + newModel.nwsum[k] + Vbeta) * (newModel.nd[m][k] + newModel.alpha) + / (newModel.ndsum[m] + Kalpha); } - + // cummulate multinomial parameters - for (int k = 1; k < newModel.K; k++){ + for (int k = 1; k < newModel.K; k++) { newModel.p[k] += newModel.p[k - 1]; } - + // scaled sample because of unnormalized p[] double u = Math.random() * newModel.p[newModel.K - 1]; - - for (topic = 0; topic < newModel.K; topic++){ + + for (topic = 0; topic < newModel.K; topic++) { if (newModel.p[topic] > u) break; } - + // add newly estimated z_i to count variables newModel.nw[_w][topic] += 1; newModel.nd[m][topic] += 1; newModel.nwsum[topic] += 1; newModel.ndsum[m] += 1; - + return topic; } - - protected void computeNewTheta(){ - for (int m = 0; m < newModel.M; m++){ - for (int k = 0; k < newModel.K; k++){ - newModel.theta[m][k] = (newModel.nd[m][k] + newModel.alpha) / (newModel.ndsum[m] + newModel.K * newModel.alpha); - }//end foreach topic - }//end foreach new document + + protected void computeNewTheta() { + for (int m = 0; m < newModel.M; m++) { + for (int k = 0; k < newModel.K; k++) { + newModel.theta[m][k] = (newModel.nd[m][k] + newModel.alpha) + / (newModel.ndsum[m] + newModel.K * newModel.alpha); + } // end foreach topic + } // end foreach new document } - - protected void computeNewPhi(){ - for (int k = 0; k < newModel.K; k++){ - for (int _w = 0; _w < newModel.V; _w++){ + + protected void computeNewPhi() { + for (int k = 0; k < newModel.K; k++) { + for (int _w = 0; _w < newModel.V; _w++) { Integer id = newModel.data.lid2gid.get(_w); - - if (id != null){ - newModel.phi[k][_w] = (trnModel.nw[id][k] + newModel.nw[_w][k] + newModel.beta) / (newModel.nwsum[k] + newModel.nwsum[k] + trnModel.V * newModel.beta); + + if (id != null) { + newModel.phi[k][_w] = (trnModel.nw[id][k] + newModel.nw[_w][k] + newModel.beta) + / (newModel.nwsum[k] + newModel.nwsum[k] + trnModel.V * newModel.beta); } - }//end foreach word - }// end foreach topic + } // end foreach word + } // end foreach topic } - - protected void computeTrnTheta(){ - for (int m = 0; m < trnModel.M; m++){ - for (int k = 0; k < trnModel.K; k++){ - trnModel.theta[m][k] = (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + trnModel.K * trnModel.alpha); + + protected void computeTrnTheta() { + for (int m = 0; m < trnModel.M; m++) { + for (int k = 0; k < trnModel.K; k++) { + trnModel.theta[m][k] = (trnModel.nd[m][k] + trnModel.alpha) + / (trnModel.ndsum[m] + trnModel.K * trnModel.alpha); } } } - - protected void computeTrnPhi(){ - for (int k = 0; k < trnModel.K; k++){ - for (int w = 0; w < trnModel.V; w++){ - trnModel.phi[k][w] = (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + trnModel.V * trnModel.beta); + + protected void computeTrnPhi() { + for (int k = 0; k < trnModel.K; k++) { + for (int w = 0; w < trnModel.V; w++) { + trnModel.phi[k][w] = (trnModel.nw[w][k] + trnModel.beta) + / (trnModel.nwsum[k] + trnModel.V * trnModel.beta); } } } diff --git a/jgibblda/src/jgibblda/LDA.java b/jgibblda/src/jgibblda/LDA.java index c6ca2a26cf2d3e3f7163eeb25044f637dc410ec9..cd950a21f259a71e626926d021b688dc68e586dd 100644 --- a/jgibblda/src/jgibblda/LDA.java +++ b/jgibblda/src/jgibblda/LDA.java @@ -31,54 +31,51 @@ package jgibblda; import org.kohsuke.args4j.*; public class LDA { - - public static void main(String args[]){ + + public static void main(String args[]) { LDACmdOption option = new LDACmdOption(); CmdLineParser parser = new CmdLineParser(option); - + try { - if (args.length == 0){ + if (args.length == 0) { showHelp(parser); return; } - + parser.parseArgument(args); - - if (option.est || option.estc){ + + if (option.est || option.estc) { Estimator estimator = new Estimator(); estimator.init(option); estimator.estimate(); - } - else if (option.inf){ + } else if (option.inf) { Inferencer inferencer = new Inferencer(); inferencer.init(option); - + Model newModel = inferencer.inference(); - - for (int i = 0; i < newModel.phi.length; ++i){ - //phi: K * V - System.out.println("-----------------------\ntopic" + i + " : "); - for (int j = 0; j < 10; ++j){ + + for (int i = 0; i < newModel.phi.length; ++i) { + // phi: K * V + System.out.println("-----------------------\ntopic" + i + " : "); + for (int j = 0; j < 10; ++j) { System.out.println(inferencer.globalDict.id2word.get(j) + "\t" + newModel.phi[i][j]); } } } - } - catch (CmdLineException cle){ + } catch (CmdLineException cle) { System.out.println("Command line error: " + cle.getMessage()); showHelp(parser); return; - } - catch (Exception e){ + } catch (Exception e) { System.out.println("Error in main: " + e.getMessage()); e.printStackTrace(); return; } } - - public static void showHelp(CmdLineParser parser){ + + public static void showHelp(CmdLineParser parser) { System.out.println("LDA [options ...] [arguments...]"); parser.printUsage(System.out); } - + } diff --git a/jgibblda/src/jgibblda/LDACmdOption.java b/jgibblda/src/jgibblda/LDACmdOption.java index bc330beef00a417f8b69cd72b1a24815926ee347..6d29a5e572cc61b454a21d084fac9c95cb6c3294 100644 --- a/jgibblda/src/jgibblda/LDACmdOption.java +++ b/jgibblda/src/jgibblda/LDACmdOption.java @@ -3,46 +3,46 @@ package jgibblda; import org.kohsuke.args4j.*; public class LDACmdOption { - - @Option(name="-est", usage="Specify whether we want to estimate model from scratch") + + @Option(name = "-est", usage = "Specify whether we want to estimate model from scratch") public boolean est = false; - - @Option(name="-estc", usage="Specify whether we want to continue the last estimation") + + @Option(name = "-estc", usage = "Specify whether we want to continue the last estimation") public boolean estc = false; - - @Option(name="-inf", usage="Specify whether we want to do inference") + + @Option(name = "-inf", usage = "Specify whether we want to do inference") public boolean inf = true; - - @Option(name="-dir", usage="Specify directory") + + @Option(name = "-dir", usage = "Specify directory") public String dir = ""; - - @Option(name="-dfile", usage="Specify data file") + + @Option(name = "-dfile", usage = "Specify data file") public String dfile = ""; - - @Option(name="-model", usage="Specify the model name") + + @Option(name = "-model", usage = "Specify the model name") public String modelName = ""; - - @Option(name="-alpha", usage="Specify alpha") + + @Option(name = "-alpha", usage = "Specify alpha") public double alpha = -1.0; - - @Option(name="-beta", usage="Specify beta") + + @Option(name = "-beta", usage = "Specify beta") public double beta = -1.0; - - @Option(name="-ntopics", usage="Specify the number of topics") + + @Option(name = "-ntopics", usage = "Specify the number of topics") public int K = 100; - - @Option(name="-niters", usage="Specify the number of iterations") + + @Option(name = "-niters", usage = "Specify the number of iterations") public int niters = 1000; - - @Option(name="-savestep", usage="Specify the number of steps to save the model since the last save") + + @Option(name = "-savestep", usage = "Specify the number of steps to save the model since the last save") public int savestep = 100; - - @Option(name="-twords", usage="Specify the number of most likely words to be printed for each topic") + + @Option(name = "-twords", usage = "Specify the number of most likely words to be printed for each topic") public int twords = 100; - - @Option(name="-withrawdata", usage="Specify whether we include raw data in the input") + + @Option(name = "-withrawdata", usage = "Specify whether we include raw data in the input") public boolean withrawdata = false; - - @Option(name="-wordmap", usage="Specify the wordmap file") + + @Option(name = "-wordmap", usage = "Specify the wordmap file") public String wordMapFileName = "wordmap.txt"; } diff --git a/jgibblda/src/jgibblda/LDADataset.java b/jgibblda/src/jgibblda/LDADataset.java index d56f96b3b961782466b6384d5ed40dd867e3abbf..d820974505abade6554247324afeabbc28fb967e 100644 --- a/jgibblda/src/jgibblda/LDADataset.java +++ b/jgibblda/src/jgibblda/LDADataset.java @@ -35,241 +35,258 @@ import java.util.Map; import java.util.Vector; public class LDADataset { - //--------------------------------------------------------------- + // --------------------------------------------------------------- // Instance Variables - //--------------------------------------------------------------- - - public Dictionary localDict; // local dictionary - public Document [] docs; // a list of documents - public int M; // number of documents - public int V; // number of words - - // map from local coordinates (id) to global ones + // --------------------------------------------------------------- + + public Dictionary localDict; // local dictionary + public Document[] docs; // a list of documents + public int M; // number of documents + public int V; // number of words + + // map from local coordinates (id) to global ones // null if the global dictionary is not set - public Map<Integer, Integer> lid2gid; - - //link to a global dictionary (optional), null for train data, not null for test data - public Dictionary globalDict; - - //-------------------------------------------------------------- + public Map<Integer, Integer> lid2gid; + + // link to a global dictionary (optional), null for train data, not null for + // test data + public Dictionary globalDict; + + // -------------------------------------------------------------- // Constructor - //-------------------------------------------------------------- - public LDADataset(){ + // -------------------------------------------------------------- + public LDADataset() { localDict = new Dictionary(); M = 0; V = 0; docs = null; - + globalDict = null; lid2gid = null; } - - public LDADataset(int M){ + + public LDADataset(int M) { localDict = new Dictionary(); this.M = M; this.V = 0; - docs = new Document[M]; - + docs = new Document[M]; + globalDict = null; lid2gid = null; } - - public LDADataset(int M, Dictionary globalDict){ - localDict = new Dictionary(); + + public LDADataset(int M, Dictionary globalDict) { + localDict = new Dictionary(); this.M = M; this.V = 0; - docs = new Document[M]; - + docs = new Document[M]; + this.globalDict = globalDict; lid2gid = new HashMap<Integer, Integer>(); } - - //------------------------------------------------------------- - //Public Instance Methods - //------------------------------------------------------------- + + // ------------------------------------------------------------- + // Public Instance Methods + // ------------------------------------------------------------- /** - * set the document at the index idx if idx is greater than 0 and less than M - * @param doc document to be set - * @param idx index in the document array - */ - public void setDoc(Document doc, int idx){ - if (0 <= idx && idx < M){ + * set the document at the index idx if idx is greater than 0 and less than + * M + * + * @param doc + * document to be set + * @param idx + * index in the document array + */ + public void setDoc(Document doc, int idx) { + if (0 <= idx && idx < M) { docs[idx] = doc; } } + /** - * set the document at the index idx if idx is greater than 0 and less than M - * @param str string contains doc - * @param idx index in the document array + * set the document at the index idx if idx is greater than 0 and less than + * M + * + * @param str + * string contains doc + * @param idx + * index in the document array */ - public void setDoc(String str, int idx){ - if (0 <= idx && idx < M){ - String [] words = str.split("[ \\t\\n]"); - + public void setDoc(String str, int idx) { + if (0 <= idx && idx < M) { + String[] words = str.split("[ \\t\\n]"); + Vector<Integer> ids = new Vector<Integer>(); - - for (String word : words){ + + for (String word : words) { int _id = localDict.word2id.size(); - - if (localDict.contains(word)) + + if (localDict.contains(word)) _id = localDict.getID(word); - - if (globalDict != null){ - //get the global id + + if (globalDict != null) { + // get the global id Integer id = globalDict.getID(word); - //System.out.println(id); - - if (id != null){ + // System.out.println(id); + + if (id != null) { localDict.addWord(word); - + lid2gid.put(_id, id); ids.add(_id); + } else { // not in global dictionary + // do nothing currently } - else { //not in global dictionary - //do nothing currently - } - } - else { + } else { localDict.addWord(word); ids.add(_id); } } - + Document doc = new Document(ids, str); docs[idx] = doc; - V = localDict.word2id.size(); + V = localDict.word2id.size(); } } - //--------------------------------------------------------------- + // --------------------------------------------------------------- // I/O methods - //--------------------------------------------------------------- - + // --------------------------------------------------------------- + /** - * read a dataset from a stream, create new dictionary - * @return dataset if success and null otherwise + * read a dataset from a stream, create new dictionary + * + * @return dataset if success and null otherwise */ - public static LDADataset readDataSet(String filename){ + public static LDADataset readDataSet(String filename) { try { - BufferedReader reader = new BufferedReader(new InputStreamReader( - new FileInputStream(filename), "UTF-8")); - + BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(filename), "UTF-8")); + LDADataset data = readDataSet(reader); - + reader.close(); return data; - } - catch (Exception e){ + } catch (Exception e) { System.out.println("Read Dataset Error: " + e.getMessage()); e.printStackTrace(); return null; } } - + /** * read a dataset from a file with a preknown vocabulary - * @param filename file from which we read dataset - * @param dict the dictionary + * + * @param filename + * file from which we read dataset + * @param dict + * the dictionary * @return dataset if success and null otherwise */ - public static LDADataset readDataSet(String filename, Dictionary dict){ + public static LDADataset readDataSet(String filename, Dictionary dict) { try { - BufferedReader reader = new BufferedReader(new InputStreamReader( - new FileInputStream(filename), "UTF-8")); + BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(filename), "UTF-8")); LDADataset data = readDataSet(reader, dict); - + reader.close(); return data; - } - catch (Exception e){ + } catch (Exception e) { System.out.println("Read Dataset Error: " + e.getMessage()); e.printStackTrace(); return null; } } - + /** - * read a dataset from a stream, create new dictionary - * @return dataset if success and null otherwise + * read a dataset from a stream, create new dictionary + * + * @return dataset if success and null otherwise */ - public static LDADataset readDataSet(BufferedReader reader){ + public static LDADataset readDataSet(BufferedReader reader) { try { - //read number of document + // read number of document String line; line = reader.readLine(); int M = Integer.parseInt(line); - + LDADataset data = new LDADataset(M); - for (int i = 0; i < M; ++i){ + for (int i = 0; i < M; ++i) { line = reader.readLine(); - + data.setDoc(line, i); } - + return data; - } - catch (Exception e){ + } catch (Exception e) { System.out.println("Read Dataset Error: " + e.getMessage()); e.printStackTrace(); return null; } } - + /** * read a dataset from a stream with respect to a specified dictionary - * @param reader stream from which we read dataset - * @param dict the dictionary + * + * @param reader + * stream from which we read dataset + * @param dict + * the dictionary * @return dataset if success and null otherwise */ - public static LDADataset readDataSet(BufferedReader reader, Dictionary dict){ + public static LDADataset readDataSet(BufferedReader reader, Dictionary dict) { try { - //read number of document + // read number of document String line; line = reader.readLine(); int M = Integer.parseInt(line); System.out.println("NewM:" + M); - + LDADataset data = new LDADataset(M, dict); - for (int i = 0; i < M; ++i){ + for (int i = 0; i < M; ++i) { line = reader.readLine(); - + data.setDoc(line, i); } - + return data; - } - catch (Exception e){ + } catch (Exception e) { System.out.println("Read Dataset Error: " + e.getMessage()); e.printStackTrace(); return null; } } - + /** * read a dataset from a string, create new dictionary - * @param str String from which we get the dataset, documents are seperated by newline character + * + * @param str + * String from which we get the dataset, documents are seperated + * by newline character * @return dataset if success and null otherwise */ - public static LDADataset readDataSet(String [] strs){ + public static LDADataset readDataSet(String[] strs) { LDADataset data = new LDADataset(strs.length); - - for (int i = 0 ; i < strs.length; ++i){ + + for (int i = 0; i < strs.length; ++i) { data.setDoc(strs[i], i); } return data; } - + /** * read a dataset from a string with respect to a specified dictionary - * @param str String from which we get the dataset, documents are seperated by newline character - * @param dict the dictionary + * + * @param str + * String from which we get the dataset, documents are seperated + * by newline character + * @param dict + * the dictionary * @return dataset if success and null otherwise */ - public static LDADataset readDataSet(String [] strs, Dictionary dict){ - //System.out.println("readDataset..."); + public static LDADataset readDataSet(String[] strs, Dictionary dict) { + // System.out.println("readDataset..."); LDADataset data = new LDADataset(strs.length, dict); - - for (int i = 0 ; i < strs.length; ++i){ - //System.out.println("set doc " + i); + + for (int i = 0; i < strs.length; ++i) { + // System.out.println("set doc " + i); data.setDoc(strs[i], i); } return data; diff --git a/jgibblda/src/jgibblda/Model.java b/jgibblda/src/jgibblda/Model.java index af5003b5af21bc53eb9f7ed82711612ed61fa43d..b8669df055afabd7e3ab7b97f8a53b93caafce10 100644 --- a/jgibblda/src/jgibblda/Model.java +++ b/jgibblda/src/jgibblda/Model.java @@ -43,67 +43,76 @@ import java.util.List; import java.util.StringTokenizer; import java.util.Vector; -public class Model { - - //--------------------------------------------------------------- - // Class Variables - //--------------------------------------------------------------- - - public static String tassignSuffix; //suffix for topic assignment file - public static String thetaSuffix; //suffix for theta (topic - document distribution) file - public static String phiSuffix; //suffix for phi file (topic - word distribution) file - public static String othersSuffix; //suffix for containing other parameters - public static String twordsSuffix; //suffix for file containing words-per-topics - - //--------------------------------------------------------------- - // Model Parameters and Variables - //--------------------------------------------------------------- - - public String wordMapFile; //file that contain word to id map - public String trainlogFile; //training log file - +public class Model { + + // --------------------------------------------------------------- + // Class Variables + // --------------------------------------------------------------- + + public static String tassignSuffix; // suffix for topic assignment file + public static String thetaSuffix; // suffix for theta (topic - document + // distribution) file + public static String phiSuffix; // suffix for phi file (topic - word + // distribution) file + public static String othersSuffix; // suffix for containing other parameters + public static String twordsSuffix; // suffix for file containing + // words-per-topics + + // --------------------------------------------------------------- + // Model Parameters and Variables + // --------------------------------------------------------------- + + public String wordMapFile; // file that contain word to id map + public String trainlogFile; // training log file + public String dir; public String dfile; public String modelName; - public int modelStatus; //see Constants class for status of model - public LDADataset data; // link to a dataset - - public int M; //dataset size (i.e., number of docs) - public int V; //vocabulary size - public int K; //number of topics - public double alpha, beta; //LDA hyperparameters - public int niters; //number of Gibbs sampling iteration - public int liter; //the iteration at which the model was saved - public int savestep; //saving period - public int twords; //print out top words per each topic + public int modelStatus; // see Constants class for status of model + public LDADataset data; // link to a dataset + + public int M; // dataset size (i.e., number of docs) + public int V; // vocabulary size + public int K; // number of topics + public double alpha, beta; // LDA hyperparameters + public int niters; // number of Gibbs sampling iteration + public int liter; // the iteration at which the model was saved + public int savestep; // saving period + public int twords; // print out top words per each topic public int withrawdata; - + // Estimated/Inferenced parameters - public double [][] theta; //theta: document - topic distributions, size M x K - public double [][] phi; // phi: topic-word distributions, size K x V - + public double[][] theta; // theta: document - topic distributions, size M x + // K + public double[][] phi; // phi: topic-word distributions, size K x V + // Temp variables while sampling - public Vector<Integer> [] z; //topic assignments for words, size M x doc.size() - protected int [][] nw; //nw[i][j]: number of instances of word/term i assigned to topic j, size V x K - protected int [][] nd; //nd[i][j]: number of words in document i assigned to topic j, size M x K - protected int [] nwsum; //nwsum[j]: total number of words assigned to topic j, size K - protected int [] ndsum; //ndsum[i]: total number of words in document i, size M - + public Vector<Integer>[] z; // topic assignments for words, size M x + // doc.size() + protected int[][] nw; // nw[i][j]: number of instances of word/term i + // assigned to topic j, size V x K + protected int[][] nd; // nd[i][j]: number of words in document i assigned to + // topic j, size M x K + protected int[] nwsum; // nwsum[j]: total number of words assigned to topic + // j, size K + protected int[] ndsum; // ndsum[i]: total number of words in document i, + // size M + // temp variables for sampling - protected double [] p; - - //--------------------------------------------------------------- - // Constructors - //--------------------------------------------------------------- - - public Model(){ - setDefaultValues(); + protected double[] p; + + // --------------------------------------------------------------- + // Constructors + // --------------------------------------------------------------- + + public Model() { + setDefaultValues(); } - + /** * Set default values for variables */ - public void setDefaultValues(){ + public void setDefaultValues() { wordMapFile = "wordmap.txt"; trainlogFile = "trainlog.txt"; tassignSuffix = ".tassign"; @@ -111,12 +120,12 @@ public class Model { phiSuffix = ".phi"; othersSuffix = ".others"; twordsSuffix = ".twords"; - + dir = "./"; dfile = "trndocs.dat"; modelName = "model-final"; - modelStatus = Constants.MODEL_STATUS_UNKNOWN; - + modelStatus = Constants.MODEL_STATUS_UNKNOWN; + M = 0; V = 0; K = 100; @@ -124,7 +133,7 @@ public class Model { beta = 0.1; niters = 2000; liter = 0; - + z = null; nw = null; nd = null; @@ -133,396 +142,382 @@ public class Model { theta = null; phi = null; } - - //--------------------------------------------------------------- - // I/O Methods - //--------------------------------------------------------------- + + // --------------------------------------------------------------- + // I/O Methods + // --------------------------------------------------------------- /** * read other file to get parameters */ - protected boolean readOthersFile(String otherFile){ - //open file <model>.others to read: - + protected boolean readOthersFile(String otherFile) { + // open file <model>.others to read: + try { BufferedReader reader = new BufferedReader(new FileReader(otherFile)); String line; - while((line = reader.readLine()) != null){ - StringTokenizer tknr = new StringTokenizer(line,"= \t\r\n"); - + while ((line = reader.readLine()) != null) { + StringTokenizer tknr = new StringTokenizer(line, "= \t\r\n"); + int count = tknr.countTokens(); if (count != 2) continue; - + String optstr = tknr.nextToken(); String optval = tknr.nextToken(); - - if (optstr.equalsIgnoreCase("alpha")){ - alpha = Double.parseDouble(optval); - } - else if (optstr.equalsIgnoreCase("beta")){ + + if (optstr.equalsIgnoreCase("alpha")) { + alpha = Double.parseDouble(optval); + } else if (optstr.equalsIgnoreCase("beta")) { beta = Double.parseDouble(optval); - } - else if (optstr.equalsIgnoreCase("ntopics")){ + } else if (optstr.equalsIgnoreCase("ntopics")) { K = Integer.parseInt(optval); - } - else if (optstr.equalsIgnoreCase("liter")){ + } else if (optstr.equalsIgnoreCase("liter")) { liter = Integer.parseInt(optval); - } - else if (optstr.equalsIgnoreCase("nwords")){ + } else if (optstr.equalsIgnoreCase("nwords")) { V = Integer.parseInt(optval); - } - else if (optstr.equalsIgnoreCase("ndocs")){ + } else if (optstr.equalsIgnoreCase("ndocs")) { M = Integer.parseInt(optval); - } - else { + } else { // any more? } } - + reader.close(); - } - catch (Exception e){ + } catch (Exception e) { System.out.println("Error while reading other file:" + e.getMessage()); e.printStackTrace(); return false; } return true; } - - protected boolean readTAssignFile(String tassignFile){ + + protected boolean readTAssignFile(String tassignFile) { try { - int i,j; - BufferedReader reader = new BufferedReader(new InputStreamReader( - new FileInputStream(tassignFile), "UTF-8")); - + int i, j; + BufferedReader reader = new BufferedReader( + new InputStreamReader(new FileInputStream(tassignFile), "UTF-8")); + String line; - z = new Vector[M]; + z = new Vector[M]; data = new LDADataset(M); - data.V = V; - for (i = 0; i < M; i++){ + data.V = V; + for (i = 0; i < M; i++) { line = reader.readLine(); StringTokenizer tknr = new StringTokenizer(line, " \t\r\n"); - + int length = tknr.countTokens(); - + Vector<Integer> words = new Vector<Integer>(); Vector<Integer> topics = new Vector<Integer>(); - - for (j = 0; j < length; j++){ + + for (j = 0; j < length; j++) { String token = tknr.nextToken(); - + StringTokenizer tknr2 = new StringTokenizer(token, ":"); - if (tknr2.countTokens() != 2){ + if (tknr2.countTokens() != 2) { System.out.println("Invalid word-topic assignment line\n"); return false; } - + words.add(Integer.parseInt(tknr2.nextToken())); topics.add(Integer.parseInt(tknr2.nextToken())); - }//end for each topic assignment - - //allocate and add new document to the corpus + } // end for each topic assignment + + // allocate and add new document to the corpus Document doc = new Document(words); data.setDoc(doc, i); - - //assign values for z + + // assign values for z z[i] = new Vector<Integer>(); - for (j = 0; j < topics.size(); j++){ + for (j = 0; j < topics.size(); j++) { z[i].add(topics.get(j)); } - - }//end for each doc - + + } // end for each doc + reader.close(); - } - catch (Exception e){ + } catch (Exception e) { System.out.println("Error while loading model: " + e.getMessage()); e.printStackTrace(); return false; } return true; } - + /** * load saved model */ - public boolean loadModel(){ + public boolean loadModel() { if (!readOthersFile(dir + File.separator + modelName + othersSuffix)) return false; - + if (!readTAssignFile(dir + File.separator + modelName + tassignSuffix)) return false; - + // read dictionary Dictionary dict = new Dictionary(); if (!dict.readWordMap(dir + File.separator + wordMapFile)) return false; - + data.localDict = dict; - + return true; } - + /** * Save word-topic assignments for this model */ - public boolean saveModelTAssign(String filename){ + public boolean saveModelTAssign(String filename) { int i, j; - - try{ + + try { BufferedWriter writer = new BufferedWriter(new FileWriter(filename)); - - //write docs with topic assignments for words - for (i = 0; i < data.M; i++){ - for (j = 0; j < data.docs[i].length; ++j){ - writer.write(data.docs[i].words[j] + ":" + z[i].get(j) + " "); + + // write docs with topic assignments for words + for (i = 0; i < data.M; i++) { + for (j = 0; j < data.docs[i].length; ++j) { + writer.write(data.docs[i].words[j] + ":" + z[i].get(j) + " "); } writer.write("\n"); } - + writer.close(); - } - catch (Exception e){ + } catch (Exception e) { System.out.println("Error while saving model tassign: " + e.getMessage()); e.printStackTrace(); return false; } return true; } - + /** * Save theta (topic distribution) for this model */ - public boolean saveModelTheta(String filename){ - try{ + public boolean saveModelTheta(String filename) { + try { BufferedWriter writer = new BufferedWriter(new FileWriter(filename)); - for (int i = 0; i < M; i++){ - for (int j = 0; j < K; j++){ + for (int i = 0; i < M; i++) { + for (int j = 0; j < K; j++) { writer.write(theta[i][j] + " "); } writer.write("\n"); } writer.close(); - } - catch (Exception e){ + } catch (Exception e) { System.out.println("Error while saving topic distribution file for this model: " + e.getMessage()); e.printStackTrace(); return false; } return true; } - + /** * Save word-topic distribution */ - - public boolean saveModelPhi(String filename){ + + public boolean saveModelPhi(String filename) { try { BufferedWriter writer = new BufferedWriter(new FileWriter(filename)); - - for (int i = 0; i < K; i++){ - for (int j = 0; j < V; j++){ + + for (int i = 0; i < K; i++) { + for (int j = 0; j < V; j++) { writer.write(phi[i][j] + " "); } writer.write("\n"); } writer.close(); - } - catch (Exception e){ + } catch (Exception e) { System.out.println("Error while saving word-topic distribution:" + e.getMessage()); e.printStackTrace(); return false; } return true; } - + /** * Save other information of this model */ - public boolean saveModelOthers(String filename){ - try{ + public boolean saveModelOthers(String filename) { + try { BufferedWriter writer = new BufferedWriter(new FileWriter(filename)); - + writer.write("alpha=" + alpha + "\n"); writer.write("beta=" + beta + "\n"); writer.write("ntopics=" + K + "\n"); writer.write("ndocs=" + M + "\n"); writer.write("nwords=" + V + "\n"); writer.write("liters=" + liter + "\n"); - + writer.close(); - } - catch(Exception e){ + } catch (Exception e) { System.out.println("Error while saving model others:" + e.getMessage()); e.printStackTrace(); return false; } return true; } - + /** * Save model the most likely words for each topic */ - public boolean saveModelTwords(String filename){ - try{ - BufferedWriter writer = new BufferedWriter(new OutputStreamWriter( - new FileOutputStream(filename), "UTF-8")); - - if (twords > V){ + public boolean saveModelTwords(String filename) { + try { + BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(filename), "UTF-8")); + + if (twords > V) { twords = V; } - - for (int k = 0; k < K; k++){ - List<Pair> wordsProbsList = new ArrayList<Pair>(); - for (int w = 0; w < V; w++){ + + for (int k = 0; k < K; k++) { + List<Pair> wordsProbsList = new ArrayList<Pair>(); + for (int w = 0; w < V; w++) { Pair p = new Pair(w, phi[k][w], false); - + wordsProbsList.add(p); - }//end foreach word - - //print topic + } // end foreach word + + // print topic writer.write("Topic " + k + "th:\n"); Collections.sort(wordsProbsList); - - for (int i = 0; i < twords; i++){ - if (data.localDict.contains((Integer)wordsProbsList.get(i).first)){ - String word = data.localDict.getWord((Integer)wordsProbsList.get(i).first); - + + for (int i = 0; i < twords; i++) { + if (data.localDict.contains((Integer) wordsProbsList.get(i).first)) { + String word = data.localDict.getWord((Integer) wordsProbsList.get(i).first); + writer.write("\t" + word + " " + wordsProbsList.get(i).second + "\n"); } } - } //end foreach topic - + } // end foreach topic + writer.close(); - } - catch(Exception e){ + } catch (Exception e) { System.out.println("Error while saving model twords: " + e.getMessage()); e.printStackTrace(); return false; } return true; } - + /** * Save model */ - public boolean saveModel(String modelName){ - if (!saveModelTAssign(dir + File.separator + modelName + tassignSuffix)){ + public boolean saveModel(String modelName) { + if (!saveModelTAssign(dir + File.separator + modelName + tassignSuffix)) { return false; } - - if (!saveModelOthers(dir + File.separator + modelName + othersSuffix)){ + + if (!saveModelOthers(dir + File.separator + modelName + othersSuffix)) { return false; } - - if (!saveModelTheta(dir + File.separator + modelName + thetaSuffix)){ + + if (!saveModelTheta(dir + File.separator + modelName + thetaSuffix)) { return false; } - - if (!saveModelPhi(dir + File.separator + modelName + phiSuffix)){ + + if (!saveModelPhi(dir + File.separator + modelName + phiSuffix)) { return false; } - - if (twords > 0){ + + if (twords > 0) { if (!saveModelTwords(dir + File.separator + modelName + twordsSuffix)) return false; } return true; } - - //--------------------------------------------------------------- - // Init Methods - //--------------------------------------------------------------- + + // --------------------------------------------------------------- + // Init Methods + // --------------------------------------------------------------- /** * initialize the model */ - protected boolean init(LDACmdOption option){ + protected boolean init(LDACmdOption option) { if (option == null) return false; - + modelName = option.modelName; K = option.K; - + alpha = option.alpha; if (alpha < 0.0) alpha = 50.0 / K; - + if (option.beta >= 0) beta = option.beta; - + niters = option.niters; - + dir = option.dir; if (dir.endsWith(File.separator)) dir = dir.substring(0, dir.length() - 1); - + dfile = option.dfile; twords = option.twords; wordMapFile = option.wordMapFileName; - + return true; } - + /** * Init parameters for estimation */ - public boolean initNewModel(LDACmdOption option){ + public boolean initNewModel(LDACmdOption option) { if (!init(option)) return false; - - int m, n, w, k; - p = new double[K]; - + + int m, n, w, k; + p = new double[K]; + data = LDADataset.readDataSet(dir + File.separator + dfile); - if (data == null){ + if (data == null) { System.out.println("Fail to read training data!\n"); return false; } - - //+ allocate memory and assign values for variables + + // + allocate memory and assign values for variables M = data.M; V = data.V; dir = option.dir; savestep = option.savestep; - + // K: from command line or default value - // alpha, beta: from command line or default values - // niters, savestep: from command line or default values + // alpha, beta: from command line or default values + // niters, savestep: from command line or default values nw = new int[V][K]; - for (w = 0; w < V; w++){ - for (k = 0; k < K; k++){ + for (w = 0; w < V; w++) { + for (k = 0; k < K; k++) { nw[w][k] = 0; } } - + nd = new int[M][K]; - for (m = 0; m < M; m++){ - for (k = 0; k < K; k++){ + for (m = 0; m < M; m++) { + for (k = 0; k < K; k++) { nd[m][k] = 0; } } - + nwsum = new int[K]; - for (k = 0; k < K; k++){ + for (k = 0; k < K; k++) { nwsum[k] = 0; } - + ndsum = new int[M]; - for (m = 0; m < M; m++){ + for (m = 0; m < M; m++) { ndsum[m] = 0; } - + z = new Vector[M]; - for (m = 0; m < data.M; m++){ + for (m = 0; m < data.M; m++) { int N = data.docs[m].length; z[m] = new Vector<Integer>(); - - //initilize for z - for (n = 0; n < N; n++){ - int topic = (int)Math.floor(Math.random() * K); + + // initilize for z + for (n = 0; n < N; n++) { + int topic = (int) Math.floor(Math.random() * K); z[m].add(topic); - + // number of instances of word assigned to topic j nw[data.docs[m].words[n]][topic] += 1; // number of words in document i assigned to topic j @@ -533,78 +528,80 @@ public class Model { // total number of words in document i ndsum[m] = N; } - - theta = new double[M][K]; + + theta = new double[M][K]; phi = new double[K][V]; - + return true; } - + /** * Init parameters for inference - * @param newData DataSet for which we do inference + * + * @param newData + * DataSet for which we do inference */ - public boolean initNewModel(LDACmdOption option, LDADataset newData, Model trnModel){ + public boolean initNewModel(LDACmdOption option, LDADataset newData, Model trnModel) { if (!init(option)) return false; - + int m, n, w, k; - + K = trnModel.K; alpha = trnModel.alpha; - beta = trnModel.beta; - + beta = trnModel.beta; + p = new double[K]; System.out.println("K:" + K); - + data = newData; - - //+ allocate memory and assign values for variables + + // + allocate memory and assign values for variables M = data.M; V = data.V; dir = option.dir; savestep = option.savestep; System.out.println("M:" + M); System.out.println("V:" + V); - + // K: from command line or default value - // alpha, beta: from command line or default values - // niters, savestep: from command line or default values + // alpha, beta: from command line or default values + // niters, savestep: from command line or default values nw = new int[V][K]; - for (w = 0; w < V; w++){ - for (k = 0; k < K; k++){ + for (w = 0; w < V; w++) { + for (k = 0; k < K; k++) { nw[w][k] = 0; } } - + nd = new int[M][K]; - for (m = 0; m < M; m++){ - for (k = 0; k < K; k++){ + for (m = 0; m < M; m++) { + for (k = 0; k < K; k++) { nd[m][k] = 0; } } - + nwsum = new int[K]; - for (k = 0; k < K; k++){ + for (k = 0; k < K; k++) { nwsum[k] = 0; } - + ndsum = new int[M]; - for (m = 0; m < M; m++){ + for (m = 0; m < M; m++) { ndsum[m] = 0; } - + z = new Vector[M]; - for (m = 0; m < data.M; m++){ + for (m = 0; m < data.M; m++) { int N = data.docs[m].length; z[m] = new Vector<Integer>(); - - //initilize for z - for (n = 0; n < N; n++){ - int topic = (int)Math.floor(Math.random() * K); + + // initilize for z + for (n = 0; n < N; n++) { + int topic = (int) Math.floor(Math.random() * K); z[m].add(topic); - + // number of instances of word assigned to topic j nw[data.docs[m].words[n]][topic] += 1; // number of words in document i assigned to topic j @@ -615,102 +612,101 @@ public class Model { // total number of words in document i ndsum[m] = N; } - - theta = new double[M][K]; + + theta = new double[M][K]; phi = new double[K][V]; - + return true; } - + /** - * Init parameters for inference - * reading new dataset from file + * Init parameters for inference reading new dataset from file */ - public boolean initNewModel(LDACmdOption option, Model trnModel){ + public boolean initNewModel(LDACmdOption option, Model trnModel) { if (!init(option)) return false; - + LDADataset dataset = LDADataset.readDataSet(dir + File.separator + dfile, trnModel.data.localDict); - if (dataset == null){ + if (dataset == null) { System.out.println("Fail to read dataset!\n"); return false; } - - return initNewModel(option, dataset , trnModel); + + return initNewModel(option, dataset, trnModel); } - + /** * init parameter for continue estimating or for later inference */ - public boolean initEstimatedModel(LDACmdOption option){ + public boolean initEstimatedModel(LDACmdOption option) { if (!init(option)) return false; - + int m, n, w, k; - + p = new double[K]; - + // load model, i.e., read z and trndata - if (!loadModel()){ + if (!loadModel()) { System.out.println("Fail to load word-topic assignment file of the model!\n"); return false; } - + System.out.println("Model loaded:"); System.out.println("\talpha:" + alpha); System.out.println("\tbeta:" + beta); System.out.println("\tM:" + M); - System.out.println("\tV:" + V); - + System.out.println("\tV:" + V); + nw = new int[V][K]; - for (w = 0; w < V; w++){ - for (k = 0; k < K; k++){ + for (w = 0; w < V; w++) { + for (k = 0; k < K; k++) { nw[w][k] = 0; } } - + nd = new int[M][K]; - for (m = 0; m < M; m++){ - for (k = 0; k < K; k++){ + for (m = 0; m < M; m++) { + for (k = 0; k < K; k++) { nd[m][k] = 0; } } - + nwsum = new int[K]; - for (k = 0; k < K; k++) { - nwsum[k] = 0; - } - - ndsum = new int[M]; - for (m = 0; m < M; m++) { - ndsum[m] = 0; - } - - for (m = 0; m < data.M; m++){ - int N = data.docs[m].length; - - // assign values for nw, nd, nwsum, and ndsum - for (n = 0; n < N; n++){ - w = data.docs[m].words[n]; - int topic = (Integer)z[m].get(n); - - // number of instances of word i assigned to topic j - nw[w][topic] += 1; - // number of words in document i assigned to topic j - nd[m][topic] += 1; - // total number of words assigned to topic j - nwsum[topic] += 1; - } - // total number of words in document i - ndsum[m] = N; - } - - theta = new double[M][K]; - phi = new double[K][V]; - dir = option.dir; + for (k = 0; k < K; k++) { + nwsum[k] = 0; + } + + ndsum = new int[M]; + for (m = 0; m < M; m++) { + ndsum[m] = 0; + } + + for (m = 0; m < data.M; m++) { + int N = data.docs[m].length; + + // assign values for nw, nd, nwsum, and ndsum + for (n = 0; n < N; n++) { + w = data.docs[m].words[n]; + int topic = (Integer) z[m].get(n); + + // number of instances of word i assigned to topic j + nw[w][topic] += 1; + // number of words in document i assigned to topic j + nd[m][topic] += 1; + // total number of words assigned to topic j + nwsum[topic] += 1; + } + // total number of words in document i + ndsum[m] = N; + } + + theta = new double[M][K]; + phi = new double[K][V]; + dir = option.dir; savestep = option.savestep; - + return true; } - + } diff --git a/jgibblda/src/jgibblda/Pair.java b/jgibblda/src/jgibblda/Pair.java index 98402c894049ffa01d2d58f497812252b857f6ae..6eef0aa292cbc788947c3baa98c907f1024215db 100644 --- a/jgibblda/src/jgibblda/Pair.java +++ b/jgibblda/src/jgibblda/Pair.java @@ -34,22 +34,22 @@ public class Pair implements Comparable<Pair> { public Object first; public Comparable second; public static boolean naturalOrder = false; - - public Pair(Object k, Comparable v){ + + public Pair(Object k, Comparable v) { first = k; - second = v; + second = v; } - - public Pair(Object k, Comparable v, boolean naturalOrder){ + + public Pair(Object k, Comparable v, boolean naturalOrder) { first = k; second = v; - Pair.naturalOrder = naturalOrder; + Pair.naturalOrder = naturalOrder; } - - public int compareTo(Pair p){ + + public int compareTo(Pair p) { if (naturalOrder) return this.second.compareTo(p.second); - else return -this.second.compareTo(p.second); + else + return -this.second.compareTo(p.second); } } -