余弦相似性算法

余弦相似性算法的具体介绍参考:http://www.ruanyifeng.com/blog/2013/03/cosine_similarity.html

下面是我根据上边的介绍进行的java语言的实现:

import java.io.IOException;
import java.io.StringReader;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.wltea.analyzer.lucene.IKAnalyzer;

import com.wjb.util.common.WjbTuple2;

public class CosineTextSimilarity {

    public static Map<String,Integer> makeTermFrequency(String text) throws IOException
    {
        Analyzer analyzer = new IKAnalyzer(true);
        StringReader reader = new StringReader(text);
        TokenStream ts = analyzer.tokenStream("",reader);  
        CharTermAttribute term=ts.getAttribute(CharTermAttribute.class); 
        Map<String,Integer> tf = new HashMap<String,Integer>();
        while(ts.incrementToken()){ 
            String t = term.toString();
            Integer count = tf.get(t);
            if(count == null)
            {
                tf.put(t,1);
            }else{
                tf.put(t,count + 1);
            }
        }
        analyzer.close();
        reader.close();
        return tf;
    }

    /** * 根据key的长度进行过滤,只有key的长度不小于 length 时,这个key才会保留 * @param map * @param length * @return * @throws IOException */
    public static Map<String,Integer> filterByKeyLength(Map<String,Integer> map,int length) throws IOException
    {
        Map<String,Integer> m = new HashMap<String,Integer>();
        for(String key : map.keySet())
        {
             if(key == null || key.trim().length() >= length)
             {
                 m.put(key,map.get(key));
             }
        }
        return m;
    }

    public static WjbTuple2<int[],int[]> makeVector(Map<String,Integer> first,Map<String,Integer> second){
         Set<String> keys = new HashSet<String>();
         keys.addAll(first.keySet());
         keys.addAll(second.keySet());
         int[] vector1 = new int[keys.size()];
         int[] vector2 = new int[keys.size()];
         int i = 0;
         for(String key : keys)
         {
             Integer count1 = first.get(key);
             if(count1 != null)
             {
                 vector1[i] = count1;
             }
             Integer count2 = second.get(key);
             if(count2 != null)
             {
                 vector2[i] = count2;
             }
             i++;

         }
        return new WjbTuple2<int[],int[]>(vector1,vector2);
    }



    public static double cosine(WjbTuple2<int[],int[]> tuple)
    {
        int[] vector1 = tuple._1;
        int[] vector2 = tuple._2;

        double sum1 = 0;
        double sum21 = 0;
        double sum22 = 0;

        for (int i = 0; i < vector1.length; i++) {
            sum1 += vector1[i] * vector2[i];
            sum21 += vector1[i] * vector1[i];
            sum22 += vector2[i] * vector2[i];
        }

        return sum1/(Math.sqrt(sum21 * sum22 ));
    }

    public static List<Entry> sort(Map unsortMap) {

        // Convert Map to List
        List<Map.Entry> list = new LinkedList<Map.Entry>(unsortMap.entrySet());

        // Sort list with comparator,to compare the Map values
        Collections.sort(list,new Comparator<Map.Entry>() {
            public int compare(Map.Entry o1,Map.Entry o2) {
                String d1 = o1.getValue().toString();
                String d2 = o2.getValue().toString();
                String k1 = o1.getKey().toString();
                String k2 = o2.getKey().toString();
                if(o1.getValue() instanceof Integer)
                {
                    Integer nd1 = Integer.parseInt(d1);
                    Integer nd2 = Integer.parseInt(d2);
                    if( nd2 - nd1 != 0 )
                        return nd2 - nd1;
                    else{
                        return k2.compareTo(k1);
                    }
                }else
                    return d2.compareTo(d1);
            }
        });

        return list;
    }
}
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import com.wjb.util.common.WjbFileUtil;
import com.wjb.util.common.WjbTuple2;

public class Main {
    public static void main(String[] args) throws Exception {

        String text1 = WjbFileUtil.fromFile("d:/1.txt");
        String text2 = WjbFileUtil.fromFile("d:/2.txt",WjbFileUtil.GBK);

        System.out.println(text2);
        long begin = System.currentTimeMillis();
        Map<String,Integer> map1 = CosineTextSimilarity.makeTermFrequency(text1);
        Map<String,Integer> map2 = CosineTextSimilarity.makeTermFrequency(text2);

//      map1 = CosineTextSimilarity.filterByKeyLength(map1,2);
//      map2 = CosineTextSimilarity.filterByKeyLength(map2,2);

        List<Entry> list1 = CosineTextSimilarity.sort(map1);
        System.out.println(list1);
        list1 = list1.subList(0,list1.size() > 20 ? 20 : list1.size());

        List<Entry> list2 = CosineTextSimilarity.sort(map2);
        System.out.println(list2);
        list2 = list2.subList(0,list2.size() > 20 ? 20 : list2.size());

        map1 = list2Map(list1);
        map2 = list2Map(list2);

        WjbTuple2<int[],int[]> tuple = CosineTextSimilarity.makeVector(map1,map2);
        double cos = CosineTextSimilarity.cosine(tuple);

        long end = System.currentTimeMillis();

        System.out.println(end - begin);

        System.out.println(cos);
    }

    public static Map<String,Integer> list2Map(List<Entry> list)
    {
        Map<String,Integer> map = new HashMap<String,Integer>();
        for(Entry e : list)
        {
            map.put(e.getKey().toString(),(Integer)e.getValue());
        }
        return map;
    }
}

相关文章

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注