/*
 * Decompiled with CFR 0.152.
 */
package com.wcohen.ss;

import com.wcohen.ss.AbstractSourcedStatisticalTokenDistance;
import com.wcohen.ss.BagOfSourcedTokens;
import com.wcohen.ss.PrintfFormat;
import com.wcohen.ss.api.SourcedStringWrapper;
import com.wcohen.ss.api.SourcedToken;
import com.wcohen.ss.api.SourcedTokenizer;
import com.wcohen.ss.api.StringWrapper;
import com.wcohen.ss.api.Token;
import java.util.Iterator;

public class SourcedTFIDF
extends AbstractSourcedStatisticalTokenDistance {
    private UnitVector lastVector = null;

    public SourcedTFIDF(SourcedTokenizer tokenizer) {
        super(tokenizer);
    }

    public SourcedTFIDF() {
    }

    public double score(StringWrapper s0, StringWrapper t0) {
        SourcedStringWrapper s = (SourcedStringWrapper)s0;
        SourcedStringWrapper t = (SourcedStringWrapper)t0;
        this.checkTrainingHasHappened(s, t);
        UnitVector sBag = this.asUnitVector(s);
        UnitVector tBag = this.asUnitVector(t);
        double sim = 0.0;
        int numCommon = 0;
        Iterator i = sBag.tokenIterator();
        while (i.hasNext()) {
            Token sTok = (Token)i.next();
            SourcedToken tTok = null;
            tTok = tBag.getEquivalentToken(sTok);
            if (tTok == null) continue;
            sim += sBag.getWeight(sTok) * tBag.getWeight(tTok);
            ++numCommon;
        }
        return sim;
    }

    protected UnitVector asUnitVector(SourcedStringWrapper w) {
        if (w instanceof UnitVector) {
            return (UnitVector)w;
        }
        if (w instanceof BagOfSourcedTokens) {
            return new UnitVector((BagOfSourcedTokens)w);
        }
        return new UnitVector(w.unwrap(), this.tokenizer.sourcedTokenize(w.unwrap(), w.getSource()));
    }

    public StringWrapper prepare(String s) {
        System.out.println("unknown source for " + s);
        this.lastVector = new UnitVector(s, this.tokenizer.sourcedTokenize(s, "*UNKNOWN SOURCE*"));
        return this.lastVector;
    }

    public Token[] getTokens() {
        return this.lastVector.getTokens();
    }

    public double getWeight(Token token) {
        return this.lastVector.getWeight(token);
    }

    public int getDocumentFrequency(Token token) {
        Integer df = (Integer)this.documentFrequency.get(token);
        if (df == null) {
            return 0;
        }
        return df;
    }

    public void setDocumentFrequency(Token token, int df) {
        this.documentFrequency.put(token, new Integer(df));
    }

    public int getCollectionSize() {
        return this.collectionSize;
    }

    public void setCollectionSize(int n) {
        this.collectionSize = n;
    }

    public String explainScore(StringWrapper s, StringWrapper t) {
        BagOfSourcedTokens sBag = (BagOfSourcedTokens)s;
        BagOfSourcedTokens tBag = (BagOfSourcedTokens)t;
        StringBuffer buf = new StringBuffer("");
        PrintfFormat fmt = new PrintfFormat("%.3f");
        buf.append("Common tokens: ");
        Iterator i = sBag.tokenIterator();
        while (i.hasNext()) {
            SourcedToken sTok = (SourcedToken)i.next();
            SourcedToken tTok = null;
            tTok = tBag.getEquivalentToken(sTok);
            if (tTok == null) continue;
            buf.append(" " + sTok.getValue() + ": ");
            buf.append(fmt.sprintf(sBag.getWeight(sTok)));
            buf.append("*");
            buf.append(fmt.sprintf(tBag.getWeight(tTok)));
        }
        buf.append("\nscore = " + this.score(s, t));
        return buf.toString();
    }

    public String toString() {
        return "[SourcedTFIDF]";
    }

    public static void main(String[] argv) {
        SourcedTFIDF.doMain(new SourcedTFIDF(), argv);
    }

    protected class UnitVector
    extends BagOfSourcedTokens {
        public UnitVector(String s, SourcedToken[] tokens) {
            super(s, tokens);
            this.termFreq2TFIDF();
        }

        public UnitVector(BagOfSourcedTokens bag) {
            this(bag.unwrap(), bag.getSourcedTokens());
            this.termFreq2TFIDF();
        }

        private void termFreq2TFIDF() {
            Token tok;
            double normalizer = 0.0;
            Iterator i = this.tokenIterator();
            while (i.hasNext()) {
                tok = (Token)i.next();
                if (SourcedTFIDF.this.collectionSize > 0) {
                    Integer dfInteger = (Integer)SourcedTFIDF.this.documentFrequency.get(tok);
                    double df = dfInteger == null ? 1.0 : (double)dfInteger.intValue();
                    double w = Math.log(this.getWeight(tok) + 1.0) * Math.log((double)SourcedTFIDF.this.collectionSize / df);
                    this.setWeight(tok, w);
                    normalizer += w * w;
                    continue;
                }
                this.setWeight(tok, 1.0);
                normalizer += 1.0;
            }
            normalizer = Math.sqrt(normalizer);
            i = this.tokenIterator();
            while (i.hasNext()) {
                tok = (Token)i.next();
                this.setWeight(tok, this.getWeight(tok) / normalizer);
            }
        }
    }
}

