diff --git a/.gitignore b/.gitignore index 58bcbf8..4be0b9d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,32 +1,33 @@ -# Windows image file caches -Thumbs.db -ehthumbs.db - -# Folder config file -Desktop.ini - -# Recycle Bin used on file shares -$RECYCLE.BIN/ - -# Windows Installer files -*.cab -*.msi -*.msm -*.msp - -# ========================= -# Operating System Files -# ========================= - -# OSX -# ========================= - +# Windows image file caches +Thumbs.db +ehthumbs.db + +# Folder config file +Desktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msm +*.msp + +# ========================= +# Operating System Files +# ========================= + +# OSX +# ========================= + .DS_Store .AppleDouble .LSOverride # Icon must end with two \r -Icon +Icon + # Thumbnails ._* @@ -41,3 +42,7 @@ Icon Network Trash Folder Temporary Items .apdisk + +# Intellij idea +.idea +out diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/AssociationAnalysis/DataMining_Apriori/AprioriTool.java b/AssociationAnalysis/DataMining_Apriori/AprioriTool.java index 2cd08b4..972872c 100644 --- a/AssociationAnalysis/DataMining_Apriori/AprioriTool.java +++ b/AssociationAnalysis/DataMining_Apriori/AprioriTool.java @@ -1,432 +1,417 @@ -package DataMining_Apriori; +package AssociationAnalysis.DataMining_Apriori; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.text.MessageFormat; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; +import java.util.*; /** * apriori算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class AprioriTool { - // 最小支持度计数 - private int minSupportCount; - // 测试数据文件地址 - private String filePath; - // 每个事务中的商品ID - private ArrayList totalGoodsIDs; - // 过程中计算出来的所有频繁项集列表 - private ArrayList resultItem; - // 过程中计算出来频繁项集的ID集合 - private ArrayList resultItemID; - - public AprioriTool(String filePath, int minSupportCount) { - this.filePath = filePath; - this.minSupportCount = minSupportCount; - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - String[] temp = null; - totalGoodsIDs = new ArrayList<>(); - for (String[] array : dataArray) { - temp = new String[array.length - 1]; - System.arraycopy(array, 1, temp, 0, array.length - 1); - - // 将事务ID加入列表吧中 - totalGoodsIDs.add(temp); - } - } - - /** - * 判读字符数组array2是否包含于数组array1中 - * - * @param array1 - * @param array2 - * @return - */ - public boolean iSStrContain(String[] array1, String[] array2) { - if (array1 == null || array2 == null) { - return false; - } - - boolean iSContain = false; - for (String s : array2) { - // 新的字母比较时,重新初始化变量 - iSContain = false; - // 判读array2中每个字符,只要包括在array1中 ,就算包含 - for (String s2 : array1) { - if (s.equals(s2)) { - iSContain = true; - break; - } - } - - // 如果已经判断出不包含了,则直接中断循环 - if (!iSContain) { - break; - } - } - - return iSContain; - } - - /** - * 项集进行连接运算 - */ - private void computeLink() { - // 连接计算的终止数,k项集必须算到k-1子项集为止 - int endNum = 0; - // 当前已经进行连接运算到几项集,开始时就是1项集 - int currentNum = 1; - // 商品,1频繁项集映射图 - HashMap itemMap = new HashMap<>(); - FrequentItem tempItem; - // 初始列表 - ArrayList list = new ArrayList<>(); - // 经过连接运算后产生的结果项集 - resultItem = new ArrayList<>(); - resultItemID = new ArrayList<>(); - // 商品ID的种类 - ArrayList idType = new ArrayList<>(); - for (String[] a : totalGoodsIDs) { - for (String s : a) { - if (!idType.contains(s)) { - tempItem = new FrequentItem(new String[] { s }, 1); - idType.add(s); - resultItemID.add(new String[] { s }); - } else { - // 支持度计数加1 - tempItem = itemMap.get(s); - tempItem.setCount(tempItem.getCount() + 1); - } - itemMap.put(s, tempItem); - } - } - // 将初始频繁项集转入到列表中,以便继续做连接运算 - for (Map.Entry entry : itemMap.entrySet()) { - list.add((FrequentItem) entry.getValue()); - } - // 按照商品ID进行排序,否则连接计算结果将会不一致,将会减少 - Collections.sort(list); - resultItem.addAll(list); - - String[] array1; - String[] array2; - String[] resultArray; - ArrayList tempIds; - ArrayList resultContainer; - // 总共要算到endNum项集 - endNum = list.size() - 1; - - while (currentNum < endNum) { - resultContainer = new ArrayList<>(); - for (int i = 0; i < list.size() - 1; i++) { - tempItem = list.get(i); - array1 = tempItem.getIdArray(); - for (int j = i + 1; j < list.size(); j++) { - tempIds = new ArrayList<>(); - array2 = list.get(j).getIdArray(); - for (int k = 0; k < array1.length; k++) { - // 如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作 - if (array1[k].equals(array2[k])) { - tempIds.add(array1[k]); - } else { - tempIds.add(array1[k]); - tempIds.add(array2[k]); - } - } - resultArray = new String[tempIds.size()]; - tempIds.toArray(resultArray); - - boolean isContain = false; - // 过滤不符合条件的的ID数组,包括重复的和长度不符合要求的 - if (resultArray.length == (array1.length + 1)) { - isContain = isIDArrayContains(resultContainer, - resultArray); - if (!isContain) { - resultContainer.add(resultArray); - } - } - } - } - - // 做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集 - list = cutItem(resultContainer); - currentNum++; - } - - // 输出频繁项集 - for (int k = 1; k <= currentNum; k++) { - System.out.println("频繁" + k + "项集:"); - for (FrequentItem i : resultItem) { - if (i.getLength() == k) { - System.out.print("{"); - for (String t : i.getIdArray()) { - System.out.print(t + ","); - } - System.out.print("},"); - } - } - System.out.println(); - } - } - - /** - * 判断列表结果中是否已经包含此数组 - * - * @param container - * ID数组容器 - * @param array - * 待比较数组 - * @return - */ - private boolean isIDArrayContains(ArrayList container, - String[] array) { - boolean isContain = true; - if (container.size() == 0) { - isContain = false; - return isContain; - } - - for (String[] s : container) { - // 比较的视乎必须保证长度一样 - if (s.length != array.length) { - continue; - } - - isContain = true; - for (int i = 0; i < s.length; i++) { - // 只要有一个id不等,就算不相等 - if (s[i] != array[i]) { - isContain = false; - break; - } - } - - // 如果已经判断是包含在容器中时,直接退出 - if (isContain) { - break; - } - } - - return isContain; - } - - /** - * 对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集 - */ - private ArrayList cutItem(ArrayList resultIds) { - String[] temp; - // 忽略的索引位置,以此构建子集 - int igNoreIndex = 0; - FrequentItem tempItem; - // 剪枝生成新的频繁项集 - ArrayList newItem = new ArrayList<>(); - // 不符合要求的id - ArrayList deleteIdArray = new ArrayList<>(); - // 子项集是否也为频繁子项集 - boolean isContain = true; - - for (String[] array : resultIds) { - // 列举出其中的一个个的子项集,判断存在于频繁项集列表中 - temp = new String[array.length - 1]; - for (igNoreIndex = 0; igNoreIndex < array.length; igNoreIndex++) { - isContain = true; - for (int j = 0, k = 0; j < array.length; j++) { - if (j != igNoreIndex) { - temp[k] = array[j]; - k++; - } - } - - if (!isIDArrayContains(resultItemID, temp)) { - isContain = false; - break; - } - } - - if (!isContain) { - deleteIdArray.add(array); - } - } - - // 移除不符合条件的ID组合 - resultIds.removeAll(deleteIdArray); - - // 移除支持度计数不够的id集合 - int tempCount = 0; - for (String[] array : resultIds) { - tempCount = 0; - for (String[] array2 : totalGoodsIDs) { - if (isStrArrayContain(array2, array)) { - tempCount++; - } - } - - // 如果支持度计数大于等于最小最小支持度计数则生成新的频繁项集,并加入结果集中 - if (tempCount >= minSupportCount) { - tempItem = new FrequentItem(array, tempCount); - newItem.add(tempItem); - resultItemID.add(array); - resultItem.add(tempItem); - } - } - - return newItem; - } - - /** - * 数组array2是否包含于array1中,不需要完全一样 - * - * @param array1 - * @param array2 - * @return - */ - private boolean isStrArrayContain(String[] array1, String[] array2) { - boolean isContain = true; - for (String s2 : array2) { - isContain = false; - for (String s1 : array1) { - // 只要s2字符存在于array1中,这个字符就算包含在array1中 - if (s2.equals(s1)) { - isContain = true; - break; - } - } - - // 一旦发现不包含的字符,则array2数组不包含于array1中 - if (!isContain) { - break; - } - } - - return isContain; - } - - /** - * 根据产生的频繁项集输出关联规则 - * - * @param minConf - * 最小置信度阈值 - */ - public void printAttachRule(double minConf) { - // 进行连接和剪枝操作 - computeLink(); - - int count1 = 0; - int count2 = 0; - ArrayList childGroup1; - ArrayList childGroup2; - String[] group1; - String[] group2; - // 以最后一个频繁项集做关联规则的输出 - String[] array = resultItem.get(resultItem.size() - 1).getIdArray(); - // 子集总数,计算的时候除去自身和空集 - int totalNum = (int) Math.pow(2, array.length); - String[] temp; - // 二进制数组,用来代表各个子集 - int[] binaryArray; - // 除去头和尾部 - for (int i = 1; i < totalNum - 1; i++) { - binaryArray = new int[array.length]; - numToBinaryArray(binaryArray, i); - - childGroup1 = new ArrayList<>(); - childGroup2 = new ArrayList<>(); - count1 = 0; - count2 = 0; - // 按照二进制位关系取出子集 - for (int j = 0; j < binaryArray.length; j++) { - if (binaryArray[j] == 1) { - childGroup1.add(array[j]); - } else { - childGroup2.add(array[j]); - } - } - - group1 = new String[childGroup1.size()]; - group2 = new String[childGroup2.size()]; - - childGroup1.toArray(group1); - childGroup2.toArray(group2); - - for (String[] a : totalGoodsIDs) { - if (isStrArrayContain(a, group1)) { - count1++; - - // 在group1的条件下,统计group2的事件发生次数 - if (isStrArrayContain(a, group2)) { - count2++; - } - } - } - - // {A}-->{B}的意思为在A的情况下发生B的概率 - System.out.print("{"); - for (String s : group1) { - System.out.print(s + ", "); - } - System.out.print("}-->"); - System.out.print("{"); - for (String s : group2) { - System.out.print(s + ", "); - } - System.out.print(MessageFormat.format( - "},confidence(置信度):{0}/{1}={2}", count2, count1, count2 - * 1.0 / count1)); - if (count2 * 1.0 / count1 < minConf) { - // 不符合要求,不是强规则 - System.out.println("由于此规则置信度未达到最小置信度的要求,不是强规则"); - } else { - System.out.println("为强规则"); - } - } - - } - - /** - * 数字转为二进制形式 - * - * @param binaryArray - * 转化后的二进制数组形式 - * @param num - * 待转化数字 - */ - private void numToBinaryArray(int[] binaryArray, int num) { - int index = 0; - while (num != 0) { - binaryArray[index] = num % 2; - index++; - num /= 2; - } - } +class AprioriTool { + // 最小支持度计数 + private int minSupportCount; + // 测试数据文件地址 + private String filePath; + // 每个事务中的商品ID + private ArrayList totalGoodsIDs; + // 过程中计算出来的所有频繁项集列表 + private ArrayList resultItem; + // 过程中计算出来频繁项集的ID集合 + private ArrayList resultItemID; + + AprioriTool(String filePath, int minSupportCount){ + this.filePath = filePath; + this.minSupportCount = minSupportCount; + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + String[] temp; + totalGoodsIDs = new ArrayList<>(); + for (String[] array : dataArray) { + temp = new String[array.length - 1]; + System.arraycopy(array, 1, temp, 0, array.length - 1); + + // 将事务ID加入列表吧中 + totalGoodsIDs.add(temp); + } + } + + /** + * 判读字符数组array2是否包含于数组array1中 + * + * @param array1 字符数组1 + * @param array2 字符数组2 + */ + public boolean iSStrContain(String[] array1, String[] array2){ + if (array1 == null || array2 == null) { + return false; + } + + boolean iSContain = false; + for (String s : array2) { + // 新的字母比较时,重新初始化变量 + iSContain = false; + // 判读array2中每个字符,只要包括在array1中 ,就算包含 + for (String s2 : array1) { + if (s.equals(s2)) { + iSContain = true; + break; + } + } + + // 如果已经判断出不包含了,则直接中断循环 + if (!iSContain) { + break; + } + } + + return iSContain; + } + + /** + * 项集进行连接运算 + */ + private void computeLink(){ + // 连接计算的终止数,k项集必须算到k-1子项集为止 + int endNum; + // 当前已经进行连接运算到几项集,开始时就是1项集 + int currentNum = 1; + // 商品,1频繁项集映射图 + HashMap itemMap = new HashMap<>(); + FrequentItem tempItem; + // 初始列表 + ArrayList list = new ArrayList<>(); + // 经过连接运算后产生的结果项集 + resultItem = new ArrayList<>(); + resultItemID = new ArrayList<>(); + // 商品ID的种类 + ArrayList idType = new ArrayList<>(); + for (String[] a : totalGoodsIDs) { + for (String s : a) { + if (!idType.contains(s)) { + tempItem = new FrequentItem(new String[]{s}, 1); + idType.add(s); + resultItemID.add(new String[]{s}); + } else { + // 支持度计数加1 + tempItem = itemMap.get(s); + tempItem.setCount(tempItem.getCount() + 1); + } + itemMap.put(s, tempItem); + } + } + // 将初始频繁项集转入到列表中,以便继续做连接运算 + for (Map.Entry entry : itemMap.entrySet()) { + list.add((FrequentItem) entry.getValue()); + } + // 按照商品ID进行排序,否则连接计算结果将会不一致,将会减少 + Collections.sort(list); + resultItem.addAll(list); + + String[] array1; + String[] array2; + String[] resultArray; + ArrayList tempIds; + ArrayList resultContainer; + // 总共要算到endNum项集 + endNum = list.size() - 1; + + while (currentNum < endNum) { + resultContainer = new ArrayList<>(); + for (int i = 0; i < list.size() - 1; i++) { + tempItem = list.get(i); + array1 = tempItem.getIdArray(); + for (int j = i + 1; j < list.size(); j++) { + tempIds = new ArrayList<>(); + array2 = list.get(j).getIdArray(); + for (int k = 0; k < array1.length; k++) { + // 如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作 + if (array1[k].equals(array2[k])) { + tempIds.add(array1[k]); + } else { + tempIds.add(array1[k]); + tempIds.add(array2[k]); + } + } + resultArray = new String[tempIds.size()]; + tempIds.toArray(resultArray); + + boolean isContain; + // 过滤不符合条件的的ID数组,包括重复的和长度不符合要求的 + if (resultArray.length == (array1.length + 1)) { + isContain = isIDArrayContains(resultContainer, + resultArray); + if (!isContain) { + resultContainer.add(resultArray); + } + } + } + } + + // 做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集 + list = cutItem(resultContainer); + currentNum++; + } + + // 输出频繁项集 + for (int k = 1; k <= currentNum; k++) { + System.out.println("频繁" + k + "项集:"); + for (FrequentItem i : resultItem) { + if (i.getLength() == k) { + System.out.print("{"); + for (String t : i.getIdArray()) { + System.out.print(t + ","); + } + System.out.print("},"); + } + } + System.out.println(); + } + } + + /** + * 判断列表结果中是否已经包含此数组 + * + * @param container ID数组容器 + * @param array 待比较数组 + */ + private boolean isIDArrayContains(ArrayList container, + String[] array){ + boolean isContain = true; + if (container.size() == 0) { + return false; + } + + for (String[] s : container) { + // 比较的视乎必须保证长度一样 + if (s.length != array.length) { + continue; + } + + isContain = true; + for (int i = 0; i < s.length; i++) { + // 只要有一个id不等,就算不相等 + if (!Objects.equals(s[i], array[i])) { + isContain = false; + break; + } + } + + // 如果已经判断是包含在容器中时,直接退出 + if (isContain) { + break; + } + } + + return isContain; + } + + /** + * 对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集 + */ + private ArrayList cutItem(ArrayList resultIds){ + String[] temp; + // 忽略的索引位置,以此构建子集 + int igNoreIndex; + FrequentItem tempItem; + // 剪枝生成新的频繁项集 + ArrayList newItem = new ArrayList<>(); + // 不符合要求的id + ArrayList deleteIdArray = new ArrayList<>(); + // 子项集是否也为频繁子项集 + boolean isContain = true; + + for (String[] array : resultIds) { + // 列举出其中的一个个的子项集,判断存在于频繁项集列表中 + temp = new String[array.length - 1]; + for (igNoreIndex = 0; igNoreIndex < array.length; igNoreIndex++) { + isContain = true; + for (int j = 0, k = 0; j < array.length; j++) { + if (j != igNoreIndex) { + temp[k] = array[j]; + k++; + } + } + + if (!isIDArrayContains(resultItemID, temp)) { + isContain = false; + break; + } + } + + if (!isContain) { + deleteIdArray.add(array); + } + } + + // 移除不符合条件的ID组合 + resultIds.removeAll(deleteIdArray); + + // 移除支持度计数不够的id集合 + int tempCount; + for (String[] array : resultIds) { + tempCount = 0; + for (String[] array2 : totalGoodsIDs) { + if (isStrArrayContain(array2, array)) { + tempCount++; + } + } + + // 如果支持度计数大于等于最小最小支持度计数则生成新的频繁项集,并加入结果集中 + if (tempCount >= minSupportCount) { + tempItem = new FrequentItem(array, tempCount); + newItem.add(tempItem); + resultItemID.add(array); + resultItem.add(tempItem); + } + } + + return newItem; + } + + /** + * 数组array2是否包含于array1中,不需要完全一样 + * + * @param array1 + * @param array2 + */ + private boolean isStrArrayContain(String[] array1, String[] array2){ + boolean isContain = true; + for (String s2 : array2) { + isContain = false; + for (String s1 : array1) { + // 只要s2字符存在于array1中,这个字符就算包含在array1中 + if (s2.equals(s1)) { + isContain = true; + break; + } + } + + // 一旦发现不包含的字符,则array2数组不包含于array1中 + if (!isContain) { + break; + } + } + + return isContain; + } + + /** + * 根据产生的频繁项集输出关联规则 + * + * @param minConf 最小置信度阈值 + */ + void printAttachRule(double minConf){ + // 进行连接和剪枝操作 + computeLink(); + + int count1; + int count2; + ArrayList childGroup1; + ArrayList childGroup2; + String[] group1; + String[] group2; + // 以最后一个频繁项集做关联规则的输出 + String[] array = resultItem.get(resultItem.size() - 1).getIdArray(); + // 子集总数,计算的时候除去自身和空集 + int totalNum = (int) Math.pow(2, array.length); + // 二进制数组,用来代表各个子集 + int[] binaryArray; + // 除去头和尾部 + for (int i = 1; i < totalNum - 1; i++) { + binaryArray = new int[array.length]; + numToBinaryArray(binaryArray, i); + + childGroup1 = new ArrayList<>(); + childGroup2 = new ArrayList<>(); + count1 = 0; + count2 = 0; + // 按照二进制位关系取出子集 + for (int j = 0; j < binaryArray.length; j++) { + if (binaryArray[j] == 1) { + childGroup1.add(array[j]); + } else { + childGroup2.add(array[j]); + } + } + + group1 = new String[childGroup1.size()]; + group2 = new String[childGroup2.size()]; + + childGroup1.toArray(group1); + childGroup2.toArray(group2); + + for (String[] a : totalGoodsIDs) { + if (isStrArrayContain(a, group1)) { + count1++; + + // 在group1的条件下,统计group2的事件发生次数 + if (isStrArrayContain(a, group2)) { + count2++; + } + } + } + + // {A}-->{B}的意思为在A的情况下发生B的概率 + System.out.print("{"); + for (String s : group1) { + System.out.print(s + ", "); + } + System.out.print("}-->"); + System.out.print("{"); + for (String s : group2) { + System.out.print(s + ", "); + } + System.out.print(MessageFormat.format( + "},confidence(置信度):{0}/{1}={2}", count2, count1, count2 + * 1.0 / count1)); + if (count2 * 1.0 / count1 < minConf) { + // 不符合要求,不是强规则 + System.out.println("由于此规则置信度未达到最小置信度的要求,不是强规则"); + } else { + System.out.println("为强规则"); + } + } + } + + /** + * 数字转为二进制形式 + * + * @param binaryArray 转化后的二进制数组形式 + * @param num 待转化数字 + */ + private void numToBinaryArray(int[] binaryArray, int num){ + int index = 0; + while (num != 0) { + binaryArray[index] = num % 2; + index++; + num /= 2; + } + } } diff --git a/AssociationAnalysis/DataMining_Apriori/Client.java b/AssociationAnalysis/DataMining_Apriori/Client.java index 7791c6f..726ae8d 100644 --- a/AssociationAnalysis/DataMining_Apriori/Client.java +++ b/AssociationAnalysis/DataMining_Apriori/Client.java @@ -1,15 +1,15 @@ -package DataMining_Apriori; +package AssociationAnalysis.DataMining_Apriori; /** * apriori关联规则挖掘算法调用类 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt"; - - AprioriTool tool = new AprioriTool(filePath, 2); - tool.printAttachRule(0.7); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/AssociationAnalysis/DataMining_Apriori/testInput.txt"; + + AprioriTool tool = new AprioriTool(filePath, 2); + tool.printAttachRule(0.7); + } } diff --git a/AssociationAnalysis/DataMining_Apriori/FrequentItem.java b/AssociationAnalysis/DataMining_Apriori/FrequentItem.java index 592d40d..bd3b9ab 100644 --- a/AssociationAnalysis/DataMining_Apriori/FrequentItem.java +++ b/AssociationAnalysis/DataMining_Apriori/FrequentItem.java @@ -1,56 +1,51 @@ -package DataMining_Apriori; +package AssociationAnalysis.DataMining_Apriori; /** * 频繁项集 - * - * @author lyq - * + * + * @author Qstar */ -public class FrequentItem implements Comparable{ - // 频繁项集的集合ID - private String[] idArray; - // 频繁项集的支持度计数 - private int count; - //频繁项集的长度,1项集或是2项集,亦或是3项集 - private int length; - - public FrequentItem(String[] idArray, int count){ - this.idArray = idArray; - this.count = count; - length = idArray.length; - } - - public String[] getIdArray() { - return idArray; - } - - public void setIdArray(String[] idArray) { - this.idArray = idArray; - } - - public int getCount() { - return count; - } - - public void setCount(int count) { - this.count = count; - } - - public int getLength() { - return length; - } - - public void setLength(int length) { - this.length = length; - } - - @Override - public int compareTo(FrequentItem o) { - // TODO Auto-generated method stub - Integer int1 = Integer.parseInt(this.getIdArray()[0]); - Integer int2 = Integer.parseInt(o.getIdArray()[0]); - - return int1.compareTo(int2); - } - +public class FrequentItem implements Comparable { + // 频繁项集的集合ID + private String[] idArray; + // 频繁项集的支持度计数 + private int count; + //频繁项集的长度,1项集或是2项集,亦或是3项集 + private int length; + + public FrequentItem(String[] idArray, int count){ + this.idArray = idArray; + this.count = count; + length = idArray.length; + } + + public String[] getIdArray(){ + return idArray; + } + + public int getCount(){ + return count; + } + + public void setCount(int count){ + this.count = count; + } + + public int getLength(){ + return length; + } + + public void setLength(int length){ + this.length = length; + } + + @Override + public int compareTo(FrequentItem o){ + // TODO Auto-generated method stub + Integer int1 = Integer.parseInt(this.getIdArray()[0]); + Integer int2 = Integer.parseInt(o.getIdArray()[0]); + + return int1.compareTo(int2); + } + } diff --git a/AssociationAnalysis/DataMining_FPTree/Client.java b/AssociationAnalysis/DataMining_FPTree/Client.java index 69789b7..fe4f10d 100644 --- a/AssociationAnalysis/DataMining_FPTree/Client.java +++ b/AssociationAnalysis/DataMining_FPTree/Client.java @@ -1,17 +1,17 @@ -package DataMining_FPTree; +package AssociationAnalysis.DataMining_FPTree; /** * FPTree频繁模式树算法 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt"; - //最小支持度阈值 - int minSupportCount = 2; - - FPTreeTool tool = new FPTreeTool(filePath, minSupportCount); - tool.startBuildingTree(); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/AssociationAnalysis/DataMining_FPTree/testInput.txt"; + //最小支持度阈值 + int minSupportCount = 2; + + FPTreeTool tool = new FPTreeTool(filePath, minSupportCount); + tool.startBuildingTree(); + } } diff --git a/AssociationAnalysis/DataMining_FPTree/FPTreeTool.java b/AssociationAnalysis/DataMining_FPTree/FPTreeTool.java index e53a946..3175955 100644 --- a/AssociationAnalysis/DataMining_FPTree/FPTreeTool.java +++ b/AssociationAnalysis/DataMining_FPTree/FPTreeTool.java @@ -1,4 +1,4 @@ -package DataMining_FPTree; +package AssociationAnalysis.DataMining_FPTree; import java.io.BufferedReader; import java.io.File; @@ -11,453 +11,439 @@ /** * FPTree算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class FPTreeTool { - // 输入数据文件位置 - private String filePath; - // 最小支持度阈值 - private int minSupportCount; - // 所有事物ID记录 - private ArrayList totalGoodsID; - // 各个ID的统计数目映射表项,计数用于排序使用 - private HashMap itemCountMap; - - public FPTreeTool(String filePath, int minSupportCount) { - this.filePath = filePath; - this.minSupportCount = minSupportCount; - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - String[] temp; - int count = 0; - itemCountMap = new HashMap<>(); - totalGoodsID = new ArrayList<>(); - for (String[] a : dataArray) { - temp = new String[a.length - 1]; - System.arraycopy(a, 1, temp, 0, a.length - 1); - totalGoodsID.add(temp); - for (String s : temp) { - if (!itemCountMap.containsKey(s)) { - count = 1; - } else { - count = ((int) itemCountMap.get(s)); - // 支持度计数加1 - count++; - } - // 更新表项 - itemCountMap.put(s, count); - } - } - } - - /** - * 根据事物记录构造FP树 - */ - private void buildFPTree(ArrayList suffixPattern, - ArrayList> transctionList) { - // 设置一个空根节点 - TreeNode rootNode = new TreeNode(null, 0); - int count = 0; - // 节点是否存在 - boolean isExist = false; - ArrayList childNodes; - ArrayList pathList; - // 相同类型节点链表,用于构造的新的FP树 - HashMap> linkedNode = new HashMap<>(); - HashMap countNode = new HashMap<>(); - // 根据事物记录,一步步构建FP树 - for (ArrayList array : transctionList) { - TreeNode searchedNode; - pathList = new ArrayList<>(); - for (TreeNode node : array) { - pathList.add(node); - nodeCounted(node, countNode); - searchedNode = searchNode(rootNode, pathList); - childNodes = searchedNode.getChildNodes(); - - if (childNodes == null) { - childNodes = new ArrayList<>(); - childNodes.add(node); - searchedNode.setChildNodes(childNodes); - node.setParentNode(searchedNode); - nodeAddToLinkedList(node, linkedNode); - } else { - isExist = false; - for (TreeNode node2 : childNodes) { - // 如果找到名称相同,则更新支持度计数 - if (node.getName().equals(node2.getName())) { - count = node2.getCount() + node.getCount(); - node2.setCount(count); - // 标识已找到节点位置 - isExist = true; - break; - } - } - - if (!isExist) { - // 如果没有找到,需添加子节点 - childNodes.add(node); - node.setParentNode(searchedNode); - nodeAddToLinkedList(node, linkedNode); - } - } - - } - } - - // 如果FP树已经是单条路径,则输出此时的频繁模式 - if (isSinglePath(rootNode)) { - printFrequentPattern(suffixPattern, rootNode); - System.out.println("-------"); - } else { - ArrayList> tList; - ArrayList sPattern; - if (suffixPattern == null) { - sPattern = new ArrayList<>(); - } else { - // 进行一个拷贝,避免互相引用的影响 - sPattern = (ArrayList) suffixPattern.clone(); - } - - // 利用节点链表构造新的事务 - for (Map.Entry entry : countNode.entrySet()) { - // 添加到后缀模式中 - sPattern.add((String) entry.getKey()); - //获取到了条件模式机,作为新的事务 - tList = getTransactionList((String) entry.getKey(), linkedNode); - - System.out.print("[后缀模式]:{"); - for(String s: sPattern){ - System.out.print(s + ", "); - } - System.out.print("}, 此时的条件模式基:"); - for(ArrayList tnList: tList){ - System.out.print("{"); - for(TreeNode n: tnList){ - System.out.print(n.getName() + ", "); - } - System.out.print("}, "); - } - System.out.println(); - // 递归构造FP树 - buildFPTree(sPattern, tList); - // 再次移除此项,构造不同的后缀模式,防止对后面造成干扰 - sPattern.remove((String) entry.getKey()); - } - } - } - - /** - * 将节点加入到同类型节点的链表中 - * - * @param node - * 待加入节点 - * @param linkedList - * 链表图 - */ - private void nodeAddToLinkedList(TreeNode node, - HashMap> linkedList) { - String name = node.getName(); - ArrayList list; - - if (linkedList.containsKey(name)) { - list = linkedList.get(name); - // 将node添加到此队列中 - list.add(node); - } else { - list = new ArrayList<>(); - list.add(node); - linkedList.put(name, list); - } - } - - /** - * 根据链表构造出新的事务 - * - * @param name - * 节点名称 - * @param linkedList - * 链表 - * @return - */ - private ArrayList> getTransactionList(String name, - HashMap> linkedList) { - ArrayList> tList = new ArrayList<>(); - ArrayList targetNode = linkedList.get(name); - ArrayList singleTansaction; - TreeNode temp; - - for (TreeNode node : targetNode) { - singleTansaction = new ArrayList<>(); - - temp = node; - while (temp.getParentNode().getName() != null) { - temp = temp.getParentNode(); - singleTansaction.add(new TreeNode(temp.getName(), 1)); - } - - // 按照支持度计数得反转一下 - Collections.reverse(singleTansaction); - - for (TreeNode node2 : singleTansaction) { - // 支持度计数调成与模式后缀一样 - node2.setCount(node.getCount()); - } - - if (singleTansaction.size() > 0) { - tList.add(singleTansaction); - } - } - - return tList; - } - - /** - * 节点计数 - * - * @param node - * 待加入节点 - * @param nodeCount - * 计数映射图 - */ - private void nodeCounted(TreeNode node, HashMap nodeCount) { - int count = 0; - String name = node.getName(); - - if (nodeCount.containsKey(name)) { - count = nodeCount.get(name); - count++; - } else { - count = 1; - } - - nodeCount.put(name, count); - } - - /** - * 显示决策树 - * - * @param node - * 待显示的节点 - * @param blankNum - * 行空格符,用于显示树型结构 - */ - private void showFPTree(TreeNode node, int blankNum) { - System.out.println(); - for (int i = 0; i < blankNum; i++) { - System.out.print("\t"); - } - System.out.print("--"); - System.out.print("--"); - - if (node.getChildNodes() == null) { - System.out.print("["); - System.out.print("I" + node.getName() + ":" + node.getCount()); - System.out.print("]"); - } else { - // 递归显示子节点 - // System.out.print("【" + node.getName() + "】"); - for (TreeNode childNode : node.getChildNodes()) { - showFPTree(childNode, 2 * blankNum); - } - } - - } - - /** - * 待插入节点的抵达位置节点,从根节点开始向下寻找待插入节点的位置 - * - * @param root - * @param list - * @return - */ - private TreeNode searchNode(TreeNode node, ArrayList list) { - ArrayList pathList = new ArrayList<>(); - TreeNode tempNode = null; - TreeNode firstNode = list.get(0); - boolean isExist = false; - // 重新转一遍,避免出现同一引用 - for (TreeNode node2 : list) { - pathList.add(node2); - } - - // 如果没有孩子节点,则直接返回,在此节点下添加子节点 - if (node.getChildNodes() == null) { - return node; - } - - for (TreeNode n : node.getChildNodes()) { - if (n.getName().equals(firstNode.getName()) && list.size() == 1) { - tempNode = node; - isExist = true; - break; - } else if (n.getName().equals(firstNode.getName())) { - // 还没有找到最后的位置,继续找 - pathList.remove(firstNode); - tempNode = searchNode(n, pathList); - return tempNode; - } - } - - // 如果没有找到,则新添加到孩子节点中 - if (!isExist) { - tempNode = node; - } - - return tempNode; - } - - /** - * 判断目前构造的FP树是否是单条路径的 - * - * @param rootNode - * 当前FP树的根节点 - * @return - */ - private boolean isSinglePath(TreeNode rootNode) { - // 默认是单条路径 - boolean isSinglePath = true; - ArrayList childList; - TreeNode node; - node = rootNode; - - while (node.getChildNodes() != null) { - childList = node.getChildNodes(); - if (childList.size() == 1) { - node = childList.get(0); - } else { - isSinglePath = false; - break; - } - } - - return isSinglePath; - } - - /** - * 开始构建FP树 - */ - public void startBuildingTree() { - ArrayList singleTransaction; - ArrayList> transactionList = new ArrayList<>(); - TreeNode tempNode; - int count = 0; - - for (String[] idArray : totalGoodsID) { - singleTransaction = new ArrayList<>(); - for (String id : idArray) { - count = itemCountMap.get(id); - tempNode = new TreeNode(id, count); - singleTransaction.add(tempNode); - } - - // 根据支持度数的多少进行排序 - Collections.sort(singleTransaction); - for (TreeNode node : singleTransaction) { - // 支持度计数重新归为1 - node.setCount(1); - } - transactionList.add(singleTransaction); - } - - buildFPTree(null, transactionList); - } - - /** - * 输出此单条路径下的频繁模式 - * - * @param suffixPattern - * 后缀模式 - * @param rootNode - * 单条路径FP树根节点 - */ - private void printFrequentPattern(ArrayList suffixPattern, - TreeNode rootNode) { - ArrayList idArray = new ArrayList<>(); - TreeNode temp; - temp = rootNode; - // 用于输出组合模式 - int length = 0; - int num = 0; - int[] binaryArray; - - while (temp.getChildNodes() != null) { - temp = temp.getChildNodes().get(0); - - // 筛选支持度系数大于最小阈值的值 - if (temp.getCount() >= minSupportCount) { - idArray.add(temp.getName()); - } - } - - length = idArray.size(); - num = (int) Math.pow(2, length); - for (int i = 0; i < num; i++) { - binaryArray = new int[length]; - numToBinaryArray(binaryArray, i); - - // 如果后缀模式只有1个,不能输出自身 - if (suffixPattern.size() == 1 && i == 0) { - continue; - } - - System.out.print("频繁模式:{【后缀模式:"); - // 先输出固有的后缀模式 - if (suffixPattern.size() > 1 - || (suffixPattern.size() == 1 && idArray.size() > 0)) { - for (String s : suffixPattern) { - System.out.print(s + ", "); - } - } - System.out.print("】"); - // 输出路径上的组合模式 - for (int j = 0; j < length; j++) { - if (binaryArray[j] == 1) { - System.out.print(idArray.get(j) + ", "); - } - } - System.out.println("}"); - } - } - - /** - * 数字转为二进制形式 - * - * @param binaryArray - * 转化后的二进制数组形式 - * @param num - * 待转化数字 - */ - private void numToBinaryArray(int[] binaryArray, int num) { - int index = 0; - while (num != 0) { - binaryArray[index] = num % 2; - index++; - num /= 2; - } - } +class FPTreeTool { + // 输入数据文件位置 + private String filePath; + // 最小支持度阈值 + private int minSupportCount; + // 所有事物ID记录 + private ArrayList totalGoodsID; + // 各个ID的统计数目映射表项,计数用于排序使用 + private HashMap itemCountMap; + + FPTreeTool(String filePath, int minSupportCount){ + this.filePath = filePath; + this.minSupportCount = minSupportCount; + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + String[] temp; + int count; + itemCountMap = new HashMap<>(); + totalGoodsID = new ArrayList<>(); + for (String[] a : dataArray) { + temp = new String[a.length - 1]; + System.arraycopy(a, 1, temp, 0, a.length - 1); + totalGoodsID.add(temp); + for (String s : temp) { + if (!itemCountMap.containsKey(s)) { + count = 1; + } else { + count = itemCountMap.get(s); + // 支持度计数加1 + count++; + } + // 更新表项 + itemCountMap.put(s, count); + } + } + } + + /** + * 根据事物记录构造FP树 + */ + private void buildFPTree(ArrayList suffixPattern, + ArrayList> transctionList){ + // 设置一个空根节点 + TreeNode rootNode = new TreeNode(null, 0); + int count; + // 节点是否存在 + boolean isExist; + ArrayList childNodes; + ArrayList pathList; + // 相同类型节点链表,用于构造的新的FP树 + HashMap> linkedNode = new HashMap<>(); + HashMap countNode = new HashMap<>(); + // 根据事物记录,一步步构建FP树 + for (ArrayList array : transctionList) { + TreeNode searchedNode; + pathList = new ArrayList<>(); + for (TreeNode node : array) { + pathList.add(node); + nodeCounted(node, countNode); + searchedNode = searchNode(rootNode, pathList); + childNodes = searchedNode.getChildNodes(); + + if (childNodes == null) { + childNodes = new ArrayList<>(); + childNodes.add(node); + searchedNode.setChildNodes(childNodes); + node.setParentNode(searchedNode); + nodeAddToLinkedList(node, linkedNode); + } else { + isExist = false; + for (TreeNode node2 : childNodes) { + // 如果找到名称相同,则更新支持度计数 + if (node.getName().equals(node2.getName())) { + count = node2.getCount() + node.getCount(); + node2.setCount(count); + // 标识已找到节点位置 + isExist = true; + break; + } + } + + if (!isExist) { + // 如果没有找到,需添加子节点 + childNodes.add(node); + node.setParentNode(searchedNode); + nodeAddToLinkedList(node, linkedNode); + } + } + + } + } + + // 如果FP树已经是单条路径,则输出此时的频繁模式 + if (isSinglePath(rootNode)) { + printFrequentPattern(suffixPattern, rootNode); + System.out.println("-------"); + } else { + ArrayList> tList; + ArrayList sPattern; + if (suffixPattern == null) { + sPattern = new ArrayList<>(); + } else { + // 进行一个拷贝,避免互相引用的影响 + sPattern = (ArrayList) suffixPattern.clone(); + } + + // 利用节点链表构造新的事务 + for (Map.Entry entry : countNode.entrySet()) { + // 添加到后缀模式中 + sPattern.add((String) entry.getKey()); + //获取到了条件模式机,作为新的事务 + tList = getTransactionList((String) entry.getKey(), linkedNode); + + System.out.print("[后缀模式]:{"); + for (String s : sPattern) { + System.out.print(s + ", "); + } + System.out.print("}, 此时的条件模式基:"); + for (ArrayList tnList : tList) { + System.out.print("{"); + for (TreeNode n : tnList) { + System.out.print(n.getName() + ", "); + } + System.out.print("}, "); + } + System.out.println(); + // 递归构造FP树 + buildFPTree(sPattern, tList); + // 再次移除此项,构造不同的后缀模式,防止对后面造成干扰 + sPattern.remove(entry.getKey()); + } + } + } + + /** + * 将节点加入到同类型节点的链表中 + * + * @param node 待加入节点 + * @param linkedList 链表图 + */ + private void nodeAddToLinkedList(TreeNode node, + HashMap> linkedList){ + String name = node.getName(); + ArrayList list; + + if (linkedList.containsKey(name)) { + list = linkedList.get(name); + // 将node添加到此队列中 + list.add(node); + } else { + list = new ArrayList<>(); + list.add(node); + linkedList.put(name, list); + } + } + + /** + * 根据链表构造出新的事务 + * + * @param name 节点名称 + * @param linkedList 链表 + * @return + */ + private ArrayList> getTransactionList(String name, + HashMap> linkedList){ + ArrayList> tList = new ArrayList<>(); + ArrayList targetNode = linkedList.get(name); + ArrayList singleTansaction; + TreeNode temp; + + for (TreeNode node : targetNode) { + singleTansaction = new ArrayList<>(); + + temp = node; + while (temp.getParentNode().getName() != null) { + temp = temp.getParentNode(); + singleTansaction.add(new TreeNode(temp.getName(), 1)); + } + + // 按照支持度计数得反转一下 + Collections.reverse(singleTansaction); + + for (TreeNode node2 : singleTansaction) { + // 支持度计数调成与模式后缀一样 + node2.setCount(node.getCount()); + } + + if (singleTansaction.size() > 0) { + tList.add(singleTansaction); + } + } + + return tList; + } + + /** + * 节点计数 + * + * @param node 待加入节点 + * @param nodeCount 计数映射图 + */ + private void nodeCounted(TreeNode node, HashMap nodeCount){ + int count; + String name = node.getName(); + + if (nodeCount.containsKey(name)) { + count = nodeCount.get(name); + count++; + } else { + count = 1; + } + + nodeCount.put(name, count); + } + + /** + * 显示决策树 + * + * @param node 待显示的节点 + * @param blankNum 行空格符,用于显示树型结构 + */ + private void showFPTree(TreeNode node, int blankNum){ + System.out.println(); + for (int i = 0; i < blankNum; i++) { + System.out.print("\t"); + } + System.out.print("--"); + System.out.print("--"); + + if (node.getChildNodes() == null) { + System.out.print("["); + System.out.print("I" + node.getName() + ":" + node.getCount()); + System.out.print("]"); + } else { + // 递归显示子节点 + // System.out.print("【" + node.getName() + "】"); + for (TreeNode childNode : node.getChildNodes()) { + showFPTree(childNode, 2 * blankNum); + } + } + + } + + /** + * 待插入节点的抵达位置节点,从根节点开始向下寻找待插入节点的位置 + * + * @param node + * @param list + * @return + */ + private TreeNode searchNode(TreeNode node, ArrayList list){ + ArrayList pathList = new ArrayList<>(); + TreeNode tempNode = null; + TreeNode firstNode = list.get(0); + boolean isExist = false; + // 重新转一遍,避免出现同一引用 + for (TreeNode node2 : list) { + pathList.add(node2); + } + + // 如果没有孩子节点,则直接返回,在此节点下添加子节点 + if (node.getChildNodes() == null) { + return node; + } + + for (TreeNode n : node.getChildNodes()) { + if (n.getName().equals(firstNode.getName()) && list.size() == 1) { + tempNode = node; + isExist = true; + break; + } else if (n.getName().equals(firstNode.getName())) { + // 还没有找到最后的位置,继续找 + pathList.remove(firstNode); + tempNode = searchNode(n, pathList); + return tempNode; + } + } + + // 如果没有找到,则新添加到孩子节点中 + if (!isExist) { + tempNode = node; + } + + return tempNode; + } + + /** + * 判断目前构造的FP树是否是单条路径的 + * + * @param rootNode 当前FP树的根节点 + * @return + */ + private boolean isSinglePath(TreeNode rootNode){ + // 默认是单条路径 + boolean isSinglePath = true; + ArrayList childList; + TreeNode node; + node = rootNode; + + while (node.getChildNodes() != null) { + childList = node.getChildNodes(); + if (childList.size() == 1) { + node = childList.get(0); + } else { + isSinglePath = false; + break; + } + } + + return isSinglePath; + } + + /** + * 开始构建FP树 + */ + void startBuildingTree(){ + ArrayList singleTransaction; + ArrayList> transactionList = new ArrayList<>(); + TreeNode tempNode; + int count; + + for (String[] idArray : totalGoodsID) { + singleTransaction = new ArrayList<>(); + for (String id : idArray) { + count = itemCountMap.get(id); + tempNode = new TreeNode(id, count); + singleTransaction.add(tempNode); + } + + // 根据支持度数的多少进行排序 + Collections.sort(singleTransaction); + for (TreeNode node : singleTransaction) { + // 支持度计数重新归为1 + node.setCount(1); + } + transactionList.add(singleTransaction); + } + + buildFPTree(null, transactionList); + } + + /** + * 输出此单条路径下的频繁模式 + * + * @param suffixPattern 后缀模式 + * @param rootNode 单条路径FP树根节点 + */ + private void printFrequentPattern(ArrayList suffixPattern, + TreeNode rootNode){ + ArrayList idArray = new ArrayList<>(); + TreeNode temp; + temp = rootNode; + // 用于输出组合模式 + int length; + int num; + int[] binaryArray; + + while (temp.getChildNodes() != null) { + temp = temp.getChildNodes().get(0); + + // 筛选支持度系数大于最小阈值的值 + if (temp.getCount() >= minSupportCount) { + idArray.add(temp.getName()); + } + } + + length = idArray.size(); + num = (int) Math.pow(2, length); + for (int i = 0; i < num; i++) { + binaryArray = new int[length]; + numToBinaryArray(binaryArray, i); + + // 如果后缀模式只有1个,不能输出自身 + if (suffixPattern.size() == 1 && i == 0) { + continue; + } + + System.out.print("频繁模式:{【后缀模式:"); + // 先输出固有的后缀模式 + if (suffixPattern.size() > 1 + || (suffixPattern.size() == 1 && idArray.size() > 0)) { + for (String s : suffixPattern) { + System.out.print(s + ", "); + } + } + System.out.print("】"); + // 输出路径上的组合模式 + for (int j = 0; j < length; j++) { + if (binaryArray[j] == 1) { + System.out.print(idArray.get(j) + ", "); + } + } + System.out.println("}"); + } + } + + /** + * 数字转为二进制形式 + * + * @param binaryArray 转化后的二进制数组形式 + * @param num 待转化数字 + */ + private void numToBinaryArray(int[] binaryArray, int num){ + int index = 0; + while (num != 0) { + binaryArray[index] = num % 2; + index++; + num /= 2; + } + } } diff --git a/AssociationAnalysis/DataMining_FPTree/TreeNode.java b/AssociationAnalysis/DataMining_FPTree/TreeNode.java index 154a94d..7201922 100644 --- a/AssociationAnalysis/DataMining_FPTree/TreeNode.java +++ b/AssociationAnalysis/DataMining_FPTree/TreeNode.java @@ -1,14 +1,14 @@ -package DataMining_FPTree; +package AssociationAnalysis.DataMining_FPTree; import java.util.ArrayList; /** * FP树节点 * - * @author lyq + * @author Qstar * */ -public class TreeNode implements Comparable, Cloneable{ +class TreeNode implements Comparable, Cloneable{ // 节点类别名称 private String name; // 计数数量 @@ -18,7 +18,7 @@ public class TreeNode implements Comparable, Cloneable{ // 孩子节点,可以为多个 private ArrayList childNodes; - public TreeNode(String name, int count){ + TreeNode(String name, int count){ this.name = name; this.count = count; } @@ -27,11 +27,7 @@ public String getName() { return name; } - public void setName(String name) { - this.name = name; - } - - public Integer getCount() { + public Integer getCount() { return count; } @@ -39,19 +35,19 @@ public void setCount(Integer count) { this.count = count; } - public TreeNode getParentNode() { + TreeNode getParentNode() { return parentNode; } - public void setParentNode(TreeNode parentNode) { + void setParentNode(TreeNode parentNode) { this.parentNode = parentNode; } - public ArrayList getChildNodes() { + ArrayList getChildNodes() { return childNodes; } - public void setChildNodes(ArrayList childNodes) { + void setChildNodes(ArrayList childNodes) { this.childNodes = childNodes; } diff --git a/BaggingAndBoosting/DataMining_AdaBoost/AdaBoostTool.java b/BaggingAndBoosting/DataMining_AdaBoost/AdaBoostTool.java index 9731e35..6665542 100644 --- a/BaggingAndBoosting/DataMining_AdaBoost/AdaBoostTool.java +++ b/BaggingAndBoosting/DataMining_AdaBoost/AdaBoostTool.java @@ -1,4 +1,4 @@ -package DataMining_AdaBoost; +package BaggingAndBoosting.DataMining_AdaBoost; import java.io.BufferedReader; import java.io.File; @@ -11,303 +11,293 @@ /** * AdaBoost提升算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class AdaBoostTool { - // 分类的类别,程序默认为正类1和负类-1 - public static final int CLASS_POSITIVE = 1; - public static final int CLASS_NEGTIVE = -1; - - // 事先假设的3个分类器(理论上应该重新对数据集进行训练得到) - public static final String CLASSIFICATION1 = "X=2.5"; - public static final String CLASSIFICATION2 = "X=7.5"; - public static final String CLASSIFICATION3 = "Y=5.5"; - - // 分类器组 - public static final String[] ClASSIFICATION = new String[] { - CLASSIFICATION1, CLASSIFICATION2, CLASSIFICATION3 }; - // 分类权重组 - private double[] CLASSIFICATION_WEIGHT; - - // 测试数据文件地址 - private String filePath; - // 误差率阈值 - private double errorValue; - // 所有的数据点 - private ArrayList totalPoint; - - public AdaBoostTool(String filePath, double errorValue) { - this.filePath = filePath; - this.errorValue = errorValue; - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - Point temp; - totalPoint = new ArrayList<>(); - for (String[] array : dataArray) { - temp = new Point(array[0], array[1], array[2]); - temp.setProbably(1.0 / dataArray.size()); - totalPoint.add(temp); - } - } - - /** - * 根据当前的误差值算出所得的权重 - * - * @param errorValue - * 当前划分的坐标点误差率 - * @return - */ - private double calculateWeight(double errorValue) { - double alpha = 0; - double temp = 0; - - temp = (1 - errorValue) / errorValue; - alpha = 0.5 * Math.log(temp); - - return alpha; - } - - /** - * 计算当前划分的误差率 - * - * @param pointMap - * 划分之后的点集 - * @param weight - * 本次划分得到的分类器权重 - * @return - */ - private double calculateErrorValue( - HashMap> pointMap) { - double resultValue = 0; - double temp = 0; - double weight = 0; - int tempClassType; - ArrayList pList; - for (Map.Entry entry : pointMap.entrySet()) { - tempClassType = (int) entry.getKey(); - - pList = (ArrayList) entry.getValue(); - for (Point p : pList) { - temp = p.getProbably(); - // 如果划分类型不相等,代表划错了 - if (tempClassType != p.getClassType()) { - resultValue += temp; - } - } - } - - weight = calculateWeight(resultValue); - for (Map.Entry entry : pointMap.entrySet()) { - tempClassType = (int) entry.getKey(); - - pList = (ArrayList) entry.getValue(); - for (Point p : pList) { - temp = p.getProbably(); - // 如果划分类型不相等,代表划错了 - if (tempClassType != p.getClassType()) { - // 划错的点的权重比例变大 - temp *= Math.exp(weight); - p.setProbably(temp); - } else { - // 划对的点的权重比减小 - temp *= Math.exp(-weight); - p.setProbably(temp); - } - } - } - - // 如果误差率没有小于阈值,继续处理 - dataNormalized(); - - return resultValue; - } - - /** - * 概率做归一化处理 - */ - private void dataNormalized() { - double sumProbably = 0; - double temp = 0; - - for (Point p : totalPoint) { - sumProbably += p.getProbably(); - } - - // 归一化处理 - for (Point p : totalPoint) { - temp = p.getProbably(); - p.setProbably(temp / sumProbably); - } - } - - /** - * 用AdaBoost算法得到的组合分类器对数据进行分类 - * - */ - public void adaBoostClassify() { - double value = 0; - Point p; - - calculateWeightArray(); - for (int i = 0; i < ClASSIFICATION.length; i++) { - System.out.println(MessageFormat.format("分类器{0}权重为:{1}", (i+1), CLASSIFICATION_WEIGHT[i])); - } - - for (int j = 0; j < totalPoint.size(); j++) { - p = totalPoint.get(j); - value = 0; - - for (int i = 0; i < ClASSIFICATION.length; i++) { - value += 1.0 * classifyData(ClASSIFICATION[i], p) - * CLASSIFICATION_WEIGHT[i]; - } - - //进行符号判断 - if (value > 0) { - System.out - .println(MessageFormat.format( - "点({0}, {1})的组合分类结果为:1,该点的实际分类为{2}", p.getX(), p.getY(), - p.getClassType())); - } else { - System.out.println(MessageFormat.format( - "点({0}, {1})的组合分类结果为:-1,该点的实际分类为{2}", p.getX(), p.getY(), - p.getClassType())); - } - } - } - - /** - * 计算分类器权重数组 - */ - private void calculateWeightArray() { - int tempClassType = 0; - double errorValue = 0; - ArrayList posPointList; - ArrayList negPointList; - HashMap> mapList; - CLASSIFICATION_WEIGHT = new double[ClASSIFICATION.length]; - - for (int i = 0; i < CLASSIFICATION_WEIGHT.length; i++) { - mapList = new HashMap<>(); - posPointList = new ArrayList<>(); - negPointList = new ArrayList<>(); - - for (Point p : totalPoint) { - tempClassType = classifyData(ClASSIFICATION[i], p); - - if (tempClassType == CLASS_POSITIVE) { - posPointList.add(p); - } else { - negPointList.add(p); - } - } - - mapList.put(CLASS_POSITIVE, posPointList); - mapList.put(CLASS_NEGTIVE, negPointList); - - if (i == 0) { - // 最开始的各个点的权重一样,所以传入0,使得e的0次方等于1 - errorValue = calculateErrorValue(mapList); - } else { - // 每次把上次计算所得的权重代入,进行概率的扩大或缩小 - errorValue = calculateErrorValue(mapList); - } - - // 计算当前分类器的所得权重 - CLASSIFICATION_WEIGHT[i] = calculateWeight(errorValue); - } - } - - /** - * 用各个子分类器进行分类 - * - * @param classification - * 分类器名称 - * @param p - * 待划分坐标点 - * @return - */ - private int classifyData(String classification, Point p) { - // 分割线所属坐标轴 - String position; - // 分割线的值 - double value = 0; - double posProbably = 0; - double negProbably = 0; - // 划分是否是大于一边的划分 - boolean isLarger = false; - String[] array; - ArrayList pList = new ArrayList<>(); - - array = classification.split("="); - position = array[0]; - value = Double.parseDouble(array[1]); - - if (position.equals("X")) { - if (p.getX() > value) { - isLarger = true; - } - - // 将训练数据中所有属于这边的点加入 - for (Point point : totalPoint) { - if (isLarger && point.getX() > value) { - pList.add(point); - } else if (!isLarger && point.getX() < value) { - pList.add(point); - } - } - } else if (position.equals("Y")) { - if (p.getY() > value) { - isLarger = true; - } - - // 将训练数据中所有属于这边的点加入 - for (Point point : totalPoint) { - if (isLarger && point.getY() > value) { - pList.add(point); - } else if (!isLarger && point.getY() < value) { - pList.add(point); - } - } - } - - for (Point p2 : pList) { - if (p2.getClassType() == CLASS_POSITIVE) { - posProbably++; - } else { - negProbably++; - } - } - - //分类按正负类数量进行划分 - if (posProbably > negProbably) { - return CLASS_POSITIVE; - } else { - return CLASS_NEGTIVE; - } - } +class AdaBoostTool { + // 分类的类别,程序默认为正类1和负类-1 + private static final int CLASS_POSITIVE = 1; + private static final int CLASS_NEGTIVE = -1; + + // 事先假设的3个分类器(理论上应该重新对数据集进行训练得到) + private static final String CLASSIFICATION1 = "X=2.5"; + private static final String CLASSIFICATION2 = "X=7.5"; + private static final String CLASSIFICATION3 = "Y=5.5"; + + // 分类器组 + private static final String[] ClASSIFICATION = new String[]{ + CLASSIFICATION1, CLASSIFICATION2, CLASSIFICATION3}; + // 分类权重组 + private double[] CLASSIFICATION_WEIGHT; + + // 测试数据文件地址 + private String filePath; + // 误差率阈值 + private double errorValue; + // 所有的数据点 + private ArrayList totalPoint; + + AdaBoostTool(String filePath, double errorValue){ + this.filePath = filePath; + this.errorValue = errorValue; + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + Point temp; + totalPoint = new ArrayList<>(); + for (String[] array : dataArray) { + temp = new Point(array[0], array[1], array[2]); + temp.setProbably(1.0 / dataArray.size()); + totalPoint.add(temp); + } + } + + /** + * 根据当前的误差值算出所得的权重 + * + * @param errorValue 当前划分的坐标点误差率 + * @return + */ + private double calculateWeight(double errorValue){ + double alpha; + double temp; + + temp = (1 - errorValue) / errorValue; + alpha = 0.5 * Math.log(temp); + + return alpha; + } + + /** + * @param pointMap + * @return weight 本次划分得到的分类器权重 + */ + private double calculateErrorValue( + HashMap> pointMap){ + double resultValue = 0; + double temp; + double weight; + int tempClassType; + ArrayList pList; + for (Map.Entry entry : pointMap.entrySet()) { + tempClassType = (int) entry.getKey(); + + pList = (ArrayList) entry.getValue(); + for (Point p : pList) { + temp = p.getProbably(); + // 如果划分类型不相等,代表划错了 + if (tempClassType != p.getClassType()) { + resultValue += temp; + } + } + } + + weight = calculateWeight(resultValue); + for (Map.Entry entry : pointMap.entrySet()) { + tempClassType = (int) entry.getKey(); + + pList = (ArrayList) entry.getValue(); + for (Point p : pList) { + temp = p.getProbably(); + // 如果划分类型不相等,代表划错了 + if (tempClassType != p.getClassType()) { + // 划错的点的权重比例变大 + temp *= Math.exp(weight); + p.setProbably(temp); + } else { + // 划对的点的权重比减小 + temp *= Math.exp(-weight); + p.setProbably(temp); + } + } + } + + // 如果误差率没有小于阈值,继续处理 + dataNormalized(); + + return resultValue; + } + + /** + * 概率做归一化处理 + */ + private void dataNormalized(){ + double sumProbably = 0; + double temp; + + for (Point p : totalPoint) { + sumProbably += p.getProbably(); + } + + // 归一化处理 + for (Point p : totalPoint) { + temp = p.getProbably(); + p.setProbably(temp / sumProbably); + } + } + + /** + * 用AdaBoost算法得到的组合分类器对数据进行分类 + */ + void adaBoostClassify(){ + double value; + Point p; + + calculateWeightArray(); + for (int i = 0; i < ClASSIFICATION.length; i++) { + System.out.println(MessageFormat.format("分类器{0}权重为:{1}", (i + 1), CLASSIFICATION_WEIGHT[i])); + } + + for (Point aTotalPoint : totalPoint) { + p = aTotalPoint; + value = 0; + + for (int i = 0; i < ClASSIFICATION.length; i++) { + value += 1.0 * classifyData(ClASSIFICATION[i], p) + * CLASSIFICATION_WEIGHT[i]; + } + + //进行符号判断 + if (value > 0) { + System.out + .println(MessageFormat.format( + "点({0}, {1})的组合分类结果为:1,该点的实际分类为{2}", p.getX(), p.getY(), + p.getClassType())); + } else { + System.out.println(MessageFormat.format( + "点({0}, {1})的组合分类结果为:-1,该点的实际分类为{2}", p.getX(), p.getY(), + p.getClassType())); + } + } + } + + /** + * 计算分类器权重数组 + */ + private void calculateWeightArray(){ + int tempClassType; + double errorValue; + ArrayList posPointList; + ArrayList negPointList; + HashMap> mapList; + CLASSIFICATION_WEIGHT = new double[ClASSIFICATION.length]; + + for (int i = 0; i < CLASSIFICATION_WEIGHT.length; i++) { + mapList = new HashMap<>(); + posPointList = new ArrayList<>(); + negPointList = new ArrayList<>(); + + for (Point p : totalPoint) { + tempClassType = classifyData(ClASSIFICATION[i], p); + + if (tempClassType == CLASS_POSITIVE) { + posPointList.add(p); + } else { + negPointList.add(p); + } + } + + mapList.put(CLASS_POSITIVE, posPointList); + mapList.put(CLASS_NEGTIVE, negPointList); + + if (i == 0) { + // 最开始的各个点的权重一样,所以传入0,使得e的0次方等于1 + errorValue = calculateErrorValue(mapList); + } else { + // 每次把上次计算所得的权重代入,进行概率的扩大或缩小 + errorValue = calculateErrorValue(mapList); + } + + // 计算当前分类器的所得权重 + CLASSIFICATION_WEIGHT[i] = calculateWeight(errorValue); + } + } + + /** + * 用各个子分类器进行分类 + * + * @param classification 分类器名称 + * @param p 待划分坐标点 + * @return + */ + private int classifyData(String classification, Point p){ + // 分割线所属坐标轴 + String position; + // 分割线的值 + double value; + double posProbably = 0; + double negProbably = 0; + // 划分是否是大于一边的划分 + boolean isLarger = false; + String[] array; + ArrayList pList = new ArrayList<>(); + + array = classification.split("="); + position = array[0]; + value = Double.parseDouble(array[1]); + + if (position.equals("X")) { + if (p.getX() > value) { + isLarger = true; + } + + // 将训练数据中所有属于这边的点加入 + for (Point point : totalPoint) { + if (isLarger && point.getX() > value) { + pList.add(point); + } else if (!isLarger && point.getX() < value) { + pList.add(point); + } + } + } else if (position.equals("Y")) { + if (p.getY() > value) { + isLarger = true; + } + + // 将训练数据中所有属于这边的点加入 + for (Point point : totalPoint) { + if (isLarger && point.getY() > value) { + pList.add(point); + } else if (!isLarger && point.getY() < value) { + pList.add(point); + } + } + } + + for (Point p2 : pList) { + if (p2.getClassType() == CLASS_POSITIVE) { + posProbably++; + } else { + negProbably++; + } + } + + //分类按正负类数量进行划分 + if (posProbably > negProbably) { + return CLASS_POSITIVE; + } else { + return CLASS_NEGTIVE; + } + } } diff --git a/BaggingAndBoosting/DataMining_AdaBoost/Client.java b/BaggingAndBoosting/DataMining_AdaBoost/Client.java index 30e98ec..f00bcdc 100644 --- a/BaggingAndBoosting/DataMining_AdaBoost/Client.java +++ b/BaggingAndBoosting/DataMining_AdaBoost/Client.java @@ -1,13 +1,13 @@ -package DataMining_AdaBoost; +package BaggingAndBoosting.DataMining_AdaBoost; /** * AdaBoost提升算法调用类 - * @author lyq + * @author Qstar * */ public class Client { public static void main(String[] agrs){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/BaggingAndBoosting/DataMining_AdaBoost/input.txt"; //误差率阈值 double errorValue = 0.2; diff --git a/BaggingAndBoosting/DataMining_AdaBoost/Point.java b/BaggingAndBoosting/DataMining_AdaBoost/Point.java index 7d352ab..f2fdac5 100644 --- a/BaggingAndBoosting/DataMining_AdaBoost/Point.java +++ b/BaggingAndBoosting/DataMining_AdaBoost/Point.java @@ -1,62 +1,57 @@ -package DataMining_AdaBoost; +package BaggingAndBoosting.DataMining_AdaBoost; /** * 坐标点类 - * - * @author lyq - * + * + * @author Qstar */ public class Point { - // 坐标点x坐标 - private int x; - // 坐标点y坐标 - private int y; - // 坐标点的分类类别 - private int classType; - //如果此节点被划错,他的误差率,不能用个数除以总数,因为不同坐标点的权重不一定相等 - private double probably; - - public Point(int x, int y, int classType){ - this.x = x; - this.y = y; - this.classType = classType; - } - - public Point(String x, String y, String classType){ - this.x = Integer.parseInt(x); - this.y = Integer.parseInt(y); - this.classType = Integer.parseInt(classType); - } - - public int getX() { - return x; - } - - public void setX(int x) { - this.x = x; - } - - public int getY() { - return y; - } - - public void setY(int y) { - this.y = y; - } - - public int getClassType() { - return classType; - } - - public void setClassType(int classType) { - this.classType = classType; - } - - public double getProbably() { - return probably; - } - - public void setProbably(double probably) { - this.probably = probably; - } + // 坐标点x坐标 + private int x; + // 坐标点y坐标 + private int y; + // 坐标点的分类类别 + private int classType; + //如果此节点被划错,他的误差率,不能用个数除以总数,因为不同坐标点的权重不一定相等 + private double probably; + + public Point(int x, int y, int classType){ + this.x = x; + this.y = y; + this.classType = classType; + } + + public Point(String x, String y, String classType){ + this.x = Integer.parseInt(x); + this.y = Integer.parseInt(y); + this.classType = Integer.parseInt(classType); + } + + public int getX(){ + return x; + } + + public void setX(int x){ + this.x = x; + } + + public int getY(){ + return y; + } + + public void setY(int y){ + this.y = y; + } + + int getClassType(){ + return classType; + } + + double getProbably(){ + return probably; + } + + void setProbably(double probably){ + this.probably = probably; + } } diff --git a/Classification/DataMining_CART/AttrNode.java b/Classification/DataMining_CART/AttrNode.java index c259ff1..5e60f9d 100644 --- a/Classification/DataMining_CART/AttrNode.java +++ b/Classification/DataMining_CART/AttrNode.java @@ -1,85 +1,83 @@ -package DataMining_CART; +package Classification.DataMining_CART; import java.util.ArrayList; /** * 回归分类树节点 - * - * @author lyq - * + * + * @author Qstar */ -public class AttrNode { - // 节点属性名字 - private String attrName; - // 节点索引标号 - private int nodeIndex; - //包含的叶子节点数 - private int leafNum; - // 节点误差率 - private double alpha; - // 父亲分类属性值 - private String parentAttrValue; - // 孩子节点 - private AttrNode[] childAttrNode; - // 数据记录索引 - private ArrayList dataIndex; - - public String getAttrName() { - return attrName; - } - - public void setAttrName(String attrName) { - this.attrName = attrName; - } - - public int getNodeIndex() { - return nodeIndex; - } - - public void setNodeIndex(int nodeIndex) { - this.nodeIndex = nodeIndex; - } - - public double getAlpha() { - return alpha; - } - - public void setAlpha(double alpha) { - this.alpha = alpha; - } - - public String getParentAttrValue() { - return parentAttrValue; - } - - public void setParentAttrValue(String parentAttrValue) { - this.parentAttrValue = parentAttrValue; - } - - public AttrNode[] getChildAttrNode() { - return childAttrNode; - } - - public void setChildAttrNode(AttrNode[] childAttrNode) { - this.childAttrNode = childAttrNode; - } - - public ArrayList getDataIndex() { - return dataIndex; - } - - public void setDataIndex(ArrayList dataIndex) { - this.dataIndex = dataIndex; - } - - public int getLeafNum() { - return leafNum; - } - - public void setLeafNum(int leafNum) { - this.leafNum = leafNum; - } - - - +class AttrNode { + // 节点属性名字 + private String attrName; + // 节点索引标号 + private int nodeIndex; + //包含的叶子节点数 + private int leafNum; + // 节点误差率 + private double alpha; + // 父亲分类属性值 + private String parentAttrValue; + // 孩子节点 + private AttrNode[] childAttrNode; + // 数据记录索引 + private ArrayList dataIndex; + + public String getAttrName(){ + return attrName; + } + + public void setAttrName(String attrName){ + this.attrName = attrName; + } + + int getNodeIndex(){ + return nodeIndex; + } + + void setNodeIndex(int nodeIndex){ + this.nodeIndex = nodeIndex; + } + + double getAlpha(){ + return alpha; + } + + void setAlpha(double alpha){ + this.alpha = alpha; + } + + String getParentAttrValue(){ + return parentAttrValue; + } + + void setParentAttrValue(String parentAttrValue){ + this.parentAttrValue = parentAttrValue; + } + + AttrNode[] getChildAttrNode(){ + return childAttrNode; + } + + void setChildAttrNode(AttrNode[] childAttrNode){ + this.childAttrNode = childAttrNode; + } + + ArrayList getDataIndex(){ + return dataIndex; + } + + void setDataIndex(ArrayList dataIndex){ + this.dataIndex = dataIndex; + } + + int getLeafNum(){ + return leafNum; + } + + void setLeafNum(int leafNum){ + this.leafNum = leafNum; + } + + } diff --git a/Classification/DataMining_CART/CARTTool.java b/Classification/DataMining_CART/CARTTool.java index 2ac9d7a..3ecf0f6 100644 --- a/Classification/DataMining_CART/CARTTool.java +++ b/Classification/DataMining_CART/CARTTool.java @@ -1,546 +1,514 @@ -package DataMining_CART; +package Classification.DataMining_CART; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.Map; -import java.util.Queue; - -import javax.lang.model.element.NestingKind; -import javax.swing.text.DefaultEditorKit.CutAction; -import javax.swing.text.html.MinimalHTMLWriter; +import java.util.*; /** * CART分类回归树算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class CARTTool { - // 类标号的值类型 - private final String YES = "Yes"; - private final String NO = "No"; - - // 所有属性的类型总数,在这里就是data源数据的列数 - private int attrNum; - private String filePath; - // 初始源数据,用一个二维字符数组存放模仿表格数据 - private String[][] data; - // 数据的属性行的名字 - private String[] attrNames; - // 每个属性的值所有类型 - private HashMap> attrValue; - - public CARTTool(String filePath) { - this.filePath = filePath; - attrValue = new HashMap<>(); - } - - /** - * 从文件中读取数据 - */ - public void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - data = new String[dataArray.size()][]; - dataArray.toArray(data); - attrNum = data[0].length; - attrNames = data[0]; +class CARTTool { + // 类标号的值类型 + private final String YES = "Yes"; + private final String NO = "No"; + + // 所有属性的类型总数,在这里就是data源数据的列数 + private int attrNum; + private String filePath; + // 初始源数据,用一个二维字符数组存放模仿表格数据 + private String[][] data; + // 数据的属性行的名字 + private String[] attrNames; + // 每个属性的值所有类型 + private HashMap> attrValue; + + CARTTool(String filePath){ + this.filePath = filePath; + attrValue = new HashMap<>(); + } + + /** + * 从文件中读取数据 + */ + public void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + data = new String[dataArray.size()][]; + dataArray.toArray(data); + attrNum = data[0].length; + attrNames = data[0]; /* - * for (int i = 0; i < data.length; i++) { for (int j = 0; j < + * for (int i = 0; i < data.length; i++) { for (int j = 0; j < * data[0].length; j++) { System.out.print(" " + data[i][j]); } * System.out.print("\n"); } */ - } - - /** - * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用 - */ - public void initAttrValue() { - ArrayList tempValues; - - // 按照列的方式,从左往右找 - for (int j = 1; j < attrNum; j++) { - // 从一列中的上往下开始寻找值 - tempValues = new ArrayList<>(); - for (int i = 1; i < data.length; i++) { - if (!tempValues.contains(data[i][j])) { - // 如果这个属性的值没有添加过,则添加 - tempValues.add(data[i][j]); - } - } - - // 一列属性的值已经遍历完毕,复制到map属性表中 - attrValue.put(data[0][j], tempValues); - } + } + + /** + * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用 + */ + private void initAttrValue(){ + ArrayList tempValues; + + // 按照列的方式,从左往右找 + for (int j = 1; j < attrNum; j++) { + // 从一列中的上往下开始寻找值 + tempValues = new ArrayList<>(); + for (int i = 1; i < data.length; i++) { + if (!tempValues.contains(data[i][j])) { + // 如果这个属性的值没有添加过,则添加 + tempValues.add(data[i][j]); + } + } + + // 一列属性的值已经遍历完毕,复制到map属性表中 + attrValue.put(data[0][j], tempValues); + } /* - * for (Map.Entry entry : attrValue.entrySet()) { + * for (Map.Entry entry : attrValue.entrySet()) { * System.out.println("key:value " + entry.getKey() + ":" + * entry.getValue()); } */ - } - - /** - * 计算机基尼指数 - * - * @param remainData - * 剩余数据 - * @param attrName - * 属性名称 - * @param value - * 属性值 - * @param beLongValue - * 分类是否属于此属性值 - * @return - */ - public double computeGini(String[][] remainData, String attrName, - String value, boolean beLongValue) { - // 实例总数 - int total = 0; - // 正实例数 - int posNum = 0; - // 负实例数 - int negNum = 0; - // 基尼指数 - double gini = 0; - - // 还是按列从左往右遍历属性 - for (int j = 1; j < attrNames.length; j++) { - // 找到了指定的属性 - if (attrName.equals(attrNames[j])) { - for (int i = 1; i < remainData.length; i++) { - // 统计正负实例按照属于和不属于值类型进行划分 - if ((beLongValue && remainData[i][j].equals(value)) - || (!beLongValue && !remainData[i][j].equals(value))) { - if (remainData[i][attrNames.length - 1].equals(YES)) { - // 判断此行数据是否为正实例 - posNum++; - } else { - negNum++; - } - } - } - } - } - - total = posNum + negNum; - double posProbobly = (double) posNum / total; - double negProbobly = (double) negNum / total; - gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly; - - // 返回计算基尼指数 - return gini; - } - - /** - * 计算属性划分的最小基尼指数,返回最小的属性值划分和最小的基尼指数,保存在一个数组中 - * - * @param remainData - * 剩余谁 - * @param attrName - * 属性名称 - * @return - */ - public String[] computeAttrGini(String[][] remainData, String attrName) { - String[] str = new String[2]; - // 最终该属性的划分类型值 - String spiltValue = ""; - // 临时变量 - int tempNum = 0; - // 保存属性的值划分时的最小的基尼指数 - double minGini = Integer.MAX_VALUE; - ArrayList valueTypes = attrValue.get(attrName); - // 属于此属性值的实例数 - HashMap belongNum = new HashMap<>(); - - for (String string : valueTypes) { - // 重新计数的时候,数字归0 - tempNum = 0; - // 按列从左往右遍历属性 - for (int j = 1; j < attrNames.length; j++) { - // 找到了指定的属性 - if (attrName.equals(attrNames[j])) { - for (int i = 1; i < remainData.length; i++) { - // 统计正负实例按照属于和不属于值类型进行划分 - if (remainData[i][j].equals(string)) { - tempNum++; - } - } - } - } - - belongNum.put(string, tempNum); - } - - double tempGini = 0; - double posProbably = 1.0; - double negProbably = 1.0; - for (String string : valueTypes) { - tempGini = 0; - - posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1); - negProbably = 1 - posProbably; - - tempGini += posProbably - * computeGini(remainData, attrName, string, true); - tempGini += negProbably - * computeGini(remainData, attrName, string, false); - - if (tempGini < minGini) { - minGini = tempGini; - spiltValue = string; - } - } - - str[0] = spiltValue; - str[1] = minGini + ""; - - return str; - } - - public void buildDecisionTree(AttrNode node, String parentAttrValue, - String[][] remainData, ArrayList remainAttr, - boolean beLongParentValue) { - // 属性划分值 - String valueType = ""; - // 划分属性名称 - String spiltAttrName = ""; - double minGini = Integer.MAX_VALUE; - double tempGini = 0; - // 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值 - String[] giniArray; - - if (beLongParentValue) { - node.setParentAttrValue(parentAttrValue); - } else { - node.setParentAttrValue("!" + parentAttrValue); - } - - if (remainAttr.size() == 0) { - if (remainData.length > 1) { - ArrayList indexArray = new ArrayList<>(); - for (int i = 1; i < remainData.length; i++) { - indexArray.add(remainData[i][0]); - } - node.setDataIndex(indexArray); - } - System.out.println("attr remain null"); - return; - } - - for (String str : remainAttr) { - giniArray = computeAttrGini(remainData, str); - tempGini = Double.parseDouble(giniArray[1]); - - if (tempGini < minGini) { - spiltAttrName = str; - minGini = tempGini; - valueType = giniArray[0]; - } - } - // 移除划分属性 - remainAttr.remove(spiltAttrName); - node.setAttrName(spiltAttrName); - - // 孩子节点,分类回归树中,每次二元划分,分出2个孩子节点 - AttrNode[] childNode = new AttrNode[2]; - String[][] rData; - - boolean[] bArray = new boolean[] { true, false }; - for (int i = 0; i < bArray.length; i++) { - // 二元划分属于属性值的划分 - rData = removeData(remainData, spiltAttrName, valueType, bArray[i]); - - boolean sameClass = true; - ArrayList indexArray = new ArrayList<>(); - for (int k = 1; k < rData.length; k++) { - indexArray.add(rData[k][0]); - // 判断是否为同一类的 - if (!rData[k][attrNames.length - 1] - .equals(rData[1][attrNames.length - 1])) { - // 只要有1个不相等,就不是同类型的 - sameClass = false; - break; - } - } - - childNode[i] = new AttrNode(); - if (!sameClass) { - // 创建新的对象属性,对象的同个引用会出错 - ArrayList rAttr = new ArrayList<>(); - for (String str : remainAttr) { - rAttr.add(str); - } - buildDecisionTree(childNode[i], valueType, rData, rAttr, - bArray[i]); - } else { - String pAtr = (bArray[i] ? valueType : "!" + valueType); - childNode[i].setParentAttrValue(pAtr); - childNode[i].setDataIndex(indexArray); - } - } - - node.setChildAttrNode(childNode); - } - - /** - * 属性划分完毕,进行数据的移除 - * - * @param srcData - * 源数据 - * @param attrName - * 划分的属性名称 - * @param valueType - * 属性的值类型 - * @parame beLongValue 分类是否属于此值类型 - */ - private String[][] removeData(String[][] srcData, String attrName, - String valueType, boolean beLongValue) { - String[][] desDataArray; - ArrayList desData = new ArrayList<>(); - // 待删除数据 - ArrayList selectData = new ArrayList<>(); - selectData.add(attrNames); - - // 数组数据转化到列表中,方便移除 - for (int i = 0; i < srcData.length; i++) { - desData.add(srcData[i]); - } - - // 还是从左往右一列列的查找 - for (int j = 1; j < attrNames.length; j++) { - if (attrNames[j].equals(attrName)) { - for (int i = 1; i < desData.size(); i++) { - if (desData.get(i)[j].equals(valueType)) { - // 如果匹配这个数据,则移除其他的数据 - selectData.add(desData.get(i)); - } - } - } - } - - if (beLongValue) { - desDataArray = new String[selectData.size()][]; - selectData.toArray(desDataArray); - } else { - // 属性名称行不移除 - selectData.remove(attrNames); - // 如果是划分不属于此类型的数据时,进行移除 - desData.removeAll(selectData); - desDataArray = new String[desData.size()][]; - desData.toArray(desDataArray); - } - - return desDataArray; - } - - public void startBuildingTree() { - readDataFile(); - initAttrValue(); - - ArrayList remainAttr = new ArrayList<>(); - // 添加属性,除了最后一个类标号属性 - for (int i = 1; i < attrNames.length - 1; i++) { - remainAttr.add(attrNames[i]); - } - - AttrNode rootNode = new AttrNode(); - buildDecisionTree(rootNode, "", data, remainAttr, false); - setIndexAndAlpah(rootNode, 0, false); - System.out.println("剪枝前:"); - showDecisionTree(rootNode, 1); - setIndexAndAlpah(rootNode, 0, true); - System.out.println("\n剪枝后:"); - showDecisionTree(rootNode, 1); - } - - /** - * 显示决策树 - * - * @param node - * 待显示的节点 - * @param blankNum - * 行空格符,用于显示树型结构 - */ - private void showDecisionTree(AttrNode node, int blankNum) { - System.out.println(); - for (int i = 0; i < blankNum; i++) { - System.out.print(" "); - } - System.out.print("--"); - // 显示分类的属性值 - if (node.getParentAttrValue() != null - && node.getParentAttrValue().length() > 0) { - System.out.print(node.getParentAttrValue()); - } else { - System.out.print("--"); - } - System.out.print("--"); - - if (node.getDataIndex() != null && node.getDataIndex().size() > 0) { - String i = node.getDataIndex().get(0); - System.out.print("【" + node.getNodeIndex() + "】类别:" - + data[Integer.parseInt(i)][attrNames.length - 1]); - System.out.print("["); - for (String index : node.getDataIndex()) { - System.out.print(index + ", "); - } - System.out.print("]"); - } else { - // 递归显示子节点 - System.out.print("【" + node.getNodeIndex() + ":" - + node.getAttrName() + "】"); - if (node.getChildAttrNode() != null) { - for (AttrNode childNode : node.getChildAttrNode()) { - showDecisionTree(childNode, 2 * blankNum); - } - } else { - System.out.print("【 Child Null】"); - } - } - } - - /** - * 为节点设置序列号,并计算每个节点的误差率,用于后面剪枝 - * - * @param node - * 开始的时候传入的是根节点 - * @param index - * 开始的索引号,从1开始 - * @param ifCutNode - * 是否需要剪枝 - */ - private void setIndexAndAlpah(AttrNode node, int index, boolean ifCutNode) { - AttrNode tempNode; - // 最小误差代价节点,即将被剪枝的节点 - AttrNode minAlphaNode = null; - double minAlpah = Integer.MAX_VALUE; - Queue nodeQueue = new LinkedList(); - - nodeQueue.add(node); - while (nodeQueue.size() > 0) { - index++; - // 从队列头部获取首个节点 - tempNode = nodeQueue.poll(); - tempNode.setNodeIndex(index); - if (tempNode.getChildAttrNode() != null) { - for (AttrNode childNode : tempNode.getChildAttrNode()) { - nodeQueue.add(childNode); - } - computeAlpha(tempNode); - if (tempNode.getAlpha() < minAlpah) { - minAlphaNode = tempNode; - minAlpah = tempNode.getAlpha(); - } else if (tempNode.getAlpha() == minAlpah) { - // 如果误差代价值一样,比较包含的叶子节点个数,剪枝有多叶子节点数的节点 - if (tempNode.getLeafNum() > minAlphaNode.getLeafNum()) { - minAlphaNode = tempNode; - } - } - } - } - - if (ifCutNode) { - // 进行树的剪枝,让其左右孩子节点为null - minAlphaNode.setChildAttrNode(null); - } - } - - /** - * 为非叶子节点计算误差代价,这里的后剪枝法用的是CCP代价复杂度剪枝 - * - * @param node - * 待计算的非叶子节点 - */ - private void computeAlpha(AttrNode node) { - double rt = 0; - double Rt = 0; - double alpha = 0; - // 当前节点的数据总数 - int sumNum = 0; - // 最少的偏差数 - int minNum = 0; - - ArrayList dataIndex; - ArrayList leafNodes = new ArrayList<>(); - - addLeafNode(node, leafNodes); - node.setLeafNum(leafNodes.size()); - for (AttrNode attrNode : leafNodes) { - dataIndex = attrNode.getDataIndex(); - - int num = 0; - sumNum += dataIndex.size(); - for (String s : dataIndex) { - // 统计分类数据中的正负实例数 - if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) { - num++; - } - } - minNum += num; - - // 取小数量的值部分 - if (1.0 * num / dataIndex.size() > 0.5) { - num = dataIndex.size() - num; - } - - rt += (1.0 * num / (data.length - 1)); - } - - //同样取出少偏差的那部分 - if (1.0 * minNum / sumNum > 0.5) { - minNum = sumNum - minNum; - } - - Rt = 1.0 * minNum / (data.length - 1); - alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1); - node.setAlpha(alpha); - } - - /** - * 筛选出节点所包含的叶子节点数 - * - * @param node - * 待筛选节点 - * @param leafNode - * 叶子节点列表容器 - */ - private void addLeafNode(AttrNode node, ArrayList leafNode) { - ArrayList dataIndex; - - if (node.getChildAttrNode() != null) { - for (AttrNode childNode : node.getChildAttrNode()) { - dataIndex = childNode.getDataIndex(); - if (dataIndex != null && dataIndex.size() > 0) { - // 说明此节点为叶子节点 - leafNode.add(childNode); - } else { - // 如果还是非叶子节点则继续递归调用 - addLeafNode(childNode, leafNode); - } - } - } - } + } + + /** + * 计算机基尼指数 + * + * @param remainData 剩余数据 + * @param attrName 属性名称 + * @param value 属性值 + * @param beLongValue 分类是否属于此属性值 + */ + private double computeGini(String[][] remainData, String attrName, + String value, boolean beLongValue){ + // 实例总数 + int total; + // 正实例数 + int posNum = 0; + // 负实例数 + int negNum = 0; + // 基尼指数 + double gini; + + // 还是按列从左往右遍历属性 + for (int j = 1; j < attrNames.length; j++) { + // 找到了指定的属性 + if (attrName.equals(attrNames[j])) { + for (int i = 1; i < remainData.length; i++) { + // 统计正负实例按照属于和不属于值类型进行划分 + if ((beLongValue && remainData[i][j].equals(value)) + || (!beLongValue && !remainData[i][j].equals(value))) { + if (remainData[i][attrNames.length - 1].equals(YES)) { + // 判断此行数据是否为正实例 + posNum++; + } else { + negNum++; + } + } + } + } + } + + total = posNum + negNum; + double posProbobly = (double) posNum / total; + double negProbobly = (double) negNum / total; + gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly; + + // 返回计算基尼指数 + return gini; + } + + /** + * 计算属性划分的最小基尼指数,返回最小的属性值划分和最小的基尼指数,保存在一个数组中 + * + * @param remainData 剩余谁 + * @param attrName 属性名称 + */ + private String[] computeAttrGini(String[][] remainData, String attrName){ + String[] str = new String[2]; + // 最终该属性的划分类型值 + String spiltValue = ""; + // 临时变量 + int tempNum; + // 保存属性的值划分时的最小的基尼指数 + double minGini = Integer.MAX_VALUE; + ArrayList valueTypes = attrValue.get(attrName); + // 属于此属性值的实例数 + HashMap belongNum = new HashMap<>(); + + for (String string : valueTypes) { + // 重新计数的时候,数字归0 + tempNum = 0; + // 按列从左往右遍历属性 + for (int j = 1; j < attrNames.length; j++) { + // 找到了指定的属性 + if (attrName.equals(attrNames[j])) { + for (int i = 1; i < remainData.length; i++) { + // 统计正负实例按照属于和不属于值类型进行划分 + if (remainData[i][j].equals(string)) { + tempNum++; + } + } + } + } + + belongNum.put(string, tempNum); + } + + double tempGini; + double posProbably; + double negProbably; + for (String string : valueTypes) { + tempGini = 0; + + posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1); + negProbably = 1 - posProbably; + + tempGini += posProbably + * computeGini(remainData, attrName, string, true); + tempGini += negProbably + * computeGini(remainData, attrName, string, false); + + if (tempGini < minGini) { + minGini = tempGini; + spiltValue = string; + } + } + + str[0] = spiltValue; + str[1] = minGini + ""; + + return str; + } + + private void buildDecisionTree(AttrNode node, String parentAttrValue, + String[][] remainData, ArrayList remainAttr, + boolean beLongParentValue){ + // 属性划分值 + String valueType = ""; + // 划分属性名称 + String spiltAttrName = ""; + double minGini = Integer.MAX_VALUE; + double tempGini; + // 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值 + String[] giniArray; + + if (beLongParentValue) { + node.setParentAttrValue(parentAttrValue); + } else { + node.setParentAttrValue("!" + parentAttrValue); + } + + if (remainAttr.size() == 0) { + if (remainData.length > 1) { + ArrayList indexArray = new ArrayList<>(); + for (int i = 1; i < remainData.length; i++) { + indexArray.add(remainData[i][0]); + } + node.setDataIndex(indexArray); + } + System.out.println("attr remain null"); + return; + } + + for (String str : remainAttr) { + giniArray = computeAttrGini(remainData, str); + tempGini = Double.parseDouble(giniArray[1]); + + if (tempGini < minGini) { + spiltAttrName = str; + minGini = tempGini; + valueType = giniArray[0]; + } + } + // 移除划分属性 + remainAttr.remove(spiltAttrName); + node.setAttrName(spiltAttrName); + + // 孩子节点,分类回归树中,每次二元划分,分出2个孩子节点 + AttrNode[] childNode = new AttrNode[2]; + String[][] rData; + + boolean[] bArray = new boolean[]{true, false}; + for (int i = 0; i < bArray.length; i++) { + // 二元划分属于属性值的划分 + rData = removeData(remainData, spiltAttrName, valueType, bArray[i]); + + boolean sameClass = true; + ArrayList indexArray = new ArrayList<>(); + for (int k = 1; k < rData.length; k++) { + indexArray.add(rData[k][0]); + // 判断是否为同一类的 + if (!rData[k][attrNames.length - 1] + .equals(rData[1][attrNames.length - 1])) { + // 只要有1个不相等,就不是同类型的 + sameClass = false; + break; + } + } + + childNode[i] = new AttrNode(); + if (!sameClass) { + // 创建新的对象属性,对象的同个引用会出错 + ArrayList rAttr = new ArrayList<>(); + for (String str : remainAttr) { + rAttr.add(str); + } + buildDecisionTree(childNode[i], valueType, rData, rAttr, + bArray[i]); + } else { + String pAtr = (bArray[i] ? valueType : "!" + valueType); + childNode[i].setParentAttrValue(pAtr); + childNode[i].setDataIndex(indexArray); + } + } + + node.setChildAttrNode(childNode); + } + + /** + * 属性划分完毕,进行数据的移除 + * + * @param srcData 源数据 + * @param attrName 划分的属性名称 + * @param valueType 属性的值类型 + * @param beLongValue 分类是否属于此值类型 + */ + private String[][] removeData(String[][] srcData, String attrName, + String valueType, boolean beLongValue){ + String[][] desDataArray; + ArrayList desData = new ArrayList<>(); + // 待删除数据 + ArrayList selectData = new ArrayList<>(); + selectData.add(attrNames); + + // 数组数据转化到列表中,方便移除 + Collections.addAll(desData, srcData); + + // 还是从左往右一列列的查找 + for (int j = 1; j < attrNames.length; j++) { + if (attrNames[j].equals(attrName)) { + for (int i = 1; i < desData.size(); i++) { + if (desData.get(i)[j].equals(valueType)) { + // 如果匹配这个数据,则移除其他的数据 + selectData.add(desData.get(i)); + } + } + } + } + + if (beLongValue) { + desDataArray = new String[selectData.size()][]; + selectData.toArray(desDataArray); + } else { + // 属性名称行不移除 + selectData.remove(attrNames); + // 如果是划分不属于此类型的数据时,进行移除 + desData.removeAll(selectData); + desDataArray = new String[desData.size()][]; + desData.toArray(desDataArray); + } + + return desDataArray; + } + + void startBuildingTree(){ + readDataFile(); + initAttrValue(); + + ArrayList remainAttr = new ArrayList<>(); + // 添加属性,除了最后一个类标号属性 + remainAttr.addAll(Arrays.asList(attrNames).subList(1, attrNames.length - 1)); + + AttrNode rootNode = new AttrNode(); + buildDecisionTree(rootNode, "", data, remainAttr, false); + setIndexAndAlpah(rootNode, 0, false); + System.out.println("剪枝前:"); + showDecisionTree(rootNode, 1); + setIndexAndAlpah(rootNode, 0, true); + System.out.println("\n剪枝后:"); + showDecisionTree(rootNode, 1); + } + + /** + * 显示决策树 + * + * @param node 待显示的节点 + * @param blankNum 行空格符,用于显示树型结构 + */ + private void showDecisionTree(AttrNode node, int blankNum){ + System.out.println(); + for (int i = 0; i < blankNum; i++) { + System.out.print(" "); + } + System.out.print("--"); + // 显示分类的属性值 + if (node.getParentAttrValue() != null + && node.getParentAttrValue().length() > 0) { + System.out.print(node.getParentAttrValue()); + } else { + System.out.print("--"); + } + System.out.print("--"); + + if (node.getDataIndex() != null && node.getDataIndex().size() > 0) { + String i = node.getDataIndex().get(0); + System.out.print("【" + node.getNodeIndex() + "】类别:" + + data[Integer.parseInt(i)][attrNames.length - 1]); + System.out.print("["); + for (String index : node.getDataIndex()) { + System.out.print(index + ", "); + } + System.out.print("]"); + } else { + // 递归显示子节点 + System.out.print("【" + node.getNodeIndex() + ":" + + node.getAttrName() + "】"); + if (node.getChildAttrNode() != null) { + for (AttrNode childNode : node.getChildAttrNode()) { + showDecisionTree(childNode, 2 * blankNum); + } + } else { + System.out.print("【 Child Null】"); + } + } + } + + /** + * 为节点设置序列号,并计算每个节点的误差率,用于后面剪枝 + * + * @param node 开始的时候传入的是根节点 + * @param index 开始的索引号,从1开始 + * @param ifCutNode 是否需要剪枝 + */ + private void setIndexAndAlpah(AttrNode node, int index, boolean ifCutNode){ + AttrNode tempNode; + // 最小误差代价节点,即将被剪枝的节点 + AttrNode minAlphaNode = null; + double minAlpah = Integer.MAX_VALUE; + Queue nodeQueue = new LinkedList<>(); + + nodeQueue.add(node); + while (nodeQueue.size() > 0) { + index++; + // 从队列头部获取首个节点 + tempNode = nodeQueue.poll(); + tempNode.setNodeIndex(index); + if (tempNode.getChildAttrNode() != null) { + Collections.addAll(nodeQueue, tempNode.getChildAttrNode()); + computeAlpha(tempNode); + if (tempNode.getAlpha() < minAlpah) { + minAlphaNode = tempNode; + minAlpah = tempNode.getAlpha(); + } else if (tempNode.getAlpha() == minAlpah) { + // 如果误差代价值一样,比较包含的叶子节点个数,剪枝有多叶子节点数的节点 + if (minAlphaNode != null && tempNode.getLeafNum() > minAlphaNode.getLeafNum()) { + minAlphaNode = tempNode; + } + } + } + } + + if (ifCutNode) { + // 进行树的剪枝,让其左右孩子节点为null + if (minAlphaNode != null) { + minAlphaNode.setChildAttrNode(null); + } + } + } + + /** + * 为非叶子节点计算误差代价,这里的后剪枝法用的是CCP代价复杂度剪枝 + * + * @param node 待计算的非叶子节点 + */ + private void computeAlpha(AttrNode node){ + double rt = 0; + double Rt; + double alpha; + // 当前节点的数据总数 + int sumNum = 0; + // 最少的偏差数 + int minNum = 0; + + ArrayList dataIndex; + ArrayList leafNodes = new ArrayList<>(); + + addLeafNode(node, leafNodes); + node.setLeafNum(leafNodes.size()); + for (AttrNode attrNode : leafNodes) { + dataIndex = attrNode.getDataIndex(); + + int num = 0; + sumNum += dataIndex.size(); + for (String s : dataIndex) { + // 统计分类数据中的正负实例数 + if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) { + num++; + } + } + minNum += num; + + // 取小数量的值部分 + if (1.0 * num / dataIndex.size() > 0.5) { + num = dataIndex.size() - num; + } + + rt += (1.0 * num / (data.length - 1)); + } + + //同样取出少偏差的那部分 + if (1.0 * minNum / sumNum > 0.5) { + minNum = sumNum - minNum; + } + + Rt = 1.0 * minNum / (data.length - 1); + alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1); + node.setAlpha(alpha); + } + + /** + * 筛选出节点所包含的叶子节点数 + * + * @param node 待筛选节点 + * @param leafNode 叶子节点列表容器 + */ + private void addLeafNode(AttrNode node, ArrayList leafNode){ + ArrayList dataIndex; + + if (node.getChildAttrNode() != null) { + for (AttrNode childNode : node.getChildAttrNode()) { + dataIndex = childNode.getDataIndex(); + if (dataIndex != null && dataIndex.size() > 0) { + // 说明此节点为叶子节点 + leafNode.add(childNode); + } else { + // 如果还是非叶子节点则继续递归调用 + addLeafNode(childNode, leafNode); + } + } + } + } } diff --git a/Classification/DataMining_CART/Client.java b/Classification/DataMining_CART/Client.java index 0aabb0c..5c474a8 100644 --- a/Classification/DataMining_CART/Client.java +++ b/Classification/DataMining_CART/Client.java @@ -1,11 +1,11 @@ -package DataMining_CART; +package Classification.DataMining_CART; public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - - CARTTool tool = new CARTTool(filePath); - - tool.startBuildingTree(); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Classification/DataMining_CART/input.txt"; + + CARTTool tool = new CARTTool(filePath); + + tool.startBuildingTree(); + } } diff --git a/Classification/DataMining_ID3/AttrNode.java b/Classification/DataMining_ID3/AttrNode.java index dfe2ce9..4018f81 100644 --- a/Classification/DataMining_ID3/AttrNode.java +++ b/Classification/DataMining_ID3/AttrNode.java @@ -1,51 +1,51 @@ -package DataMing_ID3; +package Classification.DataMining_ID3; import java.util.ArrayList; /** * 属性节点,不是叶子节点 - * @author lyq * + * @author Qstar */ -public class AttrNode { - //当前属性的名字 - private String attrName; - //父节点的分类属性值 - private String parentAttrValue; - //属性子节点 - private AttrNode[] childAttrNode; - //孩子叶子节点 - private ArrayList childDataIndex; - - public String getAttrName() { - return attrName; - } - - public void setAttrName(String attrName) { - this.attrName = attrName; - } - - public AttrNode[] getChildAttrNode() { - return childAttrNode; - } - - public void setChildAttrNode(AttrNode[] childAttrNode) { - this.childAttrNode = childAttrNode; - } - - public String getParentAttrValue() { - return parentAttrValue; - } - - public void setParentAttrValue(String parentAttrValue) { - this.parentAttrValue = parentAttrValue; - } - - public ArrayList getChildDataIndex() { - return childDataIndex; - } - - public void setChildDataIndex(ArrayList childDataIndex) { - this.childDataIndex = childDataIndex; - } +class AttrNode { + //当前属性的名字 + private String attrName; + //父节点的分类属性值 + private String parentAttrValue; + //属性子节点 + private AttrNode[] childAttrNode; + //孩子叶子节点 + private ArrayList childDataIndex; + + public String getAttrName(){ + return attrName; + } + + public void setAttrName(String attrName){ + this.attrName = attrName; + } + + AttrNode[] getChildAttrNode(){ + return childAttrNode; + } + + void setChildAttrNode(AttrNode[] childAttrNode){ + this.childAttrNode = childAttrNode; + } + + String getParentAttrValue(){ + return parentAttrValue; + } + + void setParentAttrValue(String parentAttrValue){ + this.parentAttrValue = parentAttrValue; + } + + ArrayList getChildDataIndex(){ + return childDataIndex; + } + + void setChildDataIndex(ArrayList childDataIndex){ + this.childDataIndex = childDataIndex; + } } diff --git a/Classification/DataMining_ID3/Client.java b/Classification/DataMining_ID3/Client.java index 9b70211..47aca9c 100644 --- a/Classification/DataMining_ID3/Client.java +++ b/Classification/DataMining_ID3/Client.java @@ -1,15 +1,15 @@ -package DataMing_ID3; +package Classification.DataMining_ID3; /** * ID3决策树分类算法测试场景类 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - - ID3Tool tool = new ID3Tool(filePath); - tool.startBuildingTree(true); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Classification/DataMining_ID3/input.txt"; + + ID3Tool tool = new ID3Tool(filePath); + tool.startBuildingTree(true); + } } diff --git a/Classification/DataMining_ID3/DataNode.java b/Classification/DataMining_ID3/DataNode.java index e29ab51..a623706 100644 --- a/Classification/DataMining_ID3/DataNode.java +++ b/Classification/DataMining_ID3/DataNode.java @@ -1,17 +1,17 @@ -package DataMing_ID3; +package Classification.DataMining_ID3; /** * 存放数据的叶子节点 - * @author lyq * + * @author Qstar */ public class DataNode { - /** - * 数据的标号 - */ - private int dataIndex; - - public DataNode(int dataIndex){ - this.dataIndex = dataIndex; - } + /** + * 数据的标号 + */ + private int dataIndex; + + public DataNode(int dataIndex){ + this.dataIndex = dataIndex; + } } diff --git a/Classification/DataMining_ID3/ID3Tool.java b/Classification/DataMining_ID3/ID3Tool.java index 67cfad7..0284eb5 100644 --- a/Classification/DataMining_ID3/ID3Tool.java +++ b/Classification/DataMining_ID3/ID3Tool.java @@ -1,447 +1,422 @@ -package DataMing_ID3; +package Classification.DataMining_ID3; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; -import java.util.Iterator; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Set; /** * ID3算法实现类 - * - * @author lyq - * + * + * @author Qstar */ -public class ID3Tool { - // 类标号的值类型 - private final String YES = "Yes"; - private final String NO = "No"; - - // 所有属性的类型总数,在这里就是data源数据的列数 - private int attrNum; - private String filePath; - // 初始源数据,用一个二维字符数组存放模仿表格数据 - private String[][] data; - // 数据的属性行的名字 - private String[] attrNames; - // 每个属性的值所有类型 - private HashMap> attrValue; - - public ID3Tool(String filePath) { - this.filePath = filePath; - attrValue = new HashMap<>(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - data = new String[dataArray.size()][]; - dataArray.toArray(data); - attrNum = data[0].length; - attrNames = data[0]; +class ID3Tool { + // 类标号的值类型 + private final String YES = "Yes"; + private final String NO = "No"; + + // 所有属性的类型总数,在这里就是data源数据的列数 + private int attrNum; + private String filePath; + // 初始源数据,用一个二维字符数组存放模仿表格数据 + private String[][] data; + // 数据的属性行的名字 + private String[] attrNames; + // 每个属性的值所有类型 + private HashMap> attrValue; + + ID3Tool(String filePath){ + this.filePath = filePath; + attrValue = new HashMap<>(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + data = new String[dataArray.size()][]; + dataArray.toArray(data); + attrNum = data[0].length; + attrNames = data[0]; /* - * for(int i=0; i tempValues; - - // 按照列的方式,从左往右找 - for (int j = 1; j < attrNum; j++) { - // 从一列中的上往下开始寻找值 - tempValues = new ArrayList<>(); - for (int i = 1; i < data.length; i++) { - if (!tempValues.contains(data[i][j])) { - // 如果这个属性的值没有添加过,则添加 - tempValues.add(data[i][j]); - } - } - - // 一列属性的值已经遍历完毕,复制到map属性表中 - attrValue.put(data[0][j], tempValues); - } + } + + /** + * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用 + */ + private void initAttrValue(){ + ArrayList tempValues; + + // 按照列的方式,从左往右找 + for (int j = 1; j < attrNum; j++) { + // 从一列中的上往下开始寻找值 + tempValues = new ArrayList<>(); + for (int i = 1; i < data.length; i++) { + if (!tempValues.contains(data[i][j])) { + // 如果这个属性的值没有添加过,则添加 + tempValues.add(data[i][j]); + } + } + + // 一列属性的值已经遍历完毕,复制到map属性表中 + attrValue.put(data[0][j], tempValues); + } /* - * for(Map.Entry entry : attrValue.entrySet()){ + * for(Map.Entry entry : attrValue.entrySet()){ * System.out.println("key:value " + entry.getKey() + ":" + * entry.getValue()); } */ - } - - /** - * 计算数据按照不同方式划分的熵 - * - * @param remainData - * 剩余的数据 - * @param attrName - * 待划分的属性,在算信息增益的时候会使用到 - * @param attrValue - * 划分的子属性值 - * @param isParent - * 是否分子属性划分还是原来不变的划分 - */ - private double computeEntropy(String[][] remainData, String attrName, - String value, boolean isParent) { - // 实例总数 - int total = 0; - // 正实例数 - int posNum = 0; - // 负实例数 - int negNum = 0; - - // 还是按列从左往右遍历属性 - for (int j = 1; j < attrNames.length; j++) { - // 找到了指定的属性 - if (attrName.equals(attrNames[j])) { - for (int i = 1; i < remainData.length; i++) { - // 如果是父结点直接计算熵或者是通过子属性划分计算熵,这时要进行属性值的过滤 - if (isParent - || (!isParent && remainData[i][j].equals(value))) { - if (remainData[i][attrNames.length - 1].equals(YES)) { - // 判断此行数据是否为正实例 - posNum++; - } else { - negNum++; - } - } - } - } - } - - total = posNum + negNum; - double posProbobly = (double) posNum / total; - double negProbobly = (double) negNum / total; - - if (posProbobly == 1 || posProbobly == 0) { - // 如果数据全为同种类型,则熵为0,否则带入下面的公式会报错 - return 0; - } - - double entropyValue = -posProbobly * Math.log(posProbobly) - / Math.log(2.0) - negProbobly * Math.log(negProbobly) - / Math.log(2.0); - - // 返回计算所得熵 - return entropyValue; - } - - /** - * 为某个属性计算信息增益 - * - * @param remainData - * 剩余的数据 - * @param value - * 待划分的属性名称 - * @return - */ - private double computeGain(String[][] remainData, String value) { - double gainValue = 0; - // 源熵的大小将会与属性划分后进行比较 - double entropyOri = 0; - // 子划分熵和 - double childEntropySum = 0; - // 属性子类型的个数 - int childValueNum = 0; - // 属性值的种数 - ArrayList attrTypes = attrValue.get(value); - // 子属性对应的权重比 - HashMap ratioValues = new HashMap<>(); - - for (int i = 0; i < attrTypes.size(); i++) { - // 首先都统一计数为0 - ratioValues.put(attrTypes.get(i), 0); - } - - // 还是按照一列,从左往右遍历 - for (int j = 1; j < attrNames.length; j++) { - // 判断是否到了划分的属性列 - if (value.equals(attrNames[j])) { - for (int i = 1; i <= remainData.length - 1; i++) { - childValueNum = ratioValues.get(remainData[i][j]); - // 增加个数并且重新存入 - childValueNum++; - ratioValues.put(remainData[i][j], childValueNum); - } - } - } - - // 计算原熵的大小 - entropyOri = computeEntropy(remainData, value, null, true); - for (int i = 0; i < attrTypes.size(); i++) { - double ratio = (double) ratioValues.get(attrTypes.get(i)) - / (remainData.length - 1); - childEntropySum += ratio - * computeEntropy(remainData, value, attrTypes.get(i), false); - - // System.out.println("ratio:value: " + ratio + " " + - // computeEntropy(remainData, value, - // attrTypes.get(i), false)); - } - - // 二者熵相减就是信息增益 - gainValue = entropyOri - childEntropySum; - return gainValue; - } - - /** - * 计算信息增益比 - * - * @param remainData - * 剩余数据 - * @param value - * 待划分属性 - * @return - */ - private double computeGainRatio(String[][] remainData, String value) { - double gain = 0; - double spiltInfo = 0; - int childValueNum = 0; - // 属性值的种数 - ArrayList attrTypes = attrValue.get(value); - // 子属性对应的权重比 - HashMap ratioValues = new HashMap<>(); - - for (int i = 0; i < attrTypes.size(); i++) { - // 首先都统一计数为0 - ratioValues.put(attrTypes.get(i), 0); - } - - // 还是按照一列,从左往右遍历 - for (int j = 1; j < attrNames.length; j++) { - // 判断是否到了划分的属性列 - if (value.equals(attrNames[j])) { - for (int i = 1; i <= remainData.length - 1; i++) { - childValueNum = ratioValues.get(remainData[i][j]); - // 增加个数并且重新存入 - childValueNum++; - ratioValues.put(remainData[i][j], childValueNum); - } - } - } - - // 计算信息增益 - gain = computeGain(remainData, value); - // 计算分裂信息,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀): - for (int i = 0; i < attrTypes.size(); i++) { - double ratio = (double) ratioValues.get(attrTypes.get(i)) - / (remainData.length - 1); - spiltInfo += -ratio * Math.log(ratio) / Math.log(2.0); - } - - // 计算机信息增益率 - return gain / spiltInfo; - } - - /** - * 利用源数据构造决策树 - */ - private void buildDecisionTree(AttrNode node, String parentAttrValue, - String[][] remainData, ArrayList remainAttr, boolean isID3) { - node.setParentAttrValue(parentAttrValue); - - String attrName = ""; - double gainValue = 0; - double tempValue = 0; - - // 如果只有1个属性则直接返回 - if (remainAttr.size() == 1) { - System.out.println("attr null"); - return; - } - - // 选择剩余属性中信息增益最大的作为下一个分类的属性 - for (int i = 0; i < remainAttr.size(); i++) { - // 判断是否用ID3算法还是C4.5算法 - if (isID3) { - // ID3算法采用的是按照信息增益的值来比 - tempValue = computeGain(remainData, remainAttr.get(i)); - } else { - // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足 - tempValue = computeGainRatio(remainData, remainAttr.get(i)); - } - - if (tempValue > gainValue) { - gainValue = tempValue; - attrName = remainAttr.get(i); - } - } - - node.setAttrName(attrName); - ArrayList valueTypes = attrValue.get(attrName); - remainAttr.remove(attrName); - - AttrNode[] childNode = new AttrNode[valueTypes.size()]; - String[][] rData; - for (int i = 0; i < valueTypes.size(); i++) { - // 移除非此值类型的数据 - rData = removeData(remainData, attrName, valueTypes.get(i)); - - childNode[i] = new AttrNode(); - boolean sameClass = true; - ArrayList indexArray = new ArrayList<>(); - for (int k = 1; k < rData.length; k++) { - indexArray.add(rData[k][0]); - // 判断是否为同一类的 - if (!rData[k][attrNames.length - 1] - .equals(rData[1][attrNames.length - 1])) { - // 只要有1个不相等,就不是同类型的 - sameClass = false; - break; - } - } - - if (!sameClass) { - // 创建新的对象属性,对象的同个引用会出错 - ArrayList rAttr = new ArrayList<>(); - for (String str : remainAttr) { - rAttr.add(str); - } - - buildDecisionTree(childNode[i], valueTypes.get(i), rData, - rAttr, isID3); - } else { - // 如果是同种类型,则直接为数据节点 - childNode[i].setParentAttrValue(valueTypes.get(i)); - childNode[i].setChildDataIndex(indexArray); - } - - } - node.setChildAttrNode(childNode); - } - - /** - * 属性划分完毕,进行数据的移除 - * - * @param srcData - * 源数据 - * @param attrName - * 划分的属性名称 - * @param valueType - * 属性的值类型 - */ - private String[][] removeData(String[][] srcData, String attrName, - String valueType) { - String[][] desDataArray; - ArrayList desData = new ArrayList<>(); - // 待删除数据 - ArrayList selectData = new ArrayList<>(); - selectData.add(attrNames); - - // 数组数据转化到列表中,方便移除 - for (int i = 0; i < srcData.length; i++) { - desData.add(srcData[i]); - } - - // 还是从左往右一列列的查找 - for (int j = 1; j < attrNames.length; j++) { - if (attrNames[j].equals(attrName)) { - for (int i = 1; i < desData.size(); i++) { - if (desData.get(i)[j].equals(valueType)) { - // 如果匹配这个数据,则移除其他的数据 - selectData.add(desData.get(i)); - } - } - } - } - - desDataArray = new String[selectData.size()][]; - selectData.toArray(desDataArray); - - return desDataArray; - } - - /** - * 开始构建决策树 - * - * @param isID3 - * 是否采用ID3算法构架决策树 - */ - public void startBuildingTree(boolean isID3) { - readDataFile(); - initAttrValue(); - - ArrayList remainAttr = new ArrayList<>(); - // 添加属性,除了最后一个类标号属性 - for (int i = 1; i < attrNames.length - 1; i++) { - remainAttr.add(attrNames[i]); - } - - AttrNode rootNode = new AttrNode(); - buildDecisionTree(rootNode, "", data, remainAttr, isID3); - showDecisionTree(rootNode, 1); - } - - /** - * 显示决策树 - * - * @param node - * 待显示的节点 - * @param blankNum - * 行空格符,用于显示树型结构 - */ - private void showDecisionTree(AttrNode node, int blankNum) { - System.out.println(); - for (int i = 0; i < blankNum; i++) { - System.out.print("\t"); - } - System.out.print("--"); - // 显示分类的属性值 - if (node.getParentAttrValue() != null - && node.getParentAttrValue().length() > 0) { - System.out.print(node.getParentAttrValue()); - } else { - System.out.print("--"); - } - System.out.print("--"); - - if (node.getChildDataIndex() != null - && node.getChildDataIndex().size() > 0) { - String i = node.getChildDataIndex().get(0); - System.out.print("类别:" - + data[Integer.parseInt(i)][attrNames.length - 1]); - System.out.print("["); - for (String index : node.getChildDataIndex()) { - System.out.print(index + ", "); - } - System.out.print("]"); - } else { - // 递归显示子节点 - System.out.print("【" + node.getAttrName() + "】"); - for (AttrNode childNode : node.getChildAttrNode()) { - showDecisionTree(childNode, 2 * blankNum); - } - } - - } + } + + /** + * 计算数据按照不同方式划分的熵 + * + * @param remainData 剩余的数据 + * @param attrName 待划分的属性,在算信息增益的时候会使用到 + * @param value 划分的子属性值 + * @param isParent 是否分子属性划分还是原来不变的划分 + */ + private double computeEntropy(String[][] remainData, String attrName, + String value, boolean isParent){ + // 实例总数 + int total; + // 正实例数 + int posNum = 0; + // 负实例数 + int negNum = 0; + + // 还是按列从左往右遍历属性 + for (int j = 1; j < attrNames.length; j++) { + // 找到了指定的属性 + if (attrName.equals(attrNames[j])) { + for (int i = 1; i < remainData.length; i++) { + // 如果是父结点直接计算熵或者是通过子属性划分计算熵,这时要进行属性值的过滤 + if (isParent + || (!isParent && remainData[i][j].equals(value))) { + if (remainData[i][attrNames.length - 1].equals(YES)) { + // 判断此行数据是否为正实例 + posNum++; + } else { + negNum++; + } + } + } + } + } + + total = posNum + negNum; + double posProbobly = (double) posNum / total; + double negProbobly = (double) negNum / total; + + if (posProbobly == 1 || posProbobly == 0) { + // 如果数据全为同种类型,则熵为0,否则带入下面的公式会报错 + return 0; + } + + // 返回计算所得熵 + return -posProbobly * Math.log(posProbobly) + / Math.log(2.0) - negProbobly * Math.log(negProbobly) + / Math.log(2.0); + } + + /** + * 为某个属性计算信息增益 + * + * @param remainData 剩余的数据 + * @param value 待划分的属性名称 + */ + private double computeGain(String[][] remainData, String value){ + double gainValue; + // 源熵的大小将会与属性划分后进行比较 + double entropyOri; + // 子划分熵和 + double childEntropySum = 0; + // 属性子类型的个数 + int childValueNum; + // 属性值的种数 + ArrayList attrTypes = attrValue.get(value); + // 子属性对应的权重比 + HashMap ratioValues = new HashMap<>(); + + for (String attrType : attrTypes) { + // 首先都统一计数为0 + ratioValues.put(attrType, 0); + } + + // 还是按照一列,从左往右遍历 + for (int j = 1; j < attrNames.length; j++) { + // 判断是否到了划分的属性列 + if (value.equals(attrNames[j])) { + for (int i = 1; i <= remainData.length - 1; i++) { + childValueNum = ratioValues.get(remainData[i][j]); + // 增加个数并且重新存入 + childValueNum++; + ratioValues.put(remainData[i][j], childValueNum); + } + } + } + + // 计算原熵的大小 + entropyOri = computeEntropy(remainData, value, null, true); + for (String attrType : attrTypes) { + double ratio = (double) ratioValues.get(attrType) + / (remainData.length - 1); + childEntropySum += ratio + * computeEntropy(remainData, value, attrType, false); + + // System.out.println("ratio:value: " + ratio + " " + + // computeEntropy(remainData, value, + // attrTypes.get(i), false)); + } + + // 二者熵相减就是信息增益 + gainValue = entropyOri - childEntropySum; + return gainValue; + } + + /** + * 计算信息增益比 + * + * @param remainData 剩余数据 + * @param value 待划分属性 + */ + private double computeGainRatio(String[][] remainData, String value){ + double gain; + double spiltInfo = 0; + int childValueNum; + // 属性值的种数 + ArrayList attrTypes = attrValue.get(value); + // 子属性对应的权重比 + HashMap ratioValues = new HashMap<>(); + + for (String attrType : attrTypes) { + // 首先都统一计数为0 + ratioValues.put(attrType, 0); + } + + // 还是按照一列,从左往右遍历 + for (int j = 1; j < attrNames.length; j++) { + // 判断是否到了划分的属性列 + if (value.equals(attrNames[j])) { + for (int i = 1; i <= remainData.length - 1; i++) { + childValueNum = ratioValues.get(remainData[i][j]); + // 增加个数并且重新存入 + childValueNum++; + ratioValues.put(remainData[i][j], childValueNum); + } + } + } + + // 计算信息增益 + gain = computeGain(remainData, value); + // 计算分裂信息,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀): + for (String attrType : attrTypes) { + double ratio = (double) ratioValues.get(attrType) + / (remainData.length - 1); + spiltInfo += -ratio * Math.log(ratio) / Math.log(2.0); + } + + // 计算机信息增益率 + return gain / spiltInfo; + } + + /** + * 利用源数据构造决策树 + */ + private void buildDecisionTree(AttrNode node, String parentAttrValue, + String[][] remainData, ArrayList remainAttr, boolean isID3){ + node.setParentAttrValue(parentAttrValue); + + String attrName = ""; + double gainValue = 0; + double tempValue; + + // 如果只有1个属性则直接返回 + if (remainAttr.size() == 1) { + System.out.println("attr null"); + return; + } + + // 选择剩余属性中信息增益最大的作为下一个分类的属性 + for (String aRemainAttr : remainAttr) { + // 判断是否用ID3算法还是C4.5算法 + if (isID3) { + // ID3算法采用的是按照信息增益的值来比 + tempValue = computeGain(remainData, aRemainAttr); + } else { + // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足 + tempValue = computeGainRatio(remainData, aRemainAttr); + } + + if (tempValue > gainValue) { + gainValue = tempValue; + attrName = aRemainAttr; + } + } + + node.setAttrName(attrName); + ArrayList valueTypes = attrValue.get(attrName); + remainAttr.remove(attrName); + + AttrNode[] childNode = new AttrNode[valueTypes.size()]; + String[][] rData; + for (int i = 0; i < valueTypes.size(); i++) { + // 移除非此值类型的数据 + rData = removeData(remainData, attrName, valueTypes.get(i)); + + childNode[i] = new AttrNode(); + boolean sameClass = true; + ArrayList indexArray = new ArrayList<>(); + for (int k = 1; k < rData.length; k++) { + indexArray.add(rData[k][0]); + // 判断是否为同一类的 + if (!rData[k][attrNames.length - 1] + .equals(rData[1][attrNames.length - 1])) { + // 只要有1个不相等,就不是同类型的 + sameClass = false; + break; + } + } + + if (!sameClass) { + // 创建新的对象属性,对象的同个引用会出错 + ArrayList rAttr = new ArrayList<>(); + for (String str : remainAttr) { + rAttr.add(str); + } + + buildDecisionTree(childNode[i], valueTypes.get(i), rData, + rAttr, isID3); + } else { + // 如果是同种类型,则直接为数据节点 + childNode[i].setParentAttrValue(valueTypes.get(i)); + childNode[i].setChildDataIndex(indexArray); + } + + } + node.setChildAttrNode(childNode); + } + + /** + * 属性划分完毕,进行数据的移除 + * + * @param srcData 源数据 + * @param attrName 划分的属性名称 + * @param valueType 属性的值类型 + */ + private String[][] removeData(String[][] srcData, String attrName, + String valueType){ + String[][] desDataArray; + ArrayList desData = new ArrayList<>(); + // 待删除数据 + ArrayList selectData = new ArrayList<>(); + selectData.add(attrNames); + + // 数组数据转化到列表中,方便移除 + Collections.addAll(desData, srcData); + + // 还是从左往右一列列的查找 + for (int j = 1; j < attrNames.length; j++) { + if (attrNames[j].equals(attrName)) { + for (int i = 1; i < desData.size(); i++) { + if (desData.get(i)[j].equals(valueType)) { + // 如果匹配这个数据,则移除其他的数据 + selectData.add(desData.get(i)); + } + } + } + } + + desDataArray = new String[selectData.size()][]; + selectData.toArray(desDataArray); + + return desDataArray; + } + + /** + * 开始构建决策树 + * + * @param isID3 是否采用ID3算法构架决策树 + */ + void startBuildingTree(boolean isID3){ + readDataFile(); + initAttrValue(); + + ArrayList remainAttr = new ArrayList<>(); + // 添加属性,除了最后一个类标号属性 + remainAttr.addAll(Arrays.asList(attrNames).subList(1, attrNames.length - 1)); + + AttrNode rootNode = new AttrNode(); + buildDecisionTree(rootNode, "", data, remainAttr, isID3); + showDecisionTree(rootNode, 1); + } + + /** + * 显示决策树 + * + * @param node 待显示的节点 + * @param blankNum 行空格符,用于显示树型结构 + */ + private void showDecisionTree(AttrNode node, int blankNum){ + System.out.println(); + for (int i = 0; i < blankNum; i++) { + System.out.print("\t"); + } + System.out.print("--"); + // 显示分类的属性值 + if (node.getParentAttrValue() != null + && node.getParentAttrValue().length() > 0) { + System.out.print(node.getParentAttrValue()); + } else { + System.out.print("--"); + } + System.out.print("--"); + + if (node.getChildDataIndex() != null + && node.getChildDataIndex().size() > 0) { + String i = node.getChildDataIndex().get(0); + System.out.print("类别:" + + data[Integer.parseInt(i)][attrNames.length - 1]); + System.out.print("["); + for (String index : node.getChildDataIndex()) { + System.out.print(index + ", "); + } + System.out.print("]"); + } else { + // 递归显示子节点 + System.out.print("【" + node.getAttrName() + "】"); + for (AttrNode childNode : node.getChildAttrNode()) { + showDecisionTree(childNode, 2 * blankNum); + } + } + + } } diff --git a/Classification/DataMining_KNN/Client.java b/Classification/DataMining_KNN/Client.java index 03e1d5a..51f8574 100644 --- a/Classification/DataMining_KNN/Client.java +++ b/Classification/DataMining_KNN/Client.java @@ -1,26 +1,16 @@ -package DataMining_KNN; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; - +package Classification.DataMining_KNN; /** * k最近邻算法场景类型 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String trainDataPath = "C:\\Users\\lyq\\Desktop\\icon\\trainInput.txt"; - String testDataPath = "C:\\Users\\lyq\\Desktop\\icon\\testinput.txt"; - - KNNTool tool = new KNNTool(trainDataPath, testDataPath); - tool.knnCompute(3); - - } - - + public static void main(String[] args){ + String trainDataPath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Classification/DataMining_KNN/trainInput.txt"; + String testDataPath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Classification/DataMining_KNN/testinput.txt"; + KNNTool tool = new KNNTool(trainDataPath, testDataPath); + tool.knnCompute(3); + } } diff --git a/Classification/DataMining_KNN/KNNTool.java b/Classification/DataMining_KNN/KNNTool.java index bb316fb..af24af5 100644 --- a/Classification/DataMining_KNN/KNNTool.java +++ b/Classification/DataMining_KNN/KNNTool.java @@ -1,200 +1,186 @@ -package DataMining_KNN; +package Classification.DataMining_KNN; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; import java.util.Collections; -import java.util.Comparator; import java.util.HashMap; import java.util.Map; -import org.apache.activemq.filter.ComparisonExpression; - /** * k最近邻算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class KNNTool { - // 为4个类别设置权重,默认权重比一致 - public int[] classWeightArray = new int[] { 1, 1, 1, 1 }; - // 测试数据 - private String testDataPath; - // 训练集数据地址 - private String trainDataPath; - // 分类的不同类型 - private ArrayList classTypes; - // 结果数据 - private ArrayList resultSamples; - // 训练集数据列表容器 - private ArrayList trainSamples; - // 训练集数据 - private String[][] trainData; - // 测试集数据 - private String[][] testData; - - public KNNTool(String trainDataPath, String testDataPath) { - this.trainDataPath = trainDataPath; - this.testDataPath = testDataPath; - readDataFormFile(); - } - - /** - * 从文件中阅读测试数和训练数据集 - */ - private void readDataFormFile() { - ArrayList tempArray; - - tempArray = fileDataToArray(trainDataPath); - trainData = new String[tempArray.size()][]; - tempArray.toArray(trainData); - - classTypes = new ArrayList<>(); - for (String[] s : tempArray) { - if (!classTypes.contains(s[0])) { - // 添加类型 - classTypes.add(s[0]); - } - } - - tempArray = fileDataToArray(testDataPath); - testData = new String[tempArray.size()][]; - tempArray.toArray(testData); - } - - /** - * 将文件转为列表数据输出 - * - * @param filePath - * 数据文件的内容 - */ - private ArrayList fileDataToArray(String filePath) { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - return dataArray; - } - - /** - * 计算样本特征向量的欧几里得距离 - * - * @param f1 - * 待比较样本1 - * @param f2 - * 待比较样本2 - * @return - */ - private int computeEuclideanDistance(Sample s1, Sample s2) { - String[] f1 = s1.getFeatures(); - String[] f2 = s2.getFeatures(); - // 欧几里得距离 - int distance = 0; - - for (int i = 0; i < f1.length; i++) { - int subF1 = Integer.parseInt(f1[i]); - int subF2 = Integer.parseInt(f2[i]); - - distance += (subF1 - subF2) * (subF1 - subF2); - } - - return distance; - } - - /** - * 计算K最近邻 - * @param k - * 在多少的k范围内 - */ - public void knnCompute(int k) { - String className = ""; - String[] tempF = null; - Sample temp; - resultSamples = new ArrayList<>(); - trainSamples = new ArrayList<>(); - // 分类类别计数 - HashMap classCount; - // 类别权重比 - HashMap classWeight = new HashMap<>(); - // 首先讲测试数据转化到结果数据中 - for (String[] s : testData) { - temp = new Sample(s); - resultSamples.add(temp); - } - - for (String[] s : trainData) { - className = s[0]; - tempF = new String[s.length - 1]; - System.arraycopy(s, 1, tempF, 0, s.length - 1); - temp = new Sample(className, tempF); - trainSamples.add(temp); - } - - // 离样本最近排序的的训练集数据 - ArrayList kNNSample = new ArrayList<>(); - // 计算训练数据集中离样本数据最近的K个训练集数据 - for (Sample s : resultSamples) { - classCount = new HashMap<>(); - int index = 0; - for (String type : classTypes) { - // 开始时计数为0 - classCount.put(type, 0); - classWeight.put(type, classWeightArray[index++]); - } - for (Sample tS : trainSamples) { - int dis = computeEuclideanDistance(s, tS); - tS.setDistance(dis); - } - - Collections.sort(trainSamples); - kNNSample.clear(); - // 挑选出前k个数据作为分类标准 - for (int i = 0; i < trainSamples.size(); i++) { - if (i < k) { - kNNSample.add(trainSamples.get(i)); - } else { - break; - } - } - // 判定K个训练数据的多数的分类标准 - for (Sample s1 : kNNSample) { - int num = classCount.get(s1.getClassName()); - // 进行分类权重的叠加,默认类别权重平等,可自行改变,近的权重大,远的权重小 - num += classWeight.get(s1.getClassName()); - classCount.put(s1.getClassName(), num); - } - - int maxCount = 0; - // 筛选出k个训练集数据中最多的一个分类 - for (Map.Entry entry : classCount.entrySet()) { - if ((Integer) entry.getValue() > maxCount) { - maxCount = (Integer) entry.getValue(); - s.setClassName((String) entry.getKey()); - } - } - - System.out.print("测试数据特征:"); - for (String s1 : s.getFeatures()) { - System.out.print(s1 + " "); - } - System.out.println("分类:" + s.getClassName()); - } - } +class KNNTool { + // 为4个类别设置权重,默认权重比一致 + private int[] classWeightArray = new int[]{1, 1, 1, 1}; + // 测试数据 + private String testDataPath; + // 训练集数据地址 + private String trainDataPath; + // 分类的不同类型 + private ArrayList classTypes; + // 训练集数据 + private String[][] trainData; + // 测试集数据 + private String[][] testData; + + KNNTool(String trainDataPath, String testDataPath){ + this.trainDataPath = trainDataPath; + this.testDataPath = testDataPath; + readDataFormFile(); + } + + /** + * 从文件中阅读测试数和训练数据集 + */ + private void readDataFormFile(){ + ArrayList tempArray; + + tempArray = fileDataToArray(trainDataPath); + trainData = new String[tempArray.size()][]; + tempArray.toArray(trainData); + + classTypes = new ArrayList<>(); + for (String[] s : tempArray) { + if (!classTypes.contains(s[0])) { + // 添加类型 + classTypes.add(s[0]); + } + } + + tempArray = fileDataToArray(testDataPath); + testData = new String[tempArray.size()][]; + tempArray.toArray(testData); + } + + /** + * 将文件转为列表数据输出 + * + * @param filePath 数据文件的内容 + */ + private ArrayList fileDataToArray(String filePath){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + return dataArray; + } + + /** + * 计算样本特征向量的欧几里得距离 + * + * @param s1 待比较样本1 + * @param s2 待比较样本2 + */ + private int computeEuclideanDistance(Sample s1, Sample s2){ + String[] f1 = s1.getFeatures(); + String[] f2 = s2.getFeatures(); + // 欧几里得距离 + int distance = 0; + + for (int i = 0; i < f1.length; i++) { + int subF1 = Integer.parseInt(f1[i]); + int subF2 = Integer.parseInt(f2[i]); + + distance += (subF1 - subF2) * (subF1 - subF2); + } + + return distance; + } + + /** + * 计算K最近邻 + * + * @param k 在多少的k范围内 + */ + void knnCompute(int k){ + String className; + String[] tempF; + Sample temp; + ArrayList resultSamples = new ArrayList<>(); + ArrayList trainSamples = new ArrayList<>(); + // 分类类别计数 + HashMap classCount; + // 类别权重比 + HashMap classWeight = new HashMap<>(); + // 首先讲测试数据转化到结果数据中 + for (String[] s : testData) { + temp = new Sample(s); + resultSamples.add(temp); + } + + for (String[] s : trainData) { + className = s[0]; + tempF = new String[s.length - 1]; + System.arraycopy(s, 1, tempF, 0, s.length - 1); + temp = new Sample(className, tempF); + trainSamples.add(temp); + } + + // 离样本最近排序的的训练集数据 + ArrayList kNNSample = new ArrayList<>(); + // 计算训练数据集中离样本数据最近的K个训练集数据 + for (Sample s : resultSamples) { + classCount = new HashMap<>(); + int index = 0; + for (String type : classTypes) { + // 开始时计数为0 + classCount.put(type, 0); + classWeight.put(type, classWeightArray[index++]); + } + for (Sample tS : trainSamples) { + int dis = computeEuclideanDistance(s, tS); + tS.setDistance(dis); + } + + Collections.sort(trainSamples); + kNNSample.clear(); + // 挑选出前k个数据作为分类标准 + for (int i = 0; i < trainSamples.size(); i++) { + if (i < k) { + kNNSample.add(trainSamples.get(i)); + } else { + break; + } + } + // 判定K个训练数据的多数的分类标准 + for (Sample s1 : kNNSample) { + int num = classCount.get(s1.getClassName()); + // 进行分类权重的叠加,默认类别权重平等,可自行改变,近的权重大,远的权重小 + num += classWeight.get(s1.getClassName()); + classCount.put(s1.getClassName(), num); + } + + int maxCount = 0; + // 筛选出k个训练集数据中最多的一个分类 + for (Map.Entry entry : classCount.entrySet()) { + if ((Integer) entry.getValue() > maxCount) { + maxCount = (Integer) entry.getValue(); + s.setClassName((String) entry.getKey()); + } + } + + System.out.print("测试数据特征:"); + for (String s1 : s.getFeatures()) { + System.out.print(s1 + " "); + } + System.out.println("分类:" + s.getClassName()); + } + } } diff --git a/Classification/DataMining_KNN/Sample.java b/Classification/DataMining_KNN/Sample.java index c4e185d..bf6e3e9 100644 --- a/Classification/DataMining_KNN/Sample.java +++ b/Classification/DataMining_KNN/Sample.java @@ -1,57 +1,52 @@ -package DataMining_KNN; +package Classification.DataMining_KNN; /** * 样本数据类 - * - * @author lyq - * + * + * @author Qstar */ -public class Sample implements Comparable{ - // 样本数据的分类名称 - private String className; - // 样本数据的特征向量 - private String[] features; - //测试样本之间的间距值,以此做排序 - private Integer distance; - - public Sample(String[] features){ - this.features = features; - } - - public Sample(String className, String[] features){ - this.className = className; - this.features = features; - } - - public String getClassName() { - return className; - } - - public void setClassName(String className) { - this.className = className; - } - - public String[] getFeatures() { - return features; - } - - public void setFeatures(String[] features) { - this.features = features; - } - - public Integer getDistance() { - return distance; - } - - public void setDistance(int distance) { - this.distance = distance; - } - - @Override - public int compareTo(Sample o) { - // TODO Auto-generated method stub - return this.getDistance().compareTo(o.getDistance()); - } - +class Sample implements Comparable { + // 样本数据的分类名称 + private String className; + // 样本数据的特征向量 + private String[] features; + //测试样本之间的间距值,以此做排序 + private Integer distance; + + Sample(String[] features){ + this.features = features; + } + + Sample(String className, String[] features){ + this.className = className; + this.features = features; + } + + String getClassName(){ + return className; + } + + void setClassName(String className){ + this.className = className; + } + + String[] getFeatures(){ + return features; + } + + public Integer getDistance(){ + return distance; + } + + public void setDistance(int distance){ + this.distance = distance; + } + + @Override + public int compareTo(Sample o){ + // TODO Auto-generated method stub + return this.getDistance().compareTo(o.getDistance()); + } + } diff --git a/Classification/DataMining_NaiveBayes/Client.java b/Classification/DataMining_NaiveBayes/Client.java index 05713b7..3e307b9 100644 --- a/Classification/DataMining_NaiveBayes/Client.java +++ b/Classification/DataMining_NaiveBayes/Client.java @@ -1,17 +1,17 @@ -package DataMining_NaiveBayes; +package Classification.DataMining_NaiveBayes; /** * 朴素贝叶斯算法场景调用类 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - //训练集数据 - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - String testData = "Youth Medium Yes Fair"; - NaiveBayesTool tool = new NaiveBayesTool(filePath); - System.out.println(testData + " 数据的分类为:" + tool.naiveBayesClassificate(testData)); - } + public static void main(String[] args){ + //训练集数据 + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Classification/DataMining_NaiveBayes/input.txt"; + String testData = "Youth Medium Yes Fair"; + NaiveBayesTool tool = new NaiveBayesTool(filePath); + System.out.println(testData + " 数据的分类为:" + tool.naiveBayesClassificate(testData)); + } } diff --git a/Classification/DataMining_NaiveBayes/NaiveBayesTool.java b/Classification/DataMining_NaiveBayes/NaiveBayesTool.java index 02db2c6..6239058 100644 --- a/Classification/DataMining_NaiveBayes/NaiveBayesTool.java +++ b/Classification/DataMining_NaiveBayes/NaiveBayesTool.java @@ -1,4 +1,4 @@ -package DataMining_NaiveBayes; +package Classification.DataMining_NaiveBayes; import java.io.BufferedReader; import java.io.File; @@ -10,200 +10,193 @@ /** * 朴素贝叶斯算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class NaiveBayesTool { - // 类标记符,这里分为2类,YES和NO - private String YES = "Yes"; - private String NO = "No"; - - // 已分类训练数据集文件路径 - private String filePath; - // 属性名称数组 - private String[] attrNames; - // 训练数据集 - private String[][] data; - - // 每个属性的值所有类型 - private HashMap> attrValue; - - public NaiveBayesTool(String filePath) { - this.filePath = filePath; - - readDataFile(); - initAttrValue(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - data = new String[dataArray.size()][]; - dataArray.toArray(data); - attrNames = data[0]; +class NaiveBayesTool { + // 类标记符,这里分为2类,YES和NO + private String YES = "Yes"; + + // 已分类训练数据集文件路径 + private String filePath; + // 属性名称数组 + private String[] attrNames; + // 训练数据集 + private String[][] data; + + // 每个属性的值所有类型 + private HashMap> attrValue; + + NaiveBayesTool(String filePath){ + this.filePath = filePath; + + readDataFile(); + initAttrValue(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + data = new String[dataArray.size()][]; + dataArray.toArray(data); + attrNames = data[0]; /* - * for(int i=0; i(); - ArrayList tempValues; - - // 按照列的方式,从左往右找 - for (int j = 1; j < attrNames.length; j++) { - // 从一列中的上往下开始寻找值 - tempValues = new ArrayList<>(); - for (int i = 1; i < data.length; i++) { - if (!tempValues.contains(data[i][j])) { - // 如果这个属性的值没有添加过,则添加 - tempValues.add(data[i][j]); - } - } - - // 一列属性的值已经遍历完毕,复制到map属性表中 - attrValue.put(data[0][j], tempValues); - } - - } - - /** - * 在classType的情况下,发生condition条件的概率 - * - * @param condition - * 属性条件 - * @param classType - * 分类的类型 - * @return - */ - private double computeConditionProbably(String condition, String classType) { - // 条件计数器 - int count = 0; - // 条件属性的索引列 - int attrIndex = 1; - // yes类标记符数据 - ArrayList yClassData = new ArrayList<>(); - // no类标记符数据 - ArrayList nClassData = new ArrayList<>(); - ArrayList classData; - - for (int i = 1; i < data.length; i++) { - // data数据按照yes和no分类 - if (data[i][attrNames.length - 1].equals(YES)) { - yClassData.add(data[i]); - } else { - nClassData.add(data[i]); - } - } - - if (classType.equals(YES)) { - classData = yClassData; - } else { - classData = nClassData; - } - - // 如果没有设置条件则,计算的是纯粹的类事件概率 - if (condition == null) { - return 1.0 * classData.size() / (data.length - 1); - } - - // 寻找此条件的属性列 - attrIndex = getConditionAttrName(condition); - - for (String[] s : classData) { - if (s[attrIndex].equals(condition)) { - count++; - } - } - - return 1.0 * count / classData.size(); - } - - /** - * 根据条件值返回条件所属属性的列值 - * - * @param condition - * 条件 - * @return - */ - private int getConditionAttrName(String condition) { - // 条件所属属性名 - String attrName = ""; - // 条件所在属性列索引 - int attrIndex = 1; - // 临时属性值类型 - ArrayList valueTypes; - for (Map.Entry entry : attrValue.entrySet()) { - valueTypes = (ArrayList) entry.getValue(); - if (valueTypes.contains(condition) - && !((String) entry.getKey()).equals("BuysComputer")) { - attrName = (String) entry.getKey(); - } - } - - for (int i = 0; i < attrNames.length - 1; i++) { - if (attrNames[i].equals(attrName)) { - attrIndex = i; - break; - } - } - - return attrIndex; - } - - /** - * 进行朴素贝叶斯分类 - * - * @param data - * 待分类数据 - */ - public String naiveBayesClassificate(String data) { - // 测试数据的属性值特征 - String[] dataFeatures; - // 在yes的条件下,x事件发生的概率 - double xWhenYes = 1.0; - // 在no的条件下,x事件发生的概率 - double xWhenNo = 1.0; - // 最后也是yes和no分类的总概率,用P(X|Ci)*P(Ci)的公式计算 - double pYes = 1; - double pNo = 1; - - dataFeatures = data.split(" "); - for (int i = 0; i < dataFeatures.length; i++) { - // 因为朴素贝叶斯算法是类条件独立的,所以可以进行累积的计算 - xWhenYes *= computeConditionProbably(dataFeatures[i], YES); - xWhenNo *= computeConditionProbably(dataFeatures[i], NO); - } - - pYes = xWhenYes * computeConditionProbably(null, YES); - pNo = xWhenNo * computeConditionProbably(null, NO); - - return (pYes > pNo ? YES : NO); - } + } + + /** + * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用 + */ + private void initAttrValue(){ + attrValue = new HashMap<>(); + ArrayList tempValues; + + // 按照列的方式,从左往右找 + for (int j = 1; j < attrNames.length; j++) { + // 从一列中的上往下开始寻找值 + tempValues = new ArrayList<>(); + for (int i = 1; i < data.length; i++) { + if (!tempValues.contains(data[i][j])) { + // 如果这个属性的值没有添加过,则添加 + tempValues.add(data[i][j]); + } + } + + // 一列属性的值已经遍历完毕,复制到map属性表中 + attrValue.put(data[0][j], tempValues); + } + + } + + /** + * 在classType的情况下,发生condition条件的概率 + * + * @param condition 属性条件 + * @param classType 分类的类型 + */ + private double computeConditionProbably(String condition, String classType){ + // 条件计数器 + int count = 0; + // 条件属性的索引列 + int attrIndex; + // yes类标记符数据 + ArrayList yClassData = new ArrayList<>(); + // no类标记符数据 + ArrayList nClassData = new ArrayList<>(); + ArrayList classData; + + for (int i = 1; i < data.length; i++) { + // data数据按照yes和no分类 + if (data[i][attrNames.length - 1].equals(YES)) { + yClassData.add(data[i]); + } else { + nClassData.add(data[i]); + } + } + + if (classType.equals(YES)) { + classData = yClassData; + } else { + classData = nClassData; + } + + // 如果没有设置条件则,计算的是纯粹的类事件概率 + if (condition == null) { + return 1.0 * classData.size() / (data.length - 1); + } + + // 寻找此条件的属性列 + attrIndex = getConditionAttrName(condition); + + for (String[] s : classData) { + if (s[attrIndex].equals(condition)) { + count++; + } + } + + return 1.0 * count / classData.size(); + } + + /** + * 根据条件值返回条件所属属性的列值 + * + * @param condition 条件 + */ + private int getConditionAttrName(String condition){ + // 条件所属属性名 + String attrName = ""; + // 条件所在属性列索引 + int attrIndex = 1; + // 临时属性值类型 + ArrayList valueTypes; + for (Map.Entry entry : attrValue.entrySet()) { + valueTypes = (ArrayList) entry.getValue(); + if (valueTypes.contains(condition) + && !entry.getKey().equals("BuysComputer")) { + attrName = (String) entry.getKey(); + } + } + + for (int i = 0; i < attrNames.length - 1; i++) { + if (attrNames[i].equals(attrName)) { + attrIndex = i; + break; + } + } + + return attrIndex; + } + + /** + * 进行朴素贝叶斯分类 + * + * @param data 待分类数据 + */ + String naiveBayesClassificate(String data){ + // 测试数据的属性值特征 + String[] dataFeatures; + // 在yes的条件下,x事件发生的概率 + double xWhenYes = 1.0; + // 在no的条件下,x事件发生的概率 + double xWhenNo = 1.0; + // 最后也是yes和no分类的总概率,用P(X|Ci)*P(Ci)的公式计算 + double pYes; + double pNo; + + dataFeatures = data.split(" "); + String NO = "No"; + for (String dataFeature : dataFeatures) { + // 因为朴素贝叶斯算法是类条件独立的,所以可以进行累积的计算 + xWhenYes *= computeConditionProbably(dataFeature, YES); + xWhenNo *= computeConditionProbably(dataFeature, NO); + } + + pYes = xWhenYes * computeConditionProbably(null, YES); + pNo = xWhenNo * computeConditionProbably(null, NO); + + return (pYes > pNo ? YES : NO); + } } diff --git a/Clustering/DataMining_BIRCH/BIRCHTool.java b/Clustering/DataMining_BIRCH/BIRCHTool.java index 3aec6aa..be3f039 100644 --- a/Clustering/DataMining_BIRCH/BIRCHTool.java +++ b/Clustering/DataMining_BIRCH/BIRCHTool.java @@ -1,4 +1,4 @@ -package DataMining_BIRCH; +package Clustering.DataMining_BIRCH; import java.io.BufferedReader; import java.io.File; @@ -10,243 +10,240 @@ /** * BIRCH聚类算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class BIRCHTool { - // 节点类型名称 - public static final String NON_LEAFNODE = "【NonLeafNode】"; - public static final String LEAFNODE = "【LeafNode】"; - public static final String CLUSTER = "【Cluster】"; - - // 测试数据文件地址 - private String filePath; - // 内部节点平衡因子B - public static int B; - // 叶子节点平衡因子L - public static int L; - // 簇直径阈值T - public static double T; - // 总的测试数据记录 - private ArrayList totalDataRecords; - - public BIRCHTool(String filePath, int B, int L, double T) { - this.filePath = filePath; - this.B = B; - this.L = L; - this.T = T; - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - totalDataRecords = new ArrayList<>(); - for (String[] array : dataArray) { - totalDataRecords.add(array); - } - } - - /** - * 构建CF聚类特征树 - * - * @return - */ - private ClusteringFeature buildCFTree() { - NonLeafNode rootNode = null; - LeafNode leafNode = null; - Cluster cluster = null; - - for (String[] record : totalDataRecords) { - cluster = new Cluster(record); - - if (rootNode == null) { - // CF树只有1个节点的时候的情况 - if (leafNode == null) { - leafNode = new LeafNode(); - } - leafNode.addingCluster(cluster); - if (leafNode.getParentNode() != null) { - rootNode = leafNode.getParentNode(); - } - } else { - if (rootNode.getParentNode() != null) { - rootNode = rootNode.getParentNode(); - } - - // 从根节点开始,从上往下寻找到最近的添加目标叶子节点 - LeafNode temp = rootNode.findedClosestNode(cluster); - temp.addingCluster(cluster); - } - } - - // 从下往上找出最上面的节点 - LeafNode node = cluster.getParentNode(); - NonLeafNode upNode = node.getParentNode(); - if (upNode == null) { - return node; - } else { - while (upNode.getParentNode() != null) { - upNode = upNode.getParentNode(); - } - - return upNode; - } - } - - /** - * 开始构建CF聚类特征树 - */ - public void startBuilding() { - // 树深度 - int level = 1; - ClusteringFeature rootNode = buildCFTree(); - - setTreeLevel(rootNode, level); - showCFTree(rootNode); - } - - /** - * 设置节点深度 - * - * @param clusteringFeature - * 当前节点 - * @param level - * 当前深度值 - */ - private void setTreeLevel(ClusteringFeature clusteringFeature, int level) { - LeafNode leafNode = null; - NonLeafNode nonLeafNode = null; - - if (clusteringFeature instanceof LeafNode) { - leafNode = (LeafNode) clusteringFeature; - } else if (clusteringFeature instanceof NonLeafNode) { - nonLeafNode = (NonLeafNode) clusteringFeature; - } - - if (nonLeafNode != null) { - nonLeafNode.setLevel(level); - level++; - // 设置子节点 - if (nonLeafNode.getNonLeafChilds() != null) { - for (NonLeafNode n1 : nonLeafNode.getNonLeafChilds()) { - setTreeLevel(n1, level); - } - } else { - for (LeafNode n2 : nonLeafNode.getLeafChilds()) { - setTreeLevel(n2, level); - } - } - } else { - leafNode.setLevel(level); - level++; - // 设置子聚簇 - for (Cluster c : leafNode.getClusterChilds()) { - c.setLevel(level); - } - } - } - - /** - * 显示CF聚类特征树 - * - * @param rootNode - * CF树根节点 - */ - private void showCFTree(ClusteringFeature rootNode) { - // 空格数,用于输出 - int blankNum = 5; - // 当前树深度 - int currentLevel = 1; - LinkedList nodeQueue = new LinkedList<>(); - ClusteringFeature cf; - LeafNode leafNode; - NonLeafNode nonLeafNode; - ArrayList clusterList = new ArrayList<>(); - String typeName; - - nodeQueue.add(rootNode); - while (nodeQueue.size() > 0) { - cf = nodeQueue.poll(); - - if (cf instanceof LeafNode) { - leafNode = (LeafNode) cf; - typeName = LEAFNODE; - - if (leafNode.getClusterChilds() != null) { - for (Cluster c : leafNode.getClusterChilds()) { - nodeQueue.add(c); - } - } - } else if (cf instanceof NonLeafNode) { - nonLeafNode = (NonLeafNode) cf; - typeName = NON_LEAFNODE; - - if (nonLeafNode.getNonLeafChilds() != null) { - for (NonLeafNode n1 : nonLeafNode.getNonLeafChilds()) { - nodeQueue.add(n1); - } - } else { - for (LeafNode n2 : nonLeafNode.getLeafChilds()) { - nodeQueue.add(n2); - } - } - } else { - clusterList.add((Cluster)cf); - typeName = CLUSTER; - } - - if (currentLevel != cf.getLevel()) { - currentLevel = cf.getLevel(); - System.out.println(); - System.out.println("|"); - System.out.println("|"); - }else if(currentLevel == cf.getLevel() && currentLevel != 1){ - for (int i = 0; i < blankNum; i++) { - System.out.print("-"); - } - } - - System.out.print(typeName); - System.out.print("N:" + cf.getN() + ", LS:"); - System.out.print("["); - for (double d : cf.getLS()) { - System.out.print(MessageFormat.format("{0}, ", d)); - } - System.out.print("]"); - } - - System.out.println(); - System.out.println("*******最终分好的聚簇****"); - //显示已经分好类的聚簇点 - for(int i=0; i totalDataRecords; + + BIRCHTool(String filePath, int B, int L, double T){ + this.filePath = filePath; + BIRCHTool.B = B; + BIRCHTool.L = L; + BIRCHTool.T = T; + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + totalDataRecords = new ArrayList<>(); + for (String[] array : dataArray) { + totalDataRecords.add(array); + } + } + + /** + * 构建CF聚类特征树 + */ + private ClusteringFeature buildCFTree(){ + NonLeafNode rootNode = null; + LeafNode leafNode = null; + Cluster cluster = null; + + for (String[] record : totalDataRecords) { + cluster = new Cluster(record); + + if (rootNode == null) { + // CF树只有1个节点的时候的情况 + if (leafNode == null) { + leafNode = new LeafNode(); + } + leafNode.addingCluster(cluster); + if (leafNode.getParentNode() != null) { + rootNode = leafNode.getParentNode(); + } + } else { + if (rootNode.getParentNode() != null) { + rootNode = rootNode.getParentNode(); + } + + // 从根节点开始,从上往下寻找到最近的添加目标叶子节点 + LeafNode temp = rootNode.findedClosestNode(cluster); + temp.addingCluster(cluster); + } + } + + // 从下往上找出最上面的节点 + LeafNode node = cluster != null ? cluster.getParentNode() : null; + NonLeafNode upNode = node != null ? node.getParentNode() : null; + if (upNode == null) { + return node; + } else { + while (upNode.getParentNode() != null) { + upNode = upNode.getParentNode(); + } + + return upNode; + } + } + + /** + * 开始构建CF聚类特征树 + */ + void startBuilding(){ + // 树深度 + int level = 1; + ClusteringFeature rootNode = buildCFTree(); + + setTreeLevel(rootNode, level); + showCFTree(rootNode); + } + + /** + * 设置节点深度 + * + * @param clusteringFeature 当前节点 + * @param level 当前深度值 + */ + private void setTreeLevel(ClusteringFeature clusteringFeature, int level){ + LeafNode leafNode = null; + NonLeafNode nonLeafNode = null; + + if (clusteringFeature instanceof LeafNode) { + leafNode = (LeafNode) clusteringFeature; + } else if (clusteringFeature instanceof NonLeafNode) { + nonLeafNode = (NonLeafNode) clusteringFeature; + } + + if (nonLeafNode != null) { + nonLeafNode.setLevel(level); + level++; + // 设置子节点 + if (nonLeafNode.getNonLeafChilds() != null) { + for (NonLeafNode n1 : nonLeafNode.getNonLeafChilds()) { + setTreeLevel(n1, level); + } + } else { + for (LeafNode n2 : nonLeafNode.getLeafChilds()) { + setTreeLevel(n2, level); + } + } + } else { + if (leafNode != null) { + leafNode.setLevel(level); + } + level++; + // 设置子聚簇 + if (leafNode != null) { + for (Cluster c : leafNode.getClusterChilds()) { + c.setLevel(level); + } + } + } + } + + /** + * 显示CF聚类特征树 + * + * @param rootNode CF树根节点 + */ + private void showCFTree(ClusteringFeature rootNode){ + // 空格数,用于输出 + int blankNum = 5; + // 当前树深度 + int currentLevel = 1; + LinkedList nodeQueue = new LinkedList<>(); + ClusteringFeature cf; + LeafNode leafNode; + NonLeafNode nonLeafNode; + ArrayList clusterList = new ArrayList<>(); + String typeName; + + nodeQueue.add(rootNode); + while (nodeQueue.size() > 0) { + cf = nodeQueue.poll(); + + if (cf instanceof LeafNode) { + leafNode = (LeafNode) cf; + typeName = LEAFNODE; + + if (leafNode.getClusterChilds() != null) { + for (Cluster c : leafNode.getClusterChilds()) { + nodeQueue.add(c); + } + } + } else if (cf instanceof NonLeafNode) { + nonLeafNode = (NonLeafNode) cf; + typeName = NON_LEAFNODE; + + if (nonLeafNode.getNonLeafChilds() != null) { + for (NonLeafNode n1 : nonLeafNode.getNonLeafChilds()) { + nodeQueue.add(n1); + } + } else { + for (LeafNode n2 : nonLeafNode.getLeafChilds()) { + nodeQueue.add(n2); + } + } + } else { + clusterList.add((Cluster) cf); + typeName = CLUSTER; + } + + if (currentLevel != cf.getLevel()) { + currentLevel = cf.getLevel(); + System.out.println(); + System.out.println("|"); + System.out.println("|"); + } else if (currentLevel == cf.getLevel() && currentLevel != 1) { + for (int i = 0; i < blankNum; i++) { + System.out.print("-"); + } + } + + System.out.print(typeName); + System.out.print("N:" + cf.getN() + ", LS:"); + System.out.print("["); + for (double d : cf.getLS()) { + System.out.print(MessageFormat.format("{0}, ", d)); + } + System.out.print("]"); + } + + System.out.println(); + System.out.println("*******最终分好的聚簇****"); + //显示已经分好类的聚簇点 + for (int i = 0; i < clusterList.size(); i++) { + System.out.println("Cluster" + (i + 1) + ":"); + for (double[] point : clusterList.get(i).getData()) { + System.out.print("["); + for (double d : point) { + System.out.print(MessageFormat.format("{0}, ", d)); + } + System.out.println("]"); + } + } + } } diff --git a/Clustering/DataMining_BIRCH/Client.java b/Clustering/DataMining_BIRCH/Client.java index 0f6ea28..125de7b 100644 --- a/Clustering/DataMining_BIRCH/Client.java +++ b/Clustering/DataMining_BIRCH/Client.java @@ -1,21 +1,21 @@ -package DataMining_BIRCH; +package Clustering.DataMining_BIRCH; /** * BIRCH聚类算法调用类 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt"; - //内部节点平衡因子B - int B = 2; - //叶子节点平衡因子L - int L = 2; - //簇直径阈值T - double T = 0.6; - - BIRCHTool tool = new BIRCHTool(filePath, B, L, T); - tool.startBuilding(); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Clustering/DataMining_BIRCH/testInput.txt"; + //内部节点平衡因子B + int B = 2; + //叶子节点平衡因子L + int L = 2; + //簇直径阈值T + double T = 0.6; + + BIRCHTool tool = new BIRCHTool(filePath, B, L, T); + tool.startBuilding(); + } } diff --git a/Clustering/DataMining_BIRCH/Cluster.java b/Clustering/DataMining_BIRCH/Cluster.java index 0edda4f..5269264 100644 --- a/Clustering/DataMining_BIRCH/Cluster.java +++ b/Clustering/DataMining_BIRCH/Cluster.java @@ -1,60 +1,56 @@ -package DataMining_BIRCH; +package Clustering.DataMining_BIRCH; import java.util.ArrayList; /** * 叶子节点中的小集群 - * @author lyq * + * @author Qstar */ -public class Cluster extends ClusteringFeature{ - //集群中的数据点 - private ArrayList data; - //父亲节点 - private LeafNode parentNode; - - public Cluster(String[] record){ - double[] d = new double[record.length]; - data = new ArrayList<>(); - for(int i=0; i getData() { - return data; - } - - public void setData(ArrayList data) { - this.data = data; - } - - @Override - protected void directAddCluster(ClusteringFeature node) { - //如果是聚类包括数据记录,则还需合并数据记录 - Cluster c = (Cluster)node; - ArrayList dataRecords = c.getData(); - this.data.addAll(dataRecords); - - super.directAddCluster(node); - } - - public LeafNode getParentNode() { - return parentNode; - } - - public void setParentNode(LeafNode parentNode) { - this.parentNode = parentNode; - } - - @Override - public void addingCluster(ClusteringFeature clusteringFeature) { - // TODO Auto-generated method stub - - } +class Cluster extends ClusteringFeature { + //集群中的数据点 + private ArrayList data; + //父亲节点 + private LeafNode parentNode; + + Cluster(String[] record){ + double[] d = new double[record.length]; + data = new ArrayList<>(); + for (int i = 0; i < record.length; i++) { + d[i] = Double.parseDouble(record[i]); + } + data.add(d); + //计算CF聚类特征 + this.setLS(data); + this.setSS(data); + this.setN(data); + } + + ArrayList getData(){ + return data; + } + + @Override + protected void directAddCluster(ClusteringFeature node){ + //如果是聚类包括数据记录,则还需合并数据记录 + Cluster c = (Cluster) node; + ArrayList dataRecords = c.getData(); + this.data.addAll(dataRecords); + + super.directAddCluster(node); + } + + LeafNode getParentNode(){ + return parentNode; + } + + void setParentNode(LeafNode parentNode){ + this.parentNode = parentNode; + } + + @Override + public void addingCluster(ClusteringFeature clusteringFeature){ + // TODO Auto-generated method stub + + } } diff --git a/Clustering/DataMining_BIRCH/ClusteringFeature.java b/Clustering/DataMining_BIRCH/ClusteringFeature.java index 5df5a0b..8cb48ee 100644 --- a/Clustering/DataMining_BIRCH/ClusteringFeature.java +++ b/Clustering/DataMining_BIRCH/ClusteringFeature.java @@ -1,202 +1,191 @@ -package DataMining_BIRCH; +package Clustering.DataMining_BIRCH; import java.util.ArrayList; /** * 聚类特征基本属性 - * - * @author lyq - * + * + * @author Qstar */ -public abstract class ClusteringFeature { - // 子类中节点的总数目 - protected int N; - // 子类中N个节点的线性和 - protected double[] LS; - // 子类中N个节点的平方和 - protected double[] SS; - //节点深度,用于CF树的输出 - protected int level; - - public int getN() { - return N; - } - - public void setN(int n) { - N = n; - } - - public double[] getLS() { - return LS; - } - - public void setLS(double[] lS) { - LS = lS; - } - - public double[] getSS() { - return SS; - } - - public void setSS(double[] sS) { - SS = sS; - } - - protected void setN(ArrayList dataRecords) { - this.N = dataRecords.size(); - } - - public int getLevel() { - return level; - } - - public void setLevel(int level) { - this.level = level; - } - - /** - * 根据节点数据计算线性和 - * - * @param dataRecords - * 节点数据记录 - */ - protected void setLS(ArrayList dataRecords) { - int num = dataRecords.get(0).length; - double[] record; - LS = new double[num]; - for (int j = 0; j < num; j++) { - LS[j] = 0; - } - - for (int i = 0; i < dataRecords.size(); i++) { - record = dataRecords.get(i); - for (int j = 0; j < record.length; j++) { - LS[j] += record[j]; - } - } - } - - /** - * 根据节点数据计算平方 - * - * @param dataRecords - * 节点数据 - */ - protected void setSS(ArrayList dataRecords) { - int num = dataRecords.get(0).length; - double[] record; - SS = new double[num]; - for (int j = 0; j < num; j++) { - SS[j] = 0; - } - - for (int i = 0; i < dataRecords.size(); i++) { - record = dataRecords.get(i); - for (int j = 0; j < record.length; j++) { - SS[j] += record[j] * record[j]; - } - } - } - - /** - * CF向量特征的叠加,无须考虑划分 - * - * @param node - */ - protected void directAddCluster(ClusteringFeature node) { - int N = node.getN(); - double[] otherLS = node.getLS(); - double[] otherSS = node.getSS(); - - if(LS == null){ - this.N = 0; - LS = new double[otherLS.length]; - SS = new double[otherLS.length]; - - for(int i=0; i records) { - double sumDistance = 0; - double[] data1; - double[] data2; - // 数据总数 - int totalNum = records.size(); - - for (int i = 0; i < totalNum - 1; i++) { - data1 = records.get(i); - for (int j = i + 1; j < totalNum; j++) { - data2 = records.get(j); - sumDistance += computeOuDistance(data1, data2); - } - } - - // 返回的值除以总对数,总对数应减半,会重复算一次 - return Math.sqrt(sumDistance / (totalNum * (totalNum - 1) / 2)); - } - - /** - * 对给定的2个向量,计算欧式距离 - * - * @param record1 - * 向量点1 - * @param record2 - * 向量点2 - */ - private double computeOuDistance(double[] record1, double[] record2) { - double distance = 0; - - for (int i = 0; i < record1.length; i++) { - distance += (record1[i] - record2[i]) * (record1[i] - record2[i]); - } - - return distance; - } - - /** - * 聚类添加节点包括,超出阈值进行分裂的操作 - * - * @param clusteringFeature - * 待添加聚簇 - */ - public abstract void addingCluster(ClusteringFeature clusteringFeature); +abstract class ClusteringFeature { + // 子类中节点的总数目 + private int N; + // 子类中N个节点的线性和 + private double[] LS; + // 子类中N个节点的平方和 + private double[] SS; + //节点深度,用于CF树的输出 + private int level; + + public int getN(){ + return N; + } + + protected void setN(ArrayList dataRecords){ + this.N = dataRecords.size(); + } + + public void setN(int n){ + N = n; + } + + double[] getLS(){ + return LS; + } + + /** + * 根据节点数据计算线性和 + * + * @param dataRecords 节点数据记录 + */ + void setLS(ArrayList dataRecords){ + int num = dataRecords.get(0).length; + double[] record; + LS = new double[num]; + for (int j = 0; j < num; j++) { + LS[j] = 0; + } + + for (double[] dataRecord : dataRecords) { + record = dataRecord; + for (int j = 0; j < record.length; j++) { + LS[j] += record[j]; + } + } + } + + public void setLS(double[] lS){ + LS = lS; + } + + private double[] getSS(){ + return SS; + } + + /** + * 根据节点数据计算平方 + * + * @param dataRecords 节点数据 + */ + void setSS(ArrayList dataRecords){ + int num = dataRecords.get(0).length; + double[] record; + SS = new double[num]; + for (int j = 0; j < num; j++) { + SS[j] = 0; + } + + for (double[] dataRecord : dataRecords) { + record = dataRecord; + for (int j = 0; j < record.length; j++) { + SS[j] += record[j] * record[j]; + } + } + } + + public void setSS(double[] sS){ + SS = sS; + } + + int getLevel(){ + return level; + } + + void setLevel(int level){ + this.level = level; + } + + /** + * CF向量特征的叠加,无须考虑划分 + * + * @param node 簇 + */ + protected void directAddCluster(ClusteringFeature node){ + int N = node.getN(); + double[] otherLS = node.getLS(); + double[] otherSS = node.getSS(); + + if (LS == null) { + this.N = 0; + LS = new double[otherLS.length]; + SS = new double[otherLS.length]; + + for (int i = 0; i < LS.length; i++) { + LS[i] = 0; + SS[i] = 0; + } + } + + // 3个数量上进行叠加 + for (int i = 0; i < LS.length; i++) { + LS[i] += otherLS[i]; + SS[i] += otherSS[i]; + } + this.N += N; + } + + /** + * 计算簇与簇之间的距离即簇中心之间的距离 + */ + double computerClusterDistance(ClusteringFeature cluster){ + double distance = 0; + double[] otherLS = cluster.LS; + int num = N; + + int otherNum = cluster.N; + + for (int i = 0; i < LS.length; i++) { + distance += (LS[i] / num - otherLS[i] / otherNum) + * (LS[i] / num - otherLS[i] / otherNum); + } + distance = Math.sqrt(distance); + return distance; + } + + /** + * 计算簇内对象的平均距离 + * + * @param records 簇内的数据记录 + */ + double computerInClusterDistance(ArrayList records){ + double sumDistance = 0; + double[] data1; + double[] data2; + // 数据总数 + int totalNum = records.size(); + + for (int i = 0; i < totalNum - 1; i++) { + data1 = records.get(i); + for (int j = i + 1; j < totalNum; j++) { + data2 = records.get(j); + sumDistance += computeOuDistance(data1, data2); + } + } + + // 返回的值除以总对数,总对数应减半,会重复算一次 + return Math.sqrt(sumDistance / (totalNum * (totalNum - 1) / 2)); + } + + /** + * 对给定的2个向量,计算欧式距离 + * + * @param record1 向量点1 + * @param record2 向量点2 + */ + private double computeOuDistance(double[] record1, double[] record2){ + double distance = 0; + + for (int i = 0; i < record1.length; i++) { + distance += (record1[i] - record2[i]) * (record1[i] - record2[i]); + } + + return distance; + } + + /** + * 聚类添加节点包括,超出阈值进行分裂的操作 + * + * @param clusteringFeature 待添加聚簇 + */ + public abstract void addingCluster(ClusteringFeature clusteringFeature); } diff --git a/Clustering/DataMining_BIRCH/LeafNode.java b/Clustering/DataMining_BIRCH/LeafNode.java index dda5605..f23ed13 100644 --- a/Clustering/DataMining_BIRCH/LeafNode.java +++ b/Clustering/DataMining_BIRCH/LeafNode.java @@ -1,155 +1,159 @@ -package DataMining_BIRCH; +package Clustering.DataMining_BIRCH; import java.util.ArrayList; /** * CF树叶子节点 - * - * @author lyq - * + * + * @author Qstar */ -public class LeafNode extends ClusteringFeature { - // 孩子集群 - private ArrayList clusterChilds; - // 父亲节点 - private NonLeafNode parentNode; - - public ArrayList getClusterChilds() { - return clusterChilds; - } - - public void setClusterChilds(ArrayList clusterChilds) { - this.clusterChilds = clusterChilds; - } - - /** - * 将叶子节点划分出2个 - * - * @return - */ - public LeafNode[] divideLeafNode() { - LeafNode[] leafNodeArray = new LeafNode[2]; - // 簇间距离差距最大的2个簇,后面的簇按照就近原则划分即可 - Cluster cluster1 = null; - Cluster cluster2 = null; - Cluster tempCluster = null; - double maxValue = 0; - double temp = 0; - - // 找出簇心距离差距最大的2个簇 - for (int i = 0; i < clusterChilds.size() - 1; i++) { - tempCluster = clusterChilds.get(i); - for (int j = i + 1; j < clusterChilds.size(); j++) { - temp = tempCluster - .computerClusterDistance(clusterChilds.get(j)); - - if (temp > maxValue) { - maxValue = temp; - cluster1 = tempCluster; - cluster2 = clusterChilds.get(j); - } - } - } - - leafNodeArray[0] = new LeafNode(); - leafNodeArray[0].addingCluster(cluster1); - cluster1.setParentNode(leafNodeArray[0]); - leafNodeArray[1] = new LeafNode(); - leafNodeArray[1].addingCluster(cluster2); - cluster2.setParentNode(leafNodeArray[1]); - clusterChilds.remove(cluster1); - clusterChilds.remove(cluster2); - // 就近分配簇 - for (Cluster c : clusterChilds) { - if (cluster1.computerClusterDistance(c) < cluster2 - .computerClusterDistance(c)) { - // 簇间距离如果接近最小簇,就加入最小簇所属叶子节点 - leafNodeArray[0].addingCluster(c); - c.setParentNode(leafNodeArray[0]); - } else { - leafNodeArray[1].addingCluster(c); - c.setParentNode(leafNodeArray[1]); - } - } - - return leafNodeArray; - } - - public NonLeafNode getParentNode() { - return parentNode; - } - - public void setParentNode(NonLeafNode parentNode) { - this.parentNode = parentNode; - } - - @Override - public void addingCluster(ClusteringFeature clusteringFeature) { - //更新聚类特征值 - directAddCluster(clusteringFeature); - - // 寻找到的目标集群 - Cluster findedCluster = null; - Cluster cluster = (Cluster) clusteringFeature; - // 簇内对象平均距离 - double disance = Integer.MAX_VALUE; - // 簇间距离差值 - double errorDistance = 0; - boolean needDivided = false; - if (clusterChilds == null) { - clusterChilds = new ArrayList<>(); - clusterChilds.add(cluster); - cluster.setParentNode(this); - } else { - for (Cluster c : clusterChilds) { - errorDistance = c.computerClusterDistance(cluster); - if (disance > errorDistance) { - // 选出簇间距离最近的 - disance = errorDistance; - findedCluster = c; - } - } - - ArrayList data1 = (ArrayList) findedCluster - .getData().clone(); - ArrayList data2 = cluster.getData(); - data1.addAll(data2); - // 如果添加后的聚类的簇间距离超过给定阈值,需要额外新建簇 - if (findedCluster.computerInClusterDistance(data1) > BIRCHTool.T) { - // 叶子节点的孩子数不能超过平衡因子L - if (clusterChilds.size() + 1 > BIRCHTool.L) { - needDivided = true; - } - clusterChilds.add(cluster); - cluster.setParentNode(this); - } else { - findedCluster.directAddCluster(cluster); - cluster.setParentNode(this); - } - } - - if(needDivided){ - if(parentNode == null){ - parentNode = new NonLeafNode(); - }else{ - parentNode.getLeafChilds().remove(this); - } - - LeafNode[] nodeArray = divideLeafNode(); - for(LeafNode n: nodeArray){ - parentNode.addingCluster(n); - } - } - } - - @Override - protected void directAddCluster(ClusteringFeature node) { - // TODO Auto-generated method stub - if(parentNode != null){ - parentNode.directAddCluster(node); - } - - super.directAddCluster(node); - } - +class LeafNode extends ClusteringFeature { + // 孩子集群 + private ArrayList clusterChilds; + // 父亲节点 + private NonLeafNode parentNode; + + ArrayList getClusterChilds(){ + return clusterChilds; + } + + /** + * 将叶子节点划分出2个 + */ + private LeafNode[] divideLeafNode(){ + LeafNode[] leafNodeArray = new LeafNode[2]; + // 簇间距离差距最大的2个簇,后面的簇按照就近原则划分即可 + Cluster cluster1 = null; + Cluster cluster2 = null; + Cluster tempCluster; + double maxValue = 0; + double temp; + + // 找出簇心距离差距最大的2个簇 + for (int i = 0; i < clusterChilds.size() - 1; i++) { + tempCluster = clusterChilds.get(i); + for (int j = i + 1; j < clusterChilds.size(); j++) { + temp = tempCluster + .computerClusterDistance(clusterChilds.get(j)); + + if (temp > maxValue) { + maxValue = temp; + cluster1 = tempCluster; + cluster2 = clusterChilds.get(j); + } + } + } + + leafNodeArray[0] = new LeafNode(); + leafNodeArray[0].addingCluster(cluster1); + if (cluster1 != null) { + cluster1.setParentNode(leafNodeArray[0]); + } + leafNodeArray[1] = new LeafNode(); + leafNodeArray[1].addingCluster(cluster2); + if (cluster2 != null) { + cluster2.setParentNode(leafNodeArray[1]); + } + clusterChilds.remove(cluster1); + clusterChilds.remove(cluster2); + // 就近分配簇 + for (Cluster c : clusterChilds) { + if ((cluster1 != null ? cluster1.computerClusterDistance(c) : 0) < (cluster2 != null ? cluster2 + .computerClusterDistance(c) : 0)) { + // 簇间距离如果接近最小簇,就加入最小簇所属叶子节点 + leafNodeArray[0].addingCluster(c); + c.setParentNode(leafNodeArray[0]); + } else { + leafNodeArray[1].addingCluster(c); + c.setParentNode(leafNodeArray[1]); + } + } + + return leafNodeArray; + } + + NonLeafNode getParentNode(){ + return parentNode; + } + + void setParentNode(NonLeafNode parentNode){ + this.parentNode = parentNode; + } + + @Override + public void addingCluster(ClusteringFeature clusteringFeature){ + //更新聚类特征值 + directAddCluster(clusteringFeature); + + // 寻找到的目标集群 + Cluster findedCluster = null; + Cluster cluster = (Cluster) clusteringFeature; + // 簇内对象平均距离 + double disance = Integer.MAX_VALUE; + // 簇间距离差值 + double errorDistance; + boolean needDivided = false; + if (clusterChilds == null) { + clusterChilds = new ArrayList<>(); + clusterChilds.add(cluster); + cluster.setParentNode(this); + } else { + for (Cluster c : clusterChilds) { + errorDistance = c.computerClusterDistance(cluster); + if (disance > errorDistance) { + // 选出簇间距离最近的 + disance = errorDistance; + findedCluster = c; + } + } + + ArrayList data1 = null; + if (findedCluster != null) { + data1 = (ArrayList) findedCluster + .getData().clone(); + } + ArrayList data2 = cluster.getData(); + if (data1 != null) { + data1.addAll(data2); + } + // 如果添加后的聚类的簇间距离超过给定阈值,需要额外新建簇 + if (findedCluster != null) { + if (findedCluster.computerInClusterDistance(data1) > BIRCHTool.T) { + // 叶子节点的孩子数不能超过平衡因子L + if (clusterChilds.size() + 1 > BIRCHTool.L) { + needDivided = true; + } + clusterChilds.add(cluster); + cluster.setParentNode(this); + } else { + findedCluster.directAddCluster(cluster); + cluster.setParentNode(this); + } + } + } + + if (needDivided) { + if (parentNode == null) { + parentNode = new NonLeafNode(); + } else { + parentNode.getLeafChilds().remove(this); + } + + LeafNode[] nodeArray = divideLeafNode(); + for (LeafNode n : nodeArray) { + parentNode.addingCluster(n); + } + } + } + + @Override + protected void directAddCluster(ClusteringFeature node){ + // TODO Auto-generated method stub + if (parentNode != null) { + parentNode.directAddCluster(node); + } + + super.directAddCluster(node); + } + } diff --git a/Clustering/DataMining_BIRCH/NonLeafNode.java b/Clustering/DataMining_BIRCH/NonLeafNode.java index 749d614..ef8b392 100644 --- a/Clustering/DataMining_BIRCH/NonLeafNode.java +++ b/Clustering/DataMining_BIRCH/NonLeafNode.java @@ -1,299 +1,282 @@ -package DataMining_BIRCH; +package Clustering.DataMining_BIRCH; import java.util.ArrayList; import java.util.LinkedList; /** * 非叶子节点 - * - * @author lyq - * + * + * @author Qstar */ -public class NonLeafNode extends ClusteringFeature { - // 非叶子节点的孩子节点可能为非叶子节点,也可能为叶子节点 - private ArrayList nonLeafChilds; - // 如果是叶子节点的孩子,则以双向链表的形式存在 - private LinkedList leafChilds; - // 父亲节点 - private NonLeafNode parentNode; - - public ArrayList getNonLeafChilds() { - return nonLeafChilds; - } - - public void setNonLeafChilds(ArrayList nonLeafChilds) { - this.nonLeafChilds = nonLeafChilds; - } - - public LinkedList getLeafChilds() { - return leafChilds; - } - - public void setLeafChilds(LinkedList leafChilds) { - this.leafChilds = leafChilds; - } - - /** - * 添加叶子节点 - * - * @param leafNode - * 待添加叶子节点 - * @return - */ - public boolean addingNeededDivide(LeafNode leafNode) { - boolean needDivided = false; - if (leafChilds == null) { - leafChilds = new LinkedList<>(); - leafChilds.add(leafNode); - leafNode.setParentNode(this); - } else { - // 如果添加后,叶子节点数超过平衡因子,则添加后需要分裂 - if (leafChilds.size() + 1 > BIRCHTool.B) { - needDivided = true; - } - leafChilds.add(leafNode); - leafNode.setParentNode(this); - - // 测试程序 - /* - * if(leafChilds.size() == 2){ if(BIRCHTool.B == 2){ needDivided = +class NonLeafNode extends ClusteringFeature { + // 非叶子节点的孩子节点可能为非叶子节点,也可能为叶子节点 + private ArrayList nonLeafChilds; + // 如果是叶子节点的孩子,则以双向链表的形式存在 + private LinkedList leafChilds; + // 父亲节点 + private NonLeafNode parentNode; + + ArrayList getNonLeafChilds(){ + return nonLeafChilds; + } + + LinkedList getLeafChilds(){ + return leafChilds; + } + + /** + * 添加叶子节点 + * + * @param leafNode 待添加叶子节点 + */ + private boolean addingNeededDivide(LeafNode leafNode){ + boolean needDivided = false; + if (leafChilds == null) { + leafChilds = new LinkedList<>(); + leafChilds.add(leafNode); + leafNode.setParentNode(this); + } else { + // 如果添加后,叶子节点数超过平衡因子,则添加后需要分裂 + if (leafChilds.size() + 1 > BIRCHTool.B) { + needDivided = true; + } + leafChilds.add(leafNode); + leafNode.setParentNode(this); + + // 测试程序 + /* + * if(leafChilds.size() == 2){ if(BIRCHTool.B == 2){ needDivided = * true; LeafNode node = new LeafNode(); node.setN(1); * node.setLS(new double[]{5.1, 3.3, 1.5, 0.33}); node.setSS(new * double[]{1, 1, 1, 1}); leafChilds.add(node); BIRCHTool.B++; } } */ - } - - return needDivided; - } - - /** - * 添加非叶子节点 - * - * @param nonLeafNode - * 待添加非叶子节点 - * @return - */ - public boolean addingNeededDivide(NonLeafNode nonLeafNode) { - boolean needDivided = false; - if (nonLeafChilds == null) { - nonLeafChilds = new ArrayList<>(); - nonLeafChilds.add(nonLeafNode); - nonLeafNode.setParentNode(this); - } else { - // 如果添加后,叶子节点数超过平衡因子,则添加失败 - if (nonLeafChilds.size() + 1 > BIRCHTool.B) { - needDivided = false; - } - nonLeafChilds.add(nonLeafNode); - nonLeafNode.setParentNode(this); - } - - return needDivided; - } - - /** - * 因为叶子节点数超过阈值,进行分裂 - * - * @return - */ - public NonLeafNode[] leafNodeDivided() { - NonLeafNode[] nonLeafNodes = new NonLeafNode[2]; - - // 簇间距离差距最大的2个簇,后面的簇按照就近原则划分即可 - LeafNode node1 = null; - LeafNode node2 = null; - LeafNode tempNode = null; - double maxValue = 0; - double temp = 0; - - // 找出簇心距离差距最大的2个簇 - for (int i = 0; i < leafChilds.size() - 1; i++) { - tempNode = leafChilds.get(i); - for (int j = i + 1; j < leafChilds.size(); j++) { - temp = tempNode.computerClusterDistance(leafChilds.get(j)); - - if (temp > maxValue) { - maxValue = temp; - node1 = tempNode; - node2 = leafChilds.get(j); - } - } - } - - nonLeafNodes[0] = new NonLeafNode(); - nonLeafNodes[0].addingCluster(node1); - nonLeafNodes[1] = new NonLeafNode(); - nonLeafNodes[1].addingCluster(node2); - leafChilds.remove(node1); - leafChilds.remove(node2); - // 就近分配簇 - for (LeafNode c : leafChilds) { - if (node1.computerClusterDistance(c) < node2 - .computerClusterDistance(c)) { - // 簇间距离如果接近最小簇,就加入最小簇所属叶子节点 - nonLeafNodes[0].addingCluster(c); - c.setParentNode(nonLeafNodes[0]); - } else { - nonLeafNodes[1].addingCluster(c); - c.setParentNode(nonLeafNodes[1]); - } - } - - return nonLeafNodes; - } - - /** - * 因为非叶子节点数超过阈值,进行分裂 - * - * @return - */ - public NonLeafNode[] nonLeafNodeDivided() { - NonLeafNode[] nonLeafNodes = new NonLeafNode[2]; - - // 簇间距离差距最大的2个簇,后面的簇按照就近原则划分即可 - NonLeafNode node1 = null; - NonLeafNode node2 = null; - NonLeafNode tempNode = null; - double maxValue = 0; - double temp = 0; - - // 找出簇心距离差距最大的2个簇 - for (int i = 0; i < nonLeafChilds.size() - 1; i++) { - tempNode = nonLeafChilds.get(i); - for (int j = i + 1; j < nonLeafChilds.size(); j++) { - temp = tempNode.computerClusterDistance(nonLeafChilds.get(j)); - - if (temp > maxValue) { - maxValue = temp; - node1 = tempNode; - node2 = nonLeafChilds.get(j); - } - } - } - - nonLeafNodes[0] = new NonLeafNode(); - nonLeafNodes[0].addingCluster(node1); - nonLeafNodes[1] = new NonLeafNode(); - nonLeafNodes[1].addingCluster(node2); - nonLeafChilds.remove(node1); - nonLeafChilds.remove(node2); - // 就近分配簇 - for (NonLeafNode c : nonLeafChilds) { - if (node1.computerClusterDistance(c) < node2 - .computerClusterDistance(c)) { - // 簇间距离如果接近最小簇,就加入最小簇所属叶子节点 - nonLeafNodes[0].addingCluster(c); - c.setParentNode(nonLeafNodes[0]); - } else { - nonLeafNodes[1].addingCluster(c); - c.setParentNode(nonLeafNodes[1]); - } - } - - return nonLeafNodes; - } - - /** - * 寻找到最接近的叶子节点 - * - * @param cluster - * 待添加聚簇 - * @return - */ - public LeafNode findedClosestNode(Cluster cluster) { - LeafNode node = null; - NonLeafNode nonLeafNode = null; - double temp; - double distance = Integer.MAX_VALUE; - - if (nonLeafChilds == null) { - for (LeafNode n : leafChilds) { - temp = n.computerClusterDistance(cluster); - if (temp < distance) { - distance = temp; - node = n; - } - } - } else { - for (NonLeafNode n : nonLeafChilds) { - temp = n.computerClusterDistance(cluster); - if (temp < distance) { - distance = temp; - nonLeafNode = n; - } - } - - // 递归继续往下找 - node = nonLeafNode.findedClosestNode(cluster); - } - - return node; - } - - public NonLeafNode getParentNode() { - return parentNode; - } - - public void setParentNode(NonLeafNode parentNode) { - this.parentNode = parentNode; - } - - @Override - public void addingCluster(ClusteringFeature clusteringFeature) { - LeafNode leafNode = null; - NonLeafNode nonLeafNode = null; - NonLeafNode[] nonLeafNodeArrays; - boolean neededDivide = false; - // 更新聚类特征值 - directAddCluster(clusteringFeature); - - if (clusteringFeature instanceof LeafNode) { - leafNode = (LeafNode) clusteringFeature; - } else { - nonLeafNode = (NonLeafNode) clusteringFeature; - } - - if (nonLeafNode != null) { - neededDivide = addingNeededDivide(nonLeafNode); - - if (neededDivide) { - if (parentNode == null) { - parentNode = new NonLeafNode(); - } else { - parentNode.nonLeafChilds.remove(this); - } - - nonLeafNodeArrays = this.nonLeafNodeDivided(); - for (NonLeafNode n1 : nonLeafNodeArrays) { - parentNode.addingCluster(n1); - } - } - } else { - neededDivide = addingNeededDivide(leafNode); - - if (neededDivide) { - if (parentNode == null) { - parentNode = new NonLeafNode(); - } else { - parentNode.nonLeafChilds.remove(this); - } - - nonLeafNodeArrays = this.leafNodeDivided(); - for (NonLeafNode n2 : nonLeafNodeArrays) { - parentNode.addingCluster(n2); - } - } - } - } - - @Override - protected void directAddCluster(ClusteringFeature node) { - // TODO Auto-generated method stub - if (parentNode != null) { - parentNode.directAddCluster(node); - } - - super.directAddCluster(node); - } + } + + return needDivided; + } + + /** + * 添加非叶子节点 + * + * @param nonLeafNode 待添加非叶子节点 + */ + private boolean addingNeededDivide(NonLeafNode nonLeafNode){ + boolean needDivided = false; + if (nonLeafChilds == null) { + nonLeafChilds = new ArrayList<>(); + nonLeafChilds.add(nonLeafNode); + nonLeafNode.setParentNode(this); + } else { + // 如果添加后,叶子节点数超过平衡因子,则添加失败 + if (nonLeafChilds.size() + 1 > BIRCHTool.B) { + needDivided = false; + } + nonLeafChilds.add(nonLeafNode); + nonLeafNode.setParentNode(this); + } + + return needDivided; + } + + /** + * 因为叶子节点数超过阈值,进行分裂 + */ + private NonLeafNode[] leafNodeDivided(){ + NonLeafNode[] nonLeafNodes = new NonLeafNode[2]; + + // 簇间距离差距最大的2个簇,后面的簇按照就近原则划分即可 + LeafNode node1 = null; + LeafNode node2 = null; + LeafNode tempNode; + double maxValue = 0; + double temp; + + // 找出簇心距离差距最大的2个簇 + for (int i = 0; i < leafChilds.size() - 1; i++) { + tempNode = leafChilds.get(i); + for (int j = i + 1; j < leafChilds.size(); j++) { + temp = tempNode.computerClusterDistance(leafChilds.get(j)); + + if (temp > maxValue) { + maxValue = temp; + node1 = tempNode; + node2 = leafChilds.get(j); + } + } + } + + nonLeafNodes[0] = new NonLeafNode(); + nonLeafNodes[0].addingCluster(node1); + nonLeafNodes[1] = new NonLeafNode(); + nonLeafNodes[1].addingCluster(node2); + leafChilds.remove(node1); + leafChilds.remove(node2); + // 就近分配簇 + for (LeafNode c : leafChilds) { + if ((node1 != null ? node1.computerClusterDistance(c) : 0) < (node2 != null ? node2 + .computerClusterDistance(c) : 0)) { + // 簇间距离如果接近最小簇,就加入最小簇所属叶子节点 + nonLeafNodes[0].addingCluster(c); + c.setParentNode(nonLeafNodes[0]); + } else { + nonLeafNodes[1].addingCluster(c); + c.setParentNode(nonLeafNodes[1]); + } + } + + return nonLeafNodes; + } + + /** + * 因为非叶子节点数超过阈值,进行分裂 + */ + private NonLeafNode[] nonLeafNodeDivided(){ + NonLeafNode[] nonLeafNodes = new NonLeafNode[2]; + + // 簇间距离差距最大的2个簇,后面的簇按照就近原则划分即可 + NonLeafNode node1 = null; + NonLeafNode node2 = null; + NonLeafNode tempNode; + double maxValue = 0; + double temp; + + // 找出簇心距离差距最大的2个簇 + for (int i = 0; i < nonLeafChilds.size() - 1; i++) { + tempNode = nonLeafChilds.get(i); + for (int j = i + 1; j < nonLeafChilds.size(); j++) { + temp = tempNode.computerClusterDistance(nonLeafChilds.get(j)); + + if (temp > maxValue) { + maxValue = temp; + node1 = tempNode; + node2 = nonLeafChilds.get(j); + } + } + } + + nonLeafNodes[0] = new NonLeafNode(); + nonLeafNodes[0].addingCluster(node1); + nonLeafNodes[1] = new NonLeafNode(); + nonLeafNodes[1].addingCluster(node2); + nonLeafChilds.remove(node1); + nonLeafChilds.remove(node2); + // 就近分配簇 + for (NonLeafNode c : nonLeafChilds) { + if (node1 != null) { + if (node1.computerClusterDistance(c) < node2 + .computerClusterDistance(c)) { + // 簇间距离如果接近最小簇,就加入最小簇所属叶子节点 + nonLeafNodes[0].addingCluster(c); + c.setParentNode(nonLeafNodes[0]); + } else { + nonLeafNodes[1].addingCluster(c); + c.setParentNode(nonLeafNodes[1]); + } + } + } + + return nonLeafNodes; + } + + /** + * 寻找到最接近的叶子节点 + * + * @param cluster 待添加聚簇 + */ + LeafNode findedClosestNode(Cluster cluster){ + LeafNode node = null; + NonLeafNode nonLeafNode = null; + double temp; + double distance = Integer.MAX_VALUE; + + if (nonLeafChilds == null) { + for (LeafNode n : leafChilds) { + temp = n.computerClusterDistance(cluster); + if (temp < distance) { + distance = temp; + node = n; + } + } + } else { + for (NonLeafNode n : nonLeafChilds) { + temp = n.computerClusterDistance(cluster); + if (temp < distance) { + distance = temp; + nonLeafNode = n; + } + } + + // 递归继续往下找 + node = nonLeafNode != null ? nonLeafNode.findedClosestNode(cluster) : null; + } + + return node; + } + + NonLeafNode getParentNode(){ + return parentNode; + } + + private void setParentNode(NonLeafNode parentNode){ + this.parentNode = parentNode; + } + + @Override + public void addingCluster(ClusteringFeature clusteringFeature){ + LeafNode leafNode = null; + NonLeafNode nonLeafNode = null; + NonLeafNode[] nonLeafNodeArrays; + boolean neededDivide; + // 更新聚类特征值 + directAddCluster(clusteringFeature); + + if (clusteringFeature instanceof LeafNode) { + leafNode = (LeafNode) clusteringFeature; + } else { + nonLeafNode = (NonLeafNode) clusteringFeature; + } + + if (nonLeafNode != null) { + neededDivide = addingNeededDivide(nonLeafNode); + + if (neededDivide) { + if (parentNode == null) { + parentNode = new NonLeafNode(); + } else { + parentNode.nonLeafChilds.remove(this); + } + + nonLeafNodeArrays = this.nonLeafNodeDivided(); + for (NonLeafNode n1 : nonLeafNodeArrays) { + parentNode.addingCluster(n1); + } + } + } else { + neededDivide = addingNeededDivide(leafNode); + + if (neededDivide) { + if (parentNode == null) { + parentNode = new NonLeafNode(); + } else { + parentNode.nonLeafChilds.remove(this); + } + + nonLeafNodeArrays = this.leafNodeDivided(); + for (NonLeafNode n2 : nonLeafNodeArrays) { + parentNode.addingCluster(n2); + } + } + } + } + + @Override + protected void directAddCluster(ClusteringFeature node){ + // TODO Auto-generated method stub + if (parentNode != null) { + parentNode.directAddCluster(node); + } + + super.directAddCluster(node); + } } diff --git a/Clustering/DataMining_KMeans/Client.java b/Clustering/DataMining_KMeans/Client.java index fd6aa31..124b08a 100644 --- a/Clustering/DataMining_KMeans/Client.java +++ b/Clustering/DataMining_KMeans/Client.java @@ -1,17 +1,17 @@ -package DataMining_KMeans; +package Clustering.DataMining_KMeans; /** * K-means(K均值)算法调用类 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - //聚类中心数量设定 - int classNum = 3; - - KMeansTool tool = new KMeansTool(filePath, classNum); - tool.kMeansClustering(); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Clustering/DataMining_KMeans/input.txt"; + //聚类中心数量设定 + int classNum = 3; + + KMeansTool tool = new KMeansTool(filePath, classNum); + tool.kMeansClustering(); + } } diff --git a/Clustering/DataMining_KMeans/KMeansTool.java b/Clustering/DataMining_KMeans/KMeansTool.java index 8122701..d86c9ba 100644 --- a/Clustering/DataMining_KMeans/KMeansTool.java +++ b/Clustering/DataMining_KMeans/KMeansTool.java @@ -1,4 +1,4 @@ -package DataMining_KMeans; +package Clustering.DataMining_KMeans; import java.io.BufferedReader; import java.io.File; @@ -10,124 +10,121 @@ /** * k均值算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class KMeansTool { - // 输入数据文件地址 - private String filePath; - // 分类类别个数 - private int classNum; - // 类名称 - private ArrayList classNames; - // 聚类坐标点 - private ArrayList classPoints; - // 所有的数据左边点 - private ArrayList totalPoints; - - public KMeansTool(String filePath, int classNum) { - this.filePath = filePath; - this.classNum = classNum; - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - classPoints = new ArrayList<>(); - totalPoints = new ArrayList<>(); - classNames = new ArrayList<>(); - for (int i = 0, j = 1; i < dataArray.size(); i++) { - if (j <= classNum) { - classPoints.add(new Point(dataArray.get(i)[0], - dataArray.get(i)[1], j + "")); - classNames.add(i + ""); - j++; - } - totalPoints - .add(new Point(dataArray.get(i)[0], dataArray.get(i)[1])); - } - } - - /** - * K均值聚类算法实现 - */ - public void kMeansClustering() { - double tempX = 0; - double tempY = 0; - int count = 0; - double error = Integer.MAX_VALUE; - Point temp; - - while (error > 0.01 * classNum) { - for (Point p1 : totalPoints) { - // 将所有的测试坐标点就近分类 - for (Point p2 : classPoints) { - p2.computerDistance(p1); - } - Collections.sort(classPoints); - - // 取出p1离类坐标点最近的那个点 - p1.setClassName(classPoints.get(0).getClassName()); - } - - error = 0; - // 按照均值重新划分聚类中心点 - for (Point p1 : classPoints) { - count = 0; - tempX = 0; - tempY = 0; - for (Point p : totalPoints) { - if (p.getClassName().equals(p1.getClassName())) { - count++; - tempX += p.getX(); - tempY += p.getY(); - } - } - tempX /= count; - tempY /= count; - - error += Math.abs((tempX - p1.getX())); - error += Math.abs((tempY - p1.getY())); - // 计算均值 - p1.setX(tempX); - p1.setY(tempY); - - } - - for (int i = 0; i < classPoints.size(); i++) { - temp = classPoints.get(i); - System.out.println(MessageFormat.format("聚类中心点{0},x={1},y={2}", - (i + 1), temp.getX(), temp.getY())); - } - System.out.println("----------"); - } - - System.out.println("结果值收敛"); - for (int i = 0; i < classPoints.size(); i++) { - temp = classPoints.get(i); - System.out.println(MessageFormat.format("聚类中心点{0},x={1},y={2}", - (i + 1), temp.getX(), temp.getY())); - } - - } +class KMeansTool { + // 输入数据文件地址 + private String filePath; + // 分类类别个数 + private int classNum; + // 聚类坐标点 + private ArrayList classPoints; + // 所有的数据左边点 + private ArrayList totalPoints; + + KMeansTool(String filePath, int classNum){ + this.filePath = filePath; + this.classNum = classNum; + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + classPoints = new ArrayList<>(); + totalPoints = new ArrayList<>(); + ArrayList classNames = new ArrayList<>(); + for (int i = 0, j = 1; i < dataArray.size(); i++) { + if (j <= classNum) { + classPoints.add(new Point(dataArray.get(i)[0], + dataArray.get(i)[1], j + "")); + classNames.add(i + ""); + j++; + } + totalPoints + .add(new Point(dataArray.get(i)[0], dataArray.get(i)[1])); + } + } + + /** + * K均值聚类算法实现 + */ + void kMeansClustering(){ + double tempX; + double tempY; + int count; + double error = Integer.MAX_VALUE; + Point temp; + + while (error > 0.01 * classNum) { + for (Point p1 : totalPoints) { + // 将所有的测试坐标点就近分类 + for (Point p2 : classPoints) { + p2.computerDistance(p1); + } + Collections.sort(classPoints); + + // 取出p1离类坐标点最近的那个点 + p1.setClassName(classPoints.get(0).getClassName()); + } + + error = 0; + // 按照均值重新划分聚类中心点 + for (Point p1 : classPoints) { + count = 0; + tempX = 0; + tempY = 0; + for (Point p : totalPoints) { + if (p.getClassName().equals(p1.getClassName())) { + count++; + tempX += p.getX(); + tempY += p.getY(); + } + } + tempX /= count; + tempY /= count; + + error += Math.abs((tempX - p1.getX())); + error += Math.abs((tempY - p1.getY())); + // 计算均值 + p1.setX(tempX); + p1.setY(tempY); + + } + + for (int i = 0; i < classPoints.size(); i++) { + temp = classPoints.get(i); + System.out.println(MessageFormat.format("聚类中心点{0},x={1},y={2}", + (i + 1), temp.getX(), temp.getY())); + } + System.out.println("----------"); + } + + System.out.println("结果值收敛"); + for (int i = 0; i < classPoints.size(); i++) { + temp = classPoints.get(i); + System.out.println(MessageFormat.format("聚类中心点{0},x={1},y={2}", + (i + 1), temp.getX(), temp.getY())); + } + + } } diff --git a/Clustering/DataMining_KMeans/Point.java b/Clustering/DataMining_KMeans/Point.java index 11cdb86..d0cc290 100644 --- a/Clustering/DataMining_KMeans/Point.java +++ b/Clustering/DataMining_KMeans/Point.java @@ -1,87 +1,86 @@ -package DataMining_KMeans; +package Clustering.DataMining_KMeans; /** * 坐标点类 - * - * @author lyq - * + * + * @author Qstar */ -public class Point implements Comparable{ - // 坐标点横坐标 - private double x; - // 坐标点纵坐标 - private double y; - //以此点作为聚类中心的类的类名称 - private String className; - // 坐标点之间的欧式距离 - private Double distance; - - public Point(double x, double y) { - this.x = x; - this.y = y; - } - - public Point(String x, String y) { - this.x = Double.parseDouble(x); - this.y = Double.parseDouble(y); - } - - public Point(String x, String y, String className) { - this.x = Double.parseDouble(x); - this.y = Double.parseDouble(y); - this.className = className; - } - - /** - * 距离目标点p的欧几里得距离 - * - * @param p - */ - public void computerDistance(Point p) { - if (p == null) { - return; - } - - this.distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y) - * (this.y - p.y); - } - - public double getX() { - return x; - } - - public void setX(double x) { - this.x = x; - } - - public double getY() { - return y; - } - - public void setY(double y) { - this.y = y; - } - - public String getClassName() { - return className; - } - - public void setClassName(String className) { - this.className = className; - } - - public double getDistance() { - return distance; - } - - public void setDistance(double distance) { - this.distance = distance; - } - - @Override - public int compareTo(Point o) { - // TODO Auto-generated method stub - return this.distance.compareTo(o.distance); - } - +public class Point implements Comparable { + // 坐标点横坐标 + private double x; + // 坐标点纵坐标 + private double y; + //以此点作为聚类中心的类的类名称 + private String className; + // 坐标点之间的欧式距离 + private Double distance; + + public Point(double x, double y){ + this.x = x; + this.y = y; + } + + public Point(String x, String y){ + this.x = Double.parseDouble(x); + this.y = Double.parseDouble(y); + } + + public Point(String x, String y, String className){ + this.x = Double.parseDouble(x); + this.y = Double.parseDouble(y); + this.className = className; + } + + /** + * 距离目标点p的欧几里得距离 + * + * @param p 目标点 + */ + void computerDistance(Point p){ + if (p == null) { + return; + } + + this.distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y) + * (this.y - p.y); + } + + public double getX(){ + return x; + } + + public void setX(double x){ + this.x = x; + } + + public double getY(){ + return y; + } + + public void setY(double y){ + this.y = y; + } + + String getClassName(){ + return className; + } + + void setClassName(String className){ + this.className = className; + } + + public double getDistance(){ + return distance; + } + + public void setDistance(double distance){ + this.distance = distance; + } + + @Override + public int compareTo(Point o){ + // TODO Auto-generated method stub + return this.distance.compareTo(o.distance); + } + } diff --git a/GraphMining/DataMining_GSpan/Client.java b/GraphMining/DataMining_GSpan/Client.java index 17c1094..d0defd0 100644 --- a/GraphMining/DataMining_GSpan/Client.java +++ b/GraphMining/DataMining_GSpan/Client.java @@ -1,18 +1,18 @@ -package DataMining_GSpan; +package GraphMining.DataMining_GSpan; /** * gSpan频繁子图挖掘算法 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - //测试数据文件地址 - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - //最小支持度率 - double minSupportRate = 0.2; - - GSpanTool tool = new GSpanTool(filePath, minSupportRate); - tool.freqGraphMining(); - } + public static void main(String[] args){ + //测试数据文件地址 + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/GraphMining/DataMining_GSpan/input.txt"; + //最小支持度率 + double minSupportRate = 0.2; + + GSpanTool tool = new GSpanTool(filePath, minSupportRate); + tool.freqGraphMining(); + } } diff --git a/GraphMining/DataMining_GSpan/DFSCodeTraveler.java b/GraphMining/DataMining_GSpan/DFSCodeTraveler.java index 01982f9..7d4d7c9 100644 --- a/GraphMining/DataMining_GSpan/DFSCodeTraveler.java +++ b/GraphMining/DataMining_GSpan/DFSCodeTraveler.java @@ -1,150 +1,147 @@ -package DataMining_GSpan; +package GraphMining.DataMining_GSpan; import java.util.ArrayList; import java.util.Stack; /** * 图编码深度优先搜索类,判断当前编码在给定图中是否为最小编码 - * - * @author lyq - * + * + * @author Qstar */ -public class DFSCodeTraveler { - // 当前的编码是否为最下编码标识 - boolean isMin; - // 当前挖掘的图的边五元组编码组 - ArrayList edgeSeqs; - // 当前的图结构 - Graph graph; - // 图节点id对应的边五元组中的id标识 - int[] g2s; - // 代表图中的边是否被用到了 - boolean f[][]; - public DFSCodeTraveler(ArrayList edgeSeqs, Graph graph) { - this.isMin = true; - this.edgeSeqs = edgeSeqs; - this.graph = graph; - } +class DFSCodeTraveler { + // 当前的编码是否为最下编码标识 + boolean isMin; + // 代表图中的边是否被用到了 + private boolean f[][]; + // 当前挖掘的图的边五元组编码组 + private ArrayList edgeSeqs; + // 当前的图结构 + private Graph graph; + // 图节点id对应的边五元组中的id标识 + private int[] g2s; - public void traveler() { - int nodeLNums = graph.nodeLabels.size(); - g2s = new int[nodeLNums]; - for (int i = 0; i < nodeLNums; i++) { - // 设置-1代表此点还未被计入编码 - g2s[i] = -1; - } + DFSCodeTraveler(ArrayList edgeSeqs, Graph graph){ + this.isMin = true; + this.edgeSeqs = edgeSeqs; + this.graph = graph; + } - f = new boolean[nodeLNums][nodeLNums]; - for (int i = 0; i < nodeLNums; i++) { - for (int j = 0; j < nodeLNums; j++) { - f[i][j] = false; - } - } + void traveler(){ + int nodeLNums = graph.nodeLabels.size(); + g2s = new int[nodeLNums]; + for (int i = 0; i < nodeLNums; i++) { + // 设置-1代表此点还未被计入编码 + g2s[i] = -1; + } - // 从每个点开始寻找最小编码五元组 - for (int i = 0; i < nodeLNums; i++) { - //对选择的第一个点的标号做判断 - if(graph.getNodeLabels().get(i) > edgeSeqs.get(0).x){ - continue; - } - // 五元组id从0开始设置 - g2s[i] = 0; + f = new boolean[nodeLNums][nodeLNums]; + for (int i = 0; i < nodeLNums; i++) { + for (int j = 0; j < nodeLNums; j++) { + f[i][j] = false; + } + } - Stack s = new Stack<>(); - s.push(i); - dfsSearch(s, 0, 1); - if (!isMin) { - return; - } - g2s[i] = -1; - } - } + // 从每个点开始寻找最小编码五元组 + for (int i = 0; i < nodeLNums; i++) { + //对选择的第一个点的标号做判断 + if (graph.getNodeLabels().get(i) > edgeSeqs.get(0).x) { + continue; + } + // 五元组id从0开始设置 + g2s[i] = 0; - /** - * 深度优先搜索最小编码组 - * - * @param stack - * 加入的节点id栈 - * @param currentPosition - * 当前进行的层次,代表找到的第几条边 - * @param next - * 五元组边下一条边的点的临时标识 - */ - private void dfsSearch(Stack stack, int currentPosition, int next) { - if (currentPosition >= edgeSeqs.size()) { - stack.pop(); - // 比较到底了则返回 - return; - } + Stack s = new Stack<>(); + s.push(i); + dfsSearch(s, 0, 1); + if (!isMin) { + return; + } + g2s[i] = -1; + } + } - while (!stack.isEmpty()) { - int x = stack.pop(); - for (int i = 0; i < graph.edgeNexts.get(x).size(); i++) { - // 从此id节点所连接的点中选取1个点作为下一个点 - int y = graph.edgeNexts.get(x).get(i); - // 如果这2个点所构成的边已经被用过,则继续 - if (f[x][y] || f[y][x]) { - continue; - } + /** + * 深度优先搜索最小编码组 + * + * @param stack 加入的节点id栈 + * @param currentPosition 当前进行的层次,代表找到的第几条边 + * @param next 五元组边下一条边的点的临时标识 + */ + private void dfsSearch(Stack stack, int currentPosition, int next){ + if (currentPosition >= edgeSeqs.size()) { + stack.pop(); + // 比较到底了则返回 + return; + } - // 如果y这个点未被用过 - if (g2s[y] < 0) { - // 新建这条边五元组 - Edge e = new Edge(g2s[x], next, graph.nodeLabels.get(x), - graph.edgeLabels.get(x).get(i), - graph.nodeLabels.get(y)); + while (!stack.isEmpty()) { + int x = stack.pop(); + for (int i = 0; i < graph.edgeNexts.get(x).size(); i++) { + // 从此id节点所连接的点中选取1个点作为下一个点 + int y = graph.edgeNexts.get(x).get(i); + // 如果这2个点所构成的边已经被用过,则继续 + if (f[x][y] || f[y][x]) { + continue; + } - // 与相应位置的边做比较,如果不是最小则失败 - int compareResult = e.compareWith(edgeSeqs - .get(currentPosition)); - if (compareResult == Edge.EDGE_SMALLER) { - isMin = false; - return; - } else if (compareResult == Edge.EDGE_LARGER) { - continue; - } - // 如果相等则继续比 - g2s[y] = next; - f[x][y] = true; - f[y][x] = true; - stack.push(y); - dfsSearch(stack, currentPosition + 1, next + 1); - if (!isMin) { - return; - } - f[x][y] = false; - f[y][x] = false; - g2s[y] = -1; - } else { - // 这个点已经被用过的时候,不需要再设置五元组id标识 - // 新建这条边五元组 - Edge e = new Edge(g2s[x], g2s[y], graph.nodeLabels.get(x), - graph.edgeLabels.get(x).get(i), - graph.nodeLabels.get(y)); + // 如果y这个点未被用过 + if (g2s[y] < 0) { + // 新建这条边五元组 + Edge e = new Edge(g2s[x], next, graph.nodeLabels.get(x), + graph.edgeLabels.get(x).get(i), + graph.nodeLabels.get(y)); - // 与相应位置的边做比较,如果不是最小则失败 - int compareResult = e.compareWith(edgeSeqs - .get(currentPosition)); - if (compareResult == Edge.EDGE_SMALLER) { - isMin = false; - return; - } else if (compareResult == Edge.EDGE_LARGER) { - continue; - } - // 如果相等则继续比 - g2s[y] = next; - f[x][y] = true; - f[y][x] = true; - stack.push(y); - dfsSearch(stack, currentPosition + 1, next); - if (!isMin) { - return; - } - f[x][y] = false; - f[y][x] = false; - } - } - } - } + // 与相应位置的边做比较,如果不是最小则失败 + int compareResult = e.compareWith(edgeSeqs + .get(currentPosition)); + if (compareResult == Edge.EDGE_SMALLER) { + isMin = false; + return; + } else if (compareResult == Edge.EDGE_LARGER) { + continue; + } + // 如果相等则继续比 + g2s[y] = next; + f[x][y] = true; + f[y][x] = true; + stack.push(y); + dfsSearch(stack, currentPosition + 1, next + 1); + if (!isMin) { + return; + } + f[x][y] = false; + f[y][x] = false; + g2s[y] = -1; + } else { + // 这个点已经被用过的时候,不需要再设置五元组id标识 + // 新建这条边五元组 + Edge e = new Edge(g2s[x], g2s[y], graph.nodeLabels.get(x), + graph.edgeLabels.get(x).get(i), + graph.nodeLabels.get(y)); + + // 与相应位置的边做比较,如果不是最小则失败 + int compareResult = e.compareWith(edgeSeqs + .get(currentPosition)); + if (compareResult == Edge.EDGE_SMALLER) { + isMin = false; + return; + } else if (compareResult == Edge.EDGE_LARGER) { + continue; + } + // 如果相等则继续比 + g2s[y] = next; + f[x][y] = true; + f[y][x] = true; + stack.push(y); + dfsSearch(stack, currentPosition + 1, next); + if (!isMin) { + return; + } + f[x][y] = false; + f[y][x] = false; + } + } + } + } } diff --git a/GraphMining/DataMining_GSpan/Edge.java b/GraphMining/DataMining_GSpan/Edge.java index 3abc720..00289ce 100644 --- a/GraphMining/DataMining_GSpan/Edge.java +++ b/GraphMining/DataMining_GSpan/Edge.java @@ -1,62 +1,54 @@ -package DataMining_GSpan; +package GraphMining.DataMining_GSpan; /** * 边,用五元组表示 - * - * @author lyq - * + * + * @author Qstar */ -public class Edge { - // 五元组的大小比较结果 - public static final int EDGE_EQUAL = 0; - public static final int EDGE_SMALLER = 1; - public static final int EDGE_LARGER = 2; +class Edge { + static final int EDGE_SMALLER = 1; + static final int EDGE_LARGER = 2; + // 五元组的大小比较结果 + private static final int EDGE_EQUAL = 0; + // 边的一端的id号标识 + int ix; + // 边的另一端的id号标识 + int iy; + // 边的一端的点标号 + int x; + // 边的标号 + int a; + // 边的另一端的点标号 + int y; - // 边的一端的id号标识 - int ix; - // 边的另一端的id号标识 - int iy; - // 边的一端的点标号 - int x; - // 边的标号 - int a; - // 边的另一端的点标号 - int y; + Edge(int ix, int iy, int x, int a, int y){ + this.ix = ix; + this.iy = iy; + this.x = x; + this.a = a; + this.y = y; + } - public Edge(int ix, int iy, int x, int a, int y) { - this.ix = ix; - this.iy = iy; - this.x = x; - this.a = a; - this.y = y; - } - - /** - * 当前边是与给定的边的大小比较关系 - * - * @param e - * @return - */ - public int compareWith(Edge e) { - int result = EDGE_EQUAL; - int[] array1 = new int[] { ix, iy, x, y, a }; - int[] array2 = new int[] { e.ix, e.iy, e.x, e.y, e.a }; - - // 按照ix, iy,x,y,a的次序依次比较 - for (int i = 0; i < array1.length; i++) { - if (array1[i] < array2[i]) { - result = EDGE_SMALLER; - break; - } else if (array1[i] > array2[i]) { - result = EDGE_LARGER; - break; - } else { - // 如果相等,继续比较下一个 - continue; - } - } - - return result; - } + /** + * 当前边是与给定的边的大小比较关系 + * + * @param e 当前边 + */ + int compareWith(Edge e){ + int result = EDGE_EQUAL; + int[] array1 = new int[]{ix, iy, x, y, a}; + int[] array2 = new int[]{e.ix, e.iy, e.x, e.y, e.a}; + // 按照ix, iy,x,y,a的次序依次比较 + for (int i = 0; i < array1.length; i++) { + if (array1[i] < array2[i]) { + result = EDGE_SMALLER; + break; + } else if (array1[i] > array2[i]) { + result = EDGE_LARGER; + break; + } + } + return result; + } } diff --git a/GraphMining/DataMining_GSpan/EdgeFrequency.java b/GraphMining/DataMining_GSpan/EdgeFrequency.java index b095b9a..ea22718 100644 --- a/GraphMining/DataMining_GSpan/EdgeFrequency.java +++ b/GraphMining/DataMining_GSpan/EdgeFrequency.java @@ -1,31 +1,25 @@ -package DataMining_GSpan; +package GraphMining.DataMining_GSpan; /** * 边的频繁统计 - * @author lyq * + * @author Qstar */ -public class EdgeFrequency { - //节点标号数量 - private int nodeLabelNum; - //边的标号数量 - private int edgeLabelNum; - //用于存放边计数的3维数组 - public int[][][] edgeFreqCount; - - public EdgeFrequency(int nodeLabelNum, int edgeLabelNum){ - this.nodeLabelNum = nodeLabelNum; - this.edgeLabelNum = edgeLabelNum; - - edgeFreqCount = new int[nodeLabelNum][edgeLabelNum][nodeLabelNum]; - //最初始化操作 - for(int i=0; i totalGraphDatas; - // 所有的图结构数据 - private ArrayList totalGraphs; - // 挖掘出的频繁子图 - private ArrayList resultGraphs; - // 边的频度统计 - private EdgeFrequency ef; - // 节点的频度 - private int[] freqNodeLabel; - // 边的频度 - private int[] freqEdgeLabel; - // 重新标号之后的点的标号数 - private int newNodeLabelNum = 0; - // 重新标号后的边的标号数 - private int newEdgeLabelNum = 0; - - public GSpanTool(String filePath, double minSupportRate) { - this.filePath = filePath; - this.minSupportRate = minSupportRate; - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - calFrequentAndRemove(dataArray); - } - - /** - * 统计边和点的频度,并移除不频繁的点边,以标号作为统计的变量 - * - * @param dataArray - * 原始数据 - */ - private void calFrequentAndRemove(ArrayList dataArray) { - int tempCount = 0; - freqNodeLabel = new int[LABEL_MAX]; - freqEdgeLabel = new int[LABEL_MAX]; - - // 做初始化操作 - for (int i = 0; i < LABEL_MAX; i++) { - // 代表标号为i的节点目前的数量为0 - freqNodeLabel[i] = 0; - freqEdgeLabel[i] = 0; - } - - GraphData gd = null; - totalGraphDatas = new ArrayList<>(); - for (String[] array : dataArray) { - if (array[0].equals(INPUT_NEW_GRAPH)) { - if (gd != null) { - totalGraphDatas.add(gd); - } - - // 新建图 - gd = new GraphData(); - } else if (array[0].equals(INPUT_VERTICE)) { - // 每个图中的每种图只统计一次 - if (!gd.getNodeLabels().contains(Integer.parseInt(array[2]))) { - tempCount = freqNodeLabel[Integer.parseInt(array[2])]; - tempCount++; - freqNodeLabel[Integer.parseInt(array[2])] = tempCount; - } - - gd.getNodeLabels().add(Integer.parseInt(array[2])); - gd.getNodeVisibles().add(true); - } else if (array[0].equals(INPUT_EDGE)) { - // 每个图中的每种图只统计一次 - if (!gd.getEdgeLabels().contains(Integer.parseInt(array[3]))) { - tempCount = freqEdgeLabel[Integer.parseInt(array[3])]; - tempCount++; - freqEdgeLabel[Integer.parseInt(array[3])] = tempCount; - } - - int i = Integer.parseInt(array[1]); - int j = Integer.parseInt(array[2]); - - gd.getEdgeLabels().add(Integer.parseInt(array[3])); - gd.getEdgeX().add(i); - gd.getEdgeY().add(j); - gd.getEdgeVisibles().add(true); - } - } - // 把最后一块gd数据加入 - totalGraphDatas.add(gd); - minSupportCount = (int) (minSupportRate * totalGraphDatas.size()); - - for (GraphData g : totalGraphDatas) { - g.removeInFreqNodeAndEdge(freqNodeLabel, freqEdgeLabel, - minSupportCount); - } - } - - /** - * 根据标号频繁度进行排序并且重新标号 - */ - private void sortAndReLabel() { - int label1 = 0; - int label2 = 0; - int temp = 0; - // 点排序名次 - int[] rankNodeLabels = new int[LABEL_MAX]; - // 边排序名次 - int[] rankEdgeLabels = new int[LABEL_MAX]; - // 标号对应排名 - int[] nodeLabel2Rank = new int[LABEL_MAX]; - int[] edgeLabel2Rank = new int[LABEL_MAX]; - - for (int i = 0; i < LABEL_MAX; i++) { - // 表示排名第i位的标号为i,[i]中的i表示排名 - rankNodeLabels[i] = i; - rankEdgeLabels[i] = i; - } - - for (int i = 0; i < freqNodeLabel.length - 1; i++) { - int k = 0; - label1 = rankNodeLabels[i]; - temp = label1; - for (int j = i + 1; j < freqNodeLabel.length; j++) { - label2 = rankNodeLabels[j]; - - if (freqNodeLabel[temp] < freqNodeLabel[label2]) { - // 进行标号的互换 - temp = label2; - k = j; - } - } - - if (temp != label1) { - // 进行i,k排名下的标号对调 - temp = rankNodeLabels[k]; - rankNodeLabels[k] = rankNodeLabels[i]; - rankNodeLabels[i] = temp; - } - } - - // 对边同样进行排序 - for (int i = 0; i < freqEdgeLabel.length - 1; i++) { - int k = 0; - label1 = rankEdgeLabels[i]; - temp = label1; - for (int j = i + 1; j < freqEdgeLabel.length; j++) { - label2 = rankEdgeLabels[j]; - - if (freqEdgeLabel[temp] < freqEdgeLabel[label2]) { - // 进行标号的互换 - temp = label2; - k = j; - } - } - - if (temp != label1) { - // 进行i,k排名下的标号对调 - temp = rankEdgeLabels[k]; - rankEdgeLabels[k] = rankEdgeLabels[i]; - rankEdgeLabels[i] = temp; - } - } - - // 将排名对标号转为标号对排名 - for (int i = 0; i < rankNodeLabels.length; i++) { - nodeLabel2Rank[rankNodeLabels[i]] = i; - } - - for (int i = 0; i < rankEdgeLabels.length; i++) { - edgeLabel2Rank[rankEdgeLabels[i]] = i; - } - - for (GraphData gd : totalGraphDatas) { - gd.reLabelByRank(nodeLabel2Rank, edgeLabel2Rank); - } - - // 根据排名找出小于支持度值的最大排名值 - for (int i = 0; i < rankNodeLabels.length; i++) { - if (freqNodeLabel[rankNodeLabels[i]] > minSupportCount) { - newNodeLabelNum = i; - } - } - for (int i = 0; i < rankEdgeLabels.length; i++) { - if (freqEdgeLabel[rankEdgeLabels[i]] > minSupportCount) { - newEdgeLabelNum = i; - } - } - //排名号比数量少1,所以要加回来 - newNodeLabelNum++; - newEdgeLabelNum++; - } - - /** - * 进行频繁子图的挖掘 - */ - public void freqGraphMining() { - long startTime = System.currentTimeMillis(); - long endTime = 0; - Graph g; - sortAndReLabel(); - - resultGraphs = new ArrayList<>(); - totalGraphs = new ArrayList<>(); - // 通过图数据构造图结构 - for (GraphData gd : totalGraphDatas) { - g = new Graph(); - g = g.constructGraph(gd); - totalGraphs.add(g); - } - - // 根据新的点边的标号数初始化边频繁度对象 - ef = new EdgeFrequency(newNodeLabelNum, newEdgeLabelNum); - for (int i = 0; i < newNodeLabelNum; i++) { - for (int j = 0; j < newEdgeLabelNum; j++) { - for (int k = 0; k < newNodeLabelNum; k++) { - for (Graph tempG : totalGraphs) { - if (tempG.hasEdge(i, j, k)) { - ef.edgeFreqCount[i][j][k]++; - } - } - } - } - } - - Edge edge; - GraphCode gc; - for (int i = 0; i < newNodeLabelNum; i++) { - for (int j = 0; j < newEdgeLabelNum; j++) { - for (int k = 0; k < newNodeLabelNum; k++) { - if (ef.edgeFreqCount[i][j][k] >= minSupportCount) { - gc = new GraphCode(); - edge = new Edge(0, 1, i, j, k); - gc.getEdgeSeq().add(edge); - - // 将含有此边的图id加入到gc中 - for (int y = 0; y < totalGraphs.size(); y++) { - if (totalGraphs.get(y).hasEdge(i, j, k)) { - gc.getGs().add(y); - } - } - // 对某条满足阈值的边进行挖掘 - subMining(gc, 2); - } - } - } - } - - endTime = System.currentTimeMillis(); - System.out.println("算法执行时间"+ (endTime-startTime) + "ms"); - printResultGraphInfo(); - } - - /** - * 进行频繁子图的挖掘 - * - * @param gc - * 图编码 - * @param next - * 图所含的点的个数 - */ - public void subMining(GraphCode gc, int next) { - Edge e; - Graph graph = new Graph(); - int id1; - int id2; - - for(int i=0; i()); - graph.edgeNexts.add(new ArrayList()); - } - - // 首先根据图编码中的边五元组构造图 - for (int i = 0; i < gc.getEdgeSeq().size(); i++) { - e = gc.getEdgeSeq().get(i); - id1 = e.ix; - id2 = e.iy; - - graph.nodeLabels.set(id1, e.x); - graph.nodeLabels.set(id2, e.y); - graph.edgeLabels.get(id1).add(e.a); - graph.edgeLabels.get(id2).add(e.a); - graph.edgeNexts.get(id1).add(id2); - graph.edgeNexts.get(id2).add(id1); - } - - DFSCodeTraveler dTraveler = new DFSCodeTraveler(gc.getEdgeSeq(), graph); - dTraveler.traveler(); - if (!dTraveler.isMin) { - return; - } - - // 如果当前是最小编码则将此图加入到结果集中 - resultGraphs.add(graph); - Edge e1; - ArrayList gIds; - SubChildTraveler sct; - ArrayList edgeArray; - // 添加潜在的孩子边,每条孩子边所属的图id - HashMap> edge2GId = new HashMap<>(); - for (int i = 0; i < gc.gs.size(); i++) { - int id = gc.gs.get(i); - - // 在此结构的条件下,在多加一条边构成子图继续挖掘 - sct = new SubChildTraveler(gc.edgeSeq, totalGraphs.get(id)); - sct.traveler(); - edgeArray = sct.getResultChildEdge(); - - // 做边id的更新 - for (Edge e2 : edgeArray) { - if (!edge2GId.containsKey(e2)) { - gIds = new ArrayList<>(); - } else { - gIds = edge2GId.get(e2); - } - - gIds.add(id); - edge2GId.put(e2, gIds); - } - } - - for (Map.Entry entry : edge2GId.entrySet()) { - e1 = (Edge) entry.getKey(); - gIds = (ArrayList) entry.getValue(); - - // 如果此边的频度大于最小支持度值,则继续挖掘 - if (gIds.size() < minSupportCount) { - continue; - } - - GraphCode nGc = new GraphCode(); - nGc.edgeSeq.addAll(gc.edgeSeq); - // 在当前图中新加入一条边,构成新的子图进行挖掘 - nGc.edgeSeq.add(e1); - nGc.gs.addAll(gIds); - - if (e1.iy == next) { - // 如果边的点id设置是为当前最大值的时候,则开始寻找下一个点 - subMining(nGc, next + 1); - } else { - // 如果此点已经存在,则next值不变 - subMining(nGc, next); - } - } - } - - /** - * 输出频繁子图结果信息 - */ - public void printResultGraphInfo(){ - System.out.println(MessageFormat.format("挖掘出的频繁子图的个数为:{0}个", resultGraphs.size())); - } +class GSpanTool { + // Label标号的最大数量,包括点标号和边标号 + private final int LABEL_MAX = 100; + + // 测试数据文件地址 + private String filePath; + // 最小支持度率 + private double minSupportRate; + // 最小支持度数,通过图总数与最小支持度率的乘积计算所得 + private int minSupportCount; + // 初始所有图的数据 + private ArrayList totalGraphDatas; + // 所有的图结构数据 + private ArrayList totalGraphs; + // 挖掘出的频繁子图 + private ArrayList resultGraphs; + // 节点的频度 + private int[] freqNodeLabel; + // 边的频度 + private int[] freqEdgeLabel; + // 重新标号之后的点的标号数 + private int newNodeLabelNum = 0; + // 重新标号后的边的标号数 + private int newEdgeLabelNum = 0; + + GSpanTool(String filePath, double minSupportRate){ + this.filePath = filePath; + this.minSupportRate = minSupportRate; + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + calFrequentAndRemove(dataArray); + } + + /** + * 统计边和点的频度,并移除不频繁的点边,以标号作为统计的变量 + * + * @param dataArray 原始数据 + */ + private void calFrequentAndRemove(ArrayList dataArray){ + int tempCount; + freqNodeLabel = new int[LABEL_MAX]; + freqEdgeLabel = new int[LABEL_MAX]; + + // 做初始化操作 + for (int i = 0; i < LABEL_MAX; i++) { + // 代表标号为i的节点目前的数量为0 + freqNodeLabel[i] = 0; + freqEdgeLabel[i] = 0; + } + + GraphData gd = null; + totalGraphDatas = new ArrayList<>(); + for (String[] array : dataArray) { + String INPUT_NEW_GRAPH = "t"; + String INPUT_VERTICE = "v"; + String INPUT_EDGE = "e"; + if (array[0].equals(INPUT_NEW_GRAPH)) { + if (gd != null) { + totalGraphDatas.add(gd); + } + + // 新建图 + gd = new GraphData(); + } else if (array[0].equals(INPUT_VERTICE)) { + // 每个图中的每种图只统计一次 + if (gd != null && !gd.getNodeLabels().contains(Integer.parseInt(array[2]))) { + tempCount = freqNodeLabel[Integer.parseInt(array[2])]; + tempCount++; + freqNodeLabel[Integer.parseInt(array[2])] = tempCount; + } + + if (gd != null) { + gd.getNodeLabels().add(Integer.parseInt(array[2])); + gd.getNodeVisibles().add(true); + } + } else if (array[0].equals(INPUT_EDGE)) { + // 每个图中的每种图只统计一次 + if (gd != null && !gd.getEdgeLabels().contains(Integer.parseInt(array[3]))) { + tempCount = freqEdgeLabel[Integer.parseInt(array[3])]; + tempCount++; + freqEdgeLabel[Integer.parseInt(array[3])] = tempCount; + } + + int i = Integer.parseInt(array[1]); + int j = Integer.parseInt(array[2]); + + if (gd != null) { + gd.getEdgeLabels().add(Integer.parseInt(array[3])); + gd.getEdgeX().add(i); + gd.getEdgeY().add(j); + gd.getEdgeVisibles().add(true); + } + } + } + // 把最后一块gd数据加入 + totalGraphDatas.add(gd); + minSupportCount = (int) (minSupportRate * totalGraphDatas.size()); + + for (GraphData g : totalGraphDatas) { + g.removeInFreqNodeAndEdge(freqNodeLabel, freqEdgeLabel, + minSupportCount); + } + } + + /** + * 根据标号频繁度进行排序并且重新标号 + */ + private void sortAndReLabel(){ + int label1; + int label2; + int temp; + // 点排序名次 + int[] rankNodeLabels = new int[LABEL_MAX]; + // 边排序名次 + int[] rankEdgeLabels = new int[LABEL_MAX]; + // 标号对应排名 + int[] nodeLabel2Rank = new int[LABEL_MAX]; + int[] edgeLabel2Rank = new int[LABEL_MAX]; + + for (int i = 0; i < LABEL_MAX; i++) { + // 表示排名第i位的标号为i,[i]中的i表示排名 + rankNodeLabels[i] = i; + rankEdgeLabels[i] = i; + } + + for (int i = 0; i < freqNodeLabel.length - 1; i++) { + int k = 0; + label1 = rankNodeLabels[i]; + temp = label1; + for (int j = i + 1; j < freqNodeLabel.length; j++) { + label2 = rankNodeLabels[j]; + + if (freqNodeLabel[temp] < freqNodeLabel[label2]) { + // 进行标号的互换 + temp = label2; + k = j; + } + } + + if (temp != label1) { + // 进行i,k排名下的标号对调 + temp = rankNodeLabels[k]; + rankNodeLabels[k] = rankNodeLabels[i]; + rankNodeLabels[i] = temp; + } + } + + // 对边同样进行排序 + for (int i = 0; i < freqEdgeLabel.length - 1; i++) { + int k = 0; + label1 = rankEdgeLabels[i]; + temp = label1; + for (int j = i + 1; j < freqEdgeLabel.length; j++) { + label2 = rankEdgeLabels[j]; + + if (freqEdgeLabel[temp] < freqEdgeLabel[label2]) { + // 进行标号的互换 + temp = label2; + k = j; + } + } + + if (temp != label1) { + // 进行i,k排名下的标号对调 + temp = rankEdgeLabels[k]; + rankEdgeLabels[k] = rankEdgeLabels[i]; + rankEdgeLabels[i] = temp; + } + } + + // 将排名对标号转为标号对排名 + for (int i = 0; i < rankNodeLabels.length; i++) { + nodeLabel2Rank[rankNodeLabels[i]] = i; + } + + for (int i = 0; i < rankEdgeLabels.length; i++) { + edgeLabel2Rank[rankEdgeLabels[i]] = i; + } + + for (GraphData gd : totalGraphDatas) { + gd.reLabelByRank(nodeLabel2Rank, edgeLabel2Rank); + } + + // 根据排名找出小于支持度值的最大排名值 + for (int i = 0; i < rankNodeLabels.length; i++) { + if (freqNodeLabel[rankNodeLabels[i]] > minSupportCount) { + newNodeLabelNum = i; + } + } + for (int i = 0; i < rankEdgeLabels.length; i++) { + if (freqEdgeLabel[rankEdgeLabels[i]] > minSupportCount) { + newEdgeLabelNum = i; + } + } + //排名号比数量少1,所以要加回来 + newNodeLabelNum++; + newEdgeLabelNum++; + } + + /** + * 进行频繁子图的挖掘 + */ + void freqGraphMining(){ + long startTime = System.currentTimeMillis(); + long endTime; + Graph g; + sortAndReLabel(); + + resultGraphs = new ArrayList<>(); + totalGraphs = new ArrayList<>(); + // 通过图数据构造图结构 + for (GraphData gd : totalGraphDatas) { + g = new Graph(); + g = g.constructGraph(gd); + totalGraphs.add(g); + } + + // 根据新的点边的标号数初始化边频繁度对象 + EdgeFrequency ef = new EdgeFrequency(newNodeLabelNum, newEdgeLabelNum); + for (int i = 0; i < newNodeLabelNum; i++) { + for (int j = 0; j < newEdgeLabelNum; j++) { + for (int k = 0; k < newNodeLabelNum; k++) { + for (Graph tempG : totalGraphs) { + if (tempG.hasEdge(i, j, k)) { + ef.edgeFreqCount[i][j][k]++; + } + } + } + } + } + + Edge edge; + GraphCode gc; + for (int i = 0; i < newNodeLabelNum; i++) { + for (int j = 0; j < newEdgeLabelNum; j++) { + for (int k = 0; k < newNodeLabelNum; k++) { + if (ef.edgeFreqCount[i][j][k] >= minSupportCount) { + gc = new GraphCode(); + edge = new Edge(0, 1, i, j, k); + gc.getEdgeSeq().add(edge); + + // 将含有此边的图id加入到gc中 + for (int y = 0; y < totalGraphs.size(); y++) { + if (totalGraphs.get(y).hasEdge(i, j, k)) { + gc.getGs().add(y); + } + } + // 对某条满足阈值的边进行挖掘 + subMining(gc, 2); + } + } + } + } + + endTime = System.currentTimeMillis(); + System.out.println("算法执行时间" + (endTime - startTime) + "ms"); + printResultGraphInfo(); + } + + /** + * 进行频繁子图的挖掘 + * + * @param gc 图编码 + * @param next 图所含的点的个数 + */ + private void subMining(GraphCode gc, int next){ + Edge e; + Graph graph = new Graph(); + int id1; + int id2; + + for (int i = 0; i < next; i++) { + graph.nodeLabels.add(-1); + graph.edgeLabels.add(new ArrayList()); + graph.edgeNexts.add(new ArrayList()); + } + + // 首先根据图编码中的边五元组构造图 + for (int i = 0; i < gc.getEdgeSeq().size(); i++) { + e = gc.getEdgeSeq().get(i); + id1 = e.ix; + id2 = e.iy; + + graph.nodeLabels.set(id1, e.x); + graph.nodeLabels.set(id2, e.y); + graph.edgeLabels.get(id1).add(e.a); + graph.edgeLabels.get(id2).add(e.a); + graph.edgeNexts.get(id1).add(id2); + graph.edgeNexts.get(id2).add(id1); + } + + DFSCodeTraveler dTraveler = new DFSCodeTraveler(gc.getEdgeSeq(), graph); + dTraveler.traveler(); + if (!dTraveler.isMin) { + return; + } + + // 如果当前是最小编码则将此图加入到结果集中 + resultGraphs.add(graph); + Edge e1; + ArrayList gIds; + SubChildTraveler sct; + ArrayList edgeArray; + // 添加潜在的孩子边,每条孩子边所属的图id + HashMap> edge2GId = new HashMap<>(); + for (int i = 0; i < gc.gs.size(); i++) { + int id = gc.gs.get(i); + + // 在此结构的条件下,在多加一条边构成子图继续挖掘 + sct = new SubChildTraveler(gc.edgeSeq, totalGraphs.get(id)); + sct.traveler(); + edgeArray = sct.getResultChildEdge(); + + // 做边id的更新 + for (Edge e2 : edgeArray) { + if (!edge2GId.containsKey(e2)) { + gIds = new ArrayList<>(); + } else { + gIds = edge2GId.get(e2); + } + + gIds.add(id); + edge2GId.put(e2, gIds); + } + } + + for (Map.Entry entry : edge2GId.entrySet()) { + e1 = (Edge) entry.getKey(); + gIds = (ArrayList) entry.getValue(); + + // 如果此边的频度大于最小支持度值,则继续挖掘 + if (gIds.size() < minSupportCount) { + continue; + } + + GraphCode nGc = new GraphCode(); + nGc.edgeSeq.addAll(gc.edgeSeq); + // 在当前图中新加入一条边,构成新的子图进行挖掘 + nGc.edgeSeq.add(e1); + nGc.gs.addAll(gIds); + + if (e1.iy == next) { + // 如果边的点id设置是为当前最大值的时候,则开始寻找下一个点 + subMining(nGc, next + 1); + } else { + // 如果此点已经存在,则next值不变 + subMining(nGc, next); + } + } + } + + /** + * 输出频繁子图结果信息 + */ + private void printResultGraphInfo(){ + System.out.println(MessageFormat.format("挖掘出的频繁子图的个数为:{0}个", resultGraphs.size())); + } } diff --git a/GraphMining/DataMining_GSpan/Graph.java b/GraphMining/DataMining_GSpan/Graph.java index 39e02e7..91ec348 100644 --- a/GraphMining/DataMining_GSpan/Graph.java +++ b/GraphMining/DataMining_GSpan/Graph.java @@ -1,155 +1,139 @@ -package DataMining_GSpan; +package GraphMining.DataMining_GSpan; import java.util.ArrayList; /** * 图结构类 - * - * @author lyq - * + * + * @author Qstar */ -public class Graph { - // 图节点标号组 - ArrayList nodeLabels; - // 图的边标号组 - ArrayList> edgeLabels; - // 边2头的节点id号,在这里可以理解为下标号 - ArrayList> edgeNexts; - - public Graph() { - nodeLabels = new ArrayList<>(); - edgeLabels = new ArrayList<>(); - edgeNexts = new ArrayList<>(); - } - - public ArrayList getNodeLabels() { - return nodeLabels; - } - - public void setNodeLabels(ArrayList nodeLabels) { - this.nodeLabels = nodeLabels; - } - - /** - * 判断图中是否存在某条边 - * - * @param x - * 边的一端的节点标号 - * @param a - * 边的标号 - * @param y - * 边的另外一端节点标号 - * @return - */ - public boolean hasEdge(int x, int a, int y) { - boolean isContained = false; - int t; - - for (int i = 0; i < nodeLabels.size(); i++) { - // 先寻找2个端点标号,t代表找到的点的另外一个端点标号 - if (nodeLabels.get(i) == x) { - t = y; - } else if (nodeLabels.get(i) == y) { - t = x; - } else { - continue; - } - - for (int j = 0; j < edgeNexts.get(i).size(); j++) { - // 从此端点的所连接的点去比较对应的点和边 - if (edgeLabels.get(i).get(j) == a - && nodeLabels.get(edgeNexts.get(i).get(j)) == t) { - isContained = true; - return isContained; - } - } - } - - return isContained; - } - - /** - * 在图中移除某个边 - * - * @param x - * 边的某端的一个点标号 - * @param a - * 边的标号 - * @param y - * 边的另一端的一个点标号 - */ - public void removeEdge(int x, int a, int y) { - int t; - - for (int i = 0; i < nodeLabels.size(); i++) { - // 先寻找2个端点标号,t代表找到的点的另外一个端点标号 - if (nodeLabels.get(i) == x) { - t = y; - } else if (nodeLabels.get(i) == y) { - t = x; - } else { - continue; - } - - for (int j = 0; j < edgeNexts.get(i).size(); j++) { - // 从此端点的所连接的点去比较对应的点和边 - if (edgeLabels.get(i).get(j) == a - && nodeLabels.get(edgeNexts.get(i).get(j)) == t) { - int id; - // 在连接的点中去除该点 - edgeLabels.get(i).remove(j); - - id = edgeNexts.get(i).get(j); - edgeNexts.get(i).remove(j); - for (int k = 0; k < edgeNexts.get(id).size(); k++) { - if (edgeNexts.get(id).get(k) == i) { - edgeNexts.get(id).remove(k); - break; - } - } - break; - } - } - } - - } - - /** - * 根据图数据构造一个图 - * - * @param gd - * 图数据 - * @return - */ - public Graph constructGraph(GraphData gd) { - Graph graph = new Graph(); - - - // 构造一个图需要知道3点,1.图中有哪些点2.图中的每个点周围连着哪些点3.每个点周围连着哪些边 - for (int i = 0; i < gd.getNodeVisibles().size(); i++) { - if (gd.getNodeVisibles().get(i)) { - graph.getNodeLabels().add(gd.getNodeLabels().get(i)); - } - - // 添加对应id下的集合 - // id节点后有多少相连的边的标号 - graph.edgeLabels.add(new ArrayList()); - // id节点后有多少相连的节点的id - graph.edgeNexts.add(new ArrayList()); - } - - for (int i = 0; i < gd.getEdgeLabels().size(); i++) { - if (gd.getEdgeVisibles().get(i)) { - // 在此后面添加一个边标号 - graph.edgeLabels.get(gd.getEdgeX().get(i)).add(gd.getEdgeLabels().get(i)); - graph.edgeLabels.get(gd.getEdgeY().get(i)).add(gd.getEdgeLabels().get(i)); - graph.edgeNexts.get(gd.getEdgeX().get(i)).add( - gd.getEdgeY().get(i)); - graph.edgeNexts.get(gd.getEdgeY().get(i)).add( - gd.getEdgeX().get(i)); - } - } - - return graph; - } +class Graph { + // 图节点标号组 + ArrayList nodeLabels; + // 图的边标号组 + ArrayList> edgeLabels; + // 边2头的节点id号,在这里可以理解为下标号 + ArrayList> edgeNexts; + + Graph(){ + nodeLabels = new ArrayList<>(); + edgeLabels = new ArrayList<>(); + edgeNexts = new ArrayList<>(); + } + + ArrayList getNodeLabels(){ + return nodeLabels; + } + + /** + * 判断图中是否存在某条边 + * + * @param x 边的一端的节点标号 + * @param a 边的标号 + * @param y 边的另外一端节点标号 + */ + boolean hasEdge(int x, int a, int y){ + int t; + + for (int i = 0; i < nodeLabels.size(); i++) { + // 先寻找2个端点标号,t代表找到的点的另外一个端点标号 + if (nodeLabels.get(i) == x) { + t = y; + } else if (nodeLabels.get(i) == y) { + t = x; + } else { + continue; + } + + for (int j = 0; j < edgeNexts.get(i).size(); j++) { + // 从此端点的所连接的点去比较对应的点和边 + if (edgeLabels.get(i).get(j) == a + && nodeLabels.get(edgeNexts.get(i).get(j)) == t) { + return true; + } + } + } + + return false; + } + + /** + * 在图中移除某个边 + * + * @param x 边的某端的一个点标号 + * @param a 边的标号 + * @param y 边的另一端的一个点标号 + */ + public void removeEdge(int x, int a, int y){ + int t; + + for (int i = 0; i < nodeLabels.size(); i++) { + // 先寻找2个端点标号,t代表找到的点的另外一个端点标号 + if (nodeLabels.get(i) == x) { + t = y; + } else if (nodeLabels.get(i) == y) { + t = x; + } else { + continue; + } + + for (int j = 0; j < edgeNexts.get(i).size(); j++) { + // 从此端点的所连接的点去比较对应的点和边 + if (edgeLabels.get(i).get(j) == a + && nodeLabels.get(edgeNexts.get(i).get(j)) == t) { + int id; + // 在连接的点中去除该点 + edgeLabels.get(i).remove(j); + + id = edgeNexts.get(i).get(j); + edgeNexts.get(i).remove(j); + for (int k = 0; k < edgeNexts.get(id).size(); k++) { + if (edgeNexts.get(id).get(k) == i) { + edgeNexts.get(id).remove(k); + break; + } + } + break; + } + } + } + + } + + /** + * 根据图数据构造一个图 + * + * @param gd 图数据 + */ + Graph constructGraph(GraphData gd){ + Graph graph = new Graph(); + + + // 构造一个图需要知道3点,1.图中有哪些点2.图中的每个点周围连着哪些点3.每个点周围连着哪些边 + for (int i = 0; i < gd.getNodeVisibles().size(); i++) { + if (gd.getNodeVisibles().get(i)) { + graph.getNodeLabels().add(gd.getNodeLabels().get(i)); + } + + // 添加对应id下的集合 + // id节点后有多少相连的边的标号 + graph.edgeLabels.add(new ArrayList()); + // id节点后有多少相连的节点的id + graph.edgeNexts.add(new ArrayList()); + } + + for (int i = 0; i < gd.getEdgeLabels().size(); i++) { + if (gd.getEdgeVisibles().get(i)) { + // 在此后面添加一个边标号 + graph.edgeLabels.get(gd.getEdgeX().get(i)).add(gd.getEdgeLabels().get(i)); + graph.edgeLabels.get(gd.getEdgeY().get(i)).add(gd.getEdgeLabels().get(i)); + graph.edgeNexts.get(gd.getEdgeX().get(i)).add( + gd.getEdgeY().get(i)); + graph.edgeNexts.get(gd.getEdgeY().get(i)).add( + gd.getEdgeX().get(i)); + } + } + + return graph; + } } diff --git a/GraphMining/DataMining_GSpan/GraphCode.java b/GraphMining/DataMining_GSpan/GraphCode.java index 485700e..5267bc2 100644 --- a/GraphMining/DataMining_GSpan/GraphCode.java +++ b/GraphMining/DataMining_GSpan/GraphCode.java @@ -1,36 +1,29 @@ -package DataMining_GSpan; +package GraphMining.DataMining_GSpan; import java.util.ArrayList; /** * 图编码类 - * @author lyq * + * @author Qstar */ -public class GraphCode { - //边的集合,边的排序代表着边的添加次序 - ArrayList edgeSeq; - //拥有这些边的图的id - ArrayList gs; - - public GraphCode() { - this.edgeSeq = new ArrayList<>(); - this.gs = new ArrayList<>(); - } +class GraphCode { + //边的集合,边的排序代表着边的添加次序 + ArrayList edgeSeq; + //拥有这些边的图的id + ArrayList gs; - public ArrayList getEdgeSeq() { - return edgeSeq; - } + GraphCode(){ + this.edgeSeq = new ArrayList<>(); + this.gs = new ArrayList<>(); + } - public void setEdgeSeq(ArrayList edgeSeq) { - this.edgeSeq = edgeSeq; - } + ArrayList getEdgeSeq(){ + return edgeSeq; + } - public ArrayList getGs() { - return gs; - } + ArrayList getGs(){ + return gs; + } - public void setGs(ArrayList gs) { - this.gs = gs; - } } diff --git a/GraphMining/DataMining_GSpan/GraphData.java b/GraphMining/DataMining_GSpan/GraphData.java index 0bf7246..997be25 100644 --- a/GraphMining/DataMining_GSpan/GraphData.java +++ b/GraphMining/DataMining_GSpan/GraphData.java @@ -1,165 +1,135 @@ -package DataMining_GSpan; +package GraphMining.DataMining_GSpan; import java.util.ArrayList; /** * 图的数据类 - * - * @author lyq - * + * + * @author Qstar */ -public class GraphData { - // 节点组标号 - private ArrayList nodeLabels; - // 节点是否可用,可能被移除 - private ArrayList nodeVisibles; - // 边的集合标号 - private ArrayList edgeLabels; - // 边的一边点id - private ArrayList edgeX; - // 边的另一边的点id - private ArrayList edgeY; - // 边是否可用 - private ArrayList edgeVisibles; - - public GraphData() { - nodeLabels = new ArrayList<>(); - nodeVisibles = new ArrayList<>(); - - edgeLabels = new ArrayList<>(); - edgeX = new ArrayList<>(); - edgeY = new ArrayList<>(); - edgeVisibles = new ArrayList<>(); - } - - public ArrayList getNodeLabels() { - return nodeLabels; - } - - public void setNodeLabels(ArrayList nodeLabels) { - this.nodeLabels = nodeLabels; - } - - public ArrayList getNodeVisibles() { - return nodeVisibles; - } - - public void setNodeVisibles(ArrayList nodeVisibles) { - this.nodeVisibles = nodeVisibles; - } - - public ArrayList getEdgeLabels() { - return edgeLabels; - } - - public void setEdgeLabels(ArrayList edgeLabels) { - this.edgeLabels = edgeLabels; - } - - public ArrayList getEdgeX() { - return edgeX; - } - - public void setEdgeX(ArrayList edgeX) { - this.edgeX = edgeX; - } - - public ArrayList getEdgeY() { - return edgeY; - } - - public void setEdgeY(ArrayList edgeY) { - this.edgeY = edgeY; - } - - public ArrayList getEdgeVisibles() { - return edgeVisibles; - } - - public void setEdgeVisibles(ArrayList edgeVisibles) { - this.edgeVisibles = edgeVisibles; - } - - /** - * 根据点边频繁度移除图中不频繁的点边 - * - * @param freqNodeLabel - * 点的频繁度统计 - * @param freqEdgeLabel - * 边的频繁度统计 - * @param minSupportCount - * 最小支持度计数 - */ - public void removeInFreqNodeAndEdge(int[] freqNodeLabel, - int[] freqEdgeLabel, int minSupportCount) { - int label = 0; - int x = 0; - int y = 0; - - for (int i = 0; i < nodeLabels.size(); i++) { - label = nodeLabels.get(i); - if (freqNodeLabel[label] < minSupportCount) { - // 如果小于支持度计数,则此点不可用 - nodeVisibles.set(i, false); - } - } - - for (int i = 0; i < edgeLabels.size(); i++) { - label = edgeLabels.get(i); - - if (freqEdgeLabel[label] < minSupportCount) { - // 如果小于支持度计数,则此边不可用 - edgeVisibles.set(i, false); - continue; - } - - // 如果此边的某个端的端点已经不可用了,则此边也不可用,x,y表示id号 - x = edgeX.get(i); - y = edgeY.get(i); - if (!nodeVisibles.get(x) || !nodeVisibles.get(y)) { - edgeVisibles.set(i, false); - } - } - } - - /** - * 根据标号排序重新对满足条件的点边重新编号 - * - * @param nodeLabel2Rank - * 点排名 - * @param edgeLabel2Rank - * 边排名 - */ - public void reLabelByRank(int[] nodeLabel2Rank, int[] edgeLabel2Rank) { - int label = 0; - int count = 0; - int temp = 0; - // 旧的id对新id号的映射 - int[] oldId2New = new int[nodeLabels.size()]; - for (int i = 0; i < nodeLabels.size(); i++) { - label = nodeLabels.get(i); - - // 如果当前点是可用的,将此标号的排名号作为此点新的标号 - if (nodeVisibles.get(i)) { - nodeLabels.set(i, nodeLabel2Rank[label]); - oldId2New[i] = count; - count++; - } - } - - for (int i = 0; i < edgeLabels.size(); i++) { - label = edgeLabels.get(i); - - // 如果当前边是可用的,将此标号的排名号作为此点新的标号 - if (edgeVisibles.get(i)) { - edgeLabels.set(i, edgeLabel2Rank[label]); - - // 对此点做x,y的id号替换 - temp = edgeX.get(i); - edgeX.set(i, oldId2New[temp]); - temp = edgeY.get(i); - edgeY.set(i, oldId2New[temp]); - } - } - } +class GraphData { + // 节点组标号 + private ArrayList nodeLabels; + // 节点是否可用,可能被移除 + private ArrayList nodeVisibles; + // 边的集合标号 + private ArrayList edgeLabels; + // 边的一边点id + private ArrayList edgeX; + // 边的另一边的点id + private ArrayList edgeY; + // 边是否可用 + private ArrayList edgeVisibles; + + GraphData(){ + nodeLabels = new ArrayList<>(); + nodeVisibles = new ArrayList<>(); + + edgeLabels = new ArrayList<>(); + edgeX = new ArrayList<>(); + edgeY = new ArrayList<>(); + edgeVisibles = new ArrayList<>(); + } + + ArrayList getNodeLabels(){ + return nodeLabels; + } + + ArrayList getNodeVisibles(){ + return nodeVisibles; + } + + ArrayList getEdgeLabels(){ + return edgeLabels; + } + + ArrayList getEdgeX(){ + return edgeX; + } + + ArrayList getEdgeY(){ + return edgeY; + } + + ArrayList getEdgeVisibles(){ + return edgeVisibles; + } + + /** + * 根据点边频繁度移除图中不频繁的点边 + * + * @param freqNodeLabel 点的频繁度统计 + * @param freqEdgeLabel 边的频繁度统计 + * @param minSupportCount 最小支持度计数 + */ + void removeInFreqNodeAndEdge(int[] freqNodeLabel, + int[] freqEdgeLabel, int minSupportCount){ + int label; + int x; + int y; + + for (int i = 0; i < nodeLabels.size(); i++) { + label = nodeLabels.get(i); + if (freqNodeLabel[label] < minSupportCount) { + // 如果小于支持度计数,则此点不可用 + nodeVisibles.set(i, false); + } + } + + for (int i = 0; i < edgeLabels.size(); i++) { + label = edgeLabels.get(i); + + if (freqEdgeLabel[label] < minSupportCount) { + // 如果小于支持度计数,则此边不可用 + edgeVisibles.set(i, false); + continue; + } + + // 如果此边的某个端的端点已经不可用了,则此边也不可用,x,y表示id号 + x = edgeX.get(i); + y = edgeY.get(i); + if (!nodeVisibles.get(x) || !nodeVisibles.get(y)) { + edgeVisibles.set(i, false); + } + } + } + + /** + * 根据标号排序重新对满足条件的点边重新编号 + * + * @param nodeLabel2Rank 点排名 + * @param edgeLabel2Rank 边排名 + */ + void reLabelByRank(int[] nodeLabel2Rank, int[] edgeLabel2Rank){ + int label; + int count = 0; + int temp; + // 旧的id对新id号的映射 + int[] oldId2New = new int[nodeLabels.size()]; + for (int i = 0; i < nodeLabels.size(); i++) { + label = nodeLabels.get(i); + + // 如果当前点是可用的,将此标号的排名号作为此点新的标号 + if (nodeVisibles.get(i)) { + nodeLabels.set(i, nodeLabel2Rank[label]); + oldId2New[i] = count; + count++; + } + } + + for (int i = 0; i < edgeLabels.size(); i++) { + label = edgeLabels.get(i); + + // 如果当前边是可用的,将此标号的排名号作为此点新的标号 + if (edgeVisibles.get(i)) { + edgeLabels.set(i, edgeLabel2Rank[label]); + + // 对此点做x,y的id号替换 + temp = edgeX.get(i); + edgeX.set(i, oldId2New[temp]); + temp = edgeY.get(i); + edgeY.set(i, oldId2New[temp]); + } + } + } } diff --git a/GraphMining/DataMining_GSpan/SubChildTraveler.java b/GraphMining/DataMining_GSpan/SubChildTraveler.java index 8e88bc6..b1edffc 100644 --- a/GraphMining/DataMining_GSpan/SubChildTraveler.java +++ b/GraphMining/DataMining_GSpan/SubChildTraveler.java @@ -1,202 +1,196 @@ -package DataMining_GSpan; +package GraphMining.DataMining_GSpan; import java.util.ArrayList; -import java.util.HashMap; /** * 孩子图搜寻类,在当前边的基础上寻找可能的孩子边 - * - * @author lyq - * + * + * @author Qstar */ -public class SubChildTraveler { - // 当前的五元组边 - ArrayList edgeSeq; - // 当前的图 - Graph graph; - // 结果数据,孩子边对所属的图id组 - ArrayList childEdge; - // 图的点id对五元组id标识的映射 - int[] g2s; - // 五元组id标识对图的点id的映射 - int[] s2g; - // 图中边是否被用的情况 - boolean f[][]; - // 最右路径,rm[id]表示的是此id节点在最右路径中的下一个节点id - int[] rm; - // 下一个五元组的id - int next; - - public SubChildTraveler(ArrayList edgeSeq, Graph graph) { - this.edgeSeq = edgeSeq; - this.graph = graph; - this.childEdge = new ArrayList<>(); - } - - /** - * 在图中搜索可能存在的孩子边 - * - * @param next - * 新加入边的节点将设置的id - */ - public void traveler() { - this.next = edgeSeq.size() + 1; - int size = graph.nodeLabels.size(); - // 做id映射的初始化操作 - g2s = new int[size]; - s2g = new int[size]; - f = new boolean[size][size]; - - for (int i = 0; i < size; i++) { - g2s[i] = -1; - s2g[i] = -1; - - for (int j = 0; j < size; j++) { - // 代表点id为i到id为j点此边没有被用过 - f[i][j] = false; - } - } - - rm = new int[edgeSeq.size()+1]; - for (int i = 0; i < edgeSeq.size()+1; i++) { - rm[i] = -1; - } - // 寻找最右路径 - for (Edge e : edgeSeq) { - if (e.ix < e.iy && e.iy > rm[e.ix]) { - rm[e.ix] = e.iy; - } - } - - for (int i = 0; i < size; i++) { - // 寻找第一个标号相等的点 - if (edgeSeq.get(0).x != graph.nodeLabels.get(i)) { - continue; - } - - g2s[i] = 0; - s2g[0] = i; - dfsSearchEdge(0); - g2s[i] = -1; - s2g[0] = -1; - } - - } - - /** - * 在当前图中深度优先寻找正确的子图 - * - * @param currentPosition - * 当前找到的位置 - */ - public void dfsSearchEdge(int currentPosition) { - int rmPosition = 0; - // 如果找到底了,则在当前的子图的最右路径中寻找可能的边 - if (currentPosition >= edgeSeq.size()) { - rmPosition = 0; - while (rmPosition >= 0) { - int gId = s2g[rmPosition]; - // 在此点附近寻找可能的边 - for (int i = 0; i < graph.edgeNexts.get(gId).size(); i++) { - int gId2 = graph.edgeNexts.get(gId).get(i); - // 如果这条边已经被用过 - if (f[gId][gId2] || f[gId][gId2]) { - continue; - } - - // 在最右路径中添加边分为2种情况,第一种为在最右节点上添加,第二中为在最右路径上 的点添加 - // 如果找到的点没有被用过,可以进行边的拓展 - if (g2s[gId2] < 0) { - g2s[gId2] = next; - Edge e = new Edge(g2s[gId], g2s[gId2], - graph.nodeLabels.get(gId), graph.edgeLabels - .get(gId).get(i), - graph.nodeLabels.get(gId2)); - // 将新建的子边加入集合 - childEdge.add(e); - } else { - boolean flag = true; - // 如果这点已经存在,判断他是不是最右的点 - for (int j = 0; j < graph.edgeNexts.get(gId2).size(); j++) { - int tempId = graph.edgeNexts.get(gId2).get(j); - if (g2s[gId2] < g2s[tempId]) { - flag = false; - break; - } - } - - if (flag) { - Edge e = new Edge(g2s[gId], g2s[gId2], - graph.nodeLabels.get(gId), graph.edgeLabels - .get(gId).get(i), - graph.nodeLabels.get(gId2)); - // 将新建的子边加入集合 - childEdge.add(e); - } - } - } - // 一个最右路径上点找完,继续下一个 - rmPosition = rm[rmPosition]; - } - return; - } - - Edge e = edgeSeq.get(currentPosition); - // 所连接的点标号 - int y = e.y; - // 所连接的边标号 - int a = e.a; - int gId1 = s2g[e.ix]; - int gId2 = 0; - - for (int i = 0; i < graph.edgeLabels.get(gId1).size(); i++) { - // 判断所连接的边对应的标号 - if (graph.edgeLabels.get(gId1).get(i) != a) { - continue; - } - - // 判断所连接的点的标号 - int tempId = graph.edgeNexts.get(gId1).get(i); - if (graph.nodeLabels.get(tempId) != y) { - continue; - } - - gId2 = tempId; - // 如果这两点是没有设置过的 - if (g2s[gId2] == -1 && s2g[e.iy] == -1) { - g2s[gId2] = e.iy; - s2g[e.iy] = gId2; - f[gId1][gId2] = true; - f[gId2][gId1] = true; - dfsSearchEdge(currentPosition + 1); - f[gId1][gId2] = false; - f[gId2][gId1] = false; - g2s[gId2] = -1; - s2g[e.iy] = -1; - } else { - if (g2s[gId2] != e.iy) { - continue; - } - if (s2g[e.iy] != gId2) { - continue; - } - f[gId1][gId2] = true; - f[gId2][gId1] = true; - dfsSearchEdge(currentPosition); - f[gId1][gId2] = false; - f[gId2][gId1] = false; - } - } - - } - - /** - * 获取结果数据对 - * - * @return - */ - public ArrayList getResultChildEdge() { - return this.childEdge; - } +class SubChildTraveler { + // 当前的五元组边 + private ArrayList edgeSeq; + // 当前的图 + private Graph graph; + // 结果数据,孩子边对所属的图id组 + private ArrayList childEdge; + // 图的点id对五元组id标识的映射 + private int[] g2s; + // 五元组id标识对图的点id的映射 + private int[] s2g; + // 图中边是否被用的情况 + private boolean f[][]; + // 最右路径,rm[id]表示的是此id节点在最右路径中的下一个节点id + private int[] rm; + // 下一个五元组的id + private int next; + + SubChildTraveler(ArrayList edgeSeq, Graph graph){ + this.edgeSeq = edgeSeq; + this.graph = graph; + this.childEdge = new ArrayList<>(); + } + + /** + * 在图中搜索可能存在的孩子边 + *

+ * next 新加入边的节点将设置的id + */ + void traveler(){ + this.next = edgeSeq.size() + 1; + int size = graph.nodeLabels.size(); + // 做id映射的初始化操作 + g2s = new int[size]; + s2g = new int[size]; + f = new boolean[size][size]; + + for (int i = 0; i < size; i++) { + g2s[i] = -1; + s2g[i] = -1; + + for (int j = 0; j < size; j++) { + // 代表点id为i到id为j点此边没有被用过 + f[i][j] = false; + } + } + + rm = new int[edgeSeq.size() + 1]; + for (int i = 0; i < edgeSeq.size() + 1; i++) { + rm[i] = -1; + } + // 寻找最右路径 + for (Edge e : edgeSeq) { + if (e.ix < e.iy && e.iy > rm[e.ix]) { + rm[e.ix] = e.iy; + } + } + + for (int i = 0; i < size; i++) { + // 寻找第一个标号相等的点 + if (edgeSeq.get(0).x != graph.nodeLabels.get(i)) { + continue; + } + + g2s[i] = 0; + s2g[0] = i; + dfsSearchEdge(0); + g2s[i] = -1; + s2g[0] = -1; + } + + } + + /** + * 在当前图中深度优先寻找正确的子图 + * + * @param currentPosition 当前找到的位置 + */ + private void dfsSearchEdge(int currentPosition){ + int rmPosition; + // 如果找到底了,则在当前的子图的最右路径中寻找可能的边 + if (currentPosition >= edgeSeq.size()) { + rmPosition = 0; + while (rmPosition >= 0) { + int gId = s2g[rmPosition]; + // 在此点附近寻找可能的边 + for (int i = 0; i < graph.edgeNexts.get(gId).size(); i++) { + int gId2 = graph.edgeNexts.get(gId).get(i); + // 如果这条边已经被用过 + if (f[gId][gId2] || f[gId][gId2]) { + continue; + } + + // 在最右路径中添加边分为2种情况,第一种为在最右节点上添加,第二中为在最右路径上 的点添加 + // 如果找到的点没有被用过,可以进行边的拓展 + if (g2s[gId2] < 0) { + g2s[gId2] = next; + Edge e = new Edge(g2s[gId], g2s[gId2], + graph.nodeLabels.get(gId), graph.edgeLabels + .get(gId).get(i), + graph.nodeLabels.get(gId2)); + // 将新建的子边加入集合 + childEdge.add(e); + } else { + boolean flag = true; + // 如果这点已经存在,判断他是不是最右的点 + for (int j = 0; j < graph.edgeNexts.get(gId2).size(); j++) { + int tempId = graph.edgeNexts.get(gId2).get(j); + if (g2s[gId2] < g2s[tempId]) { + flag = false; + break; + } + } + + if (flag) { + Edge e = new Edge(g2s[gId], g2s[gId2], + graph.nodeLabels.get(gId), graph.edgeLabels + .get(gId).get(i), + graph.nodeLabels.get(gId2)); + // 将新建的子边加入集合 + childEdge.add(e); + } + } + } + // 一个最右路径上点找完,继续下一个 + rmPosition = rm[rmPosition]; + } + return; + } + + Edge e = edgeSeq.get(currentPosition); + // 所连接的点标号 + int y = e.y; + // 所连接的边标号 + int a = e.a; + int gId1 = s2g[e.ix]; + int gId2; + + for (int i = 0; i < graph.edgeLabels.get(gId1).size(); i++) { + // 判断所连接的边对应的标号 + if (graph.edgeLabels.get(gId1).get(i) != a) { + continue; + } + + // 判断所连接的点的标号 + int tempId = graph.edgeNexts.get(gId1).get(i); + if (graph.nodeLabels.get(tempId) != y) { + continue; + } + + gId2 = tempId; + // 如果这两点是没有设置过的 + if (g2s[gId2] == -1 && s2g[e.iy] == -1) { + g2s[gId2] = e.iy; + s2g[e.iy] = gId2; + f[gId1][gId2] = true; + f[gId2][gId1] = true; + dfsSearchEdge(currentPosition + 1); + f[gId1][gId2] = false; + f[gId2][gId1] = false; + g2s[gId2] = -1; + s2g[e.iy] = -1; + } else { + if (g2s[gId2] != e.iy) { + continue; + } + if (s2g[e.iy] != gId2) { + continue; + } + f[gId1][gId2] = true; + f[gId2][gId1] = true; + dfsSearchEdge(currentPosition); + f[gId1][gId2] = false; + f[gId2][gId1] = false; + } + } + + } + + /** + * 获取结果数据对 + */ + ArrayList getResultChildEdge(){ + return this.childEdge; + } } diff --git a/IntegratedMining/DataMining_CBA/AprioriTool/AprioriTool.java b/IntegratedMining/DataMining_CBA/AprioriTool/AprioriTool.java index 5701fa9..b8031d8 100644 --- a/IntegratedMining/DataMining_CBA/AprioriTool/AprioriTool.java +++ b/IntegratedMining/DataMining_CBA/AprioriTool/AprioriTool.java @@ -1,403 +1,388 @@ -package DataMining_CBA.AprioriTool; +package IntegratedMining.DataMining_CBA.AprioriTool; import java.text.MessageFormat; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; +import java.util.*; /** * apriori算法工具类 - * - * @author lyq - * + * + * @author Qstar */ public class AprioriTool { - // 最小支持度计数 - private int minSupportCount; - // 每个事务中的商品ID - private ArrayList totalGoodsIDs; - // 过程中计算出来的所有频繁项集列表 - private ArrayList resultItem; - // 过程中计算出来频繁项集的ID集合 - private ArrayList resultItemID; - - public AprioriTool(ArrayList totalGoodsIDs, int minSupportCount) { - this.totalGoodsIDs = totalGoodsIDs; - this.minSupportCount = minSupportCount; - } - - - /** - * 判读字符数组array2是否包含于数组array1中 - * - * @param array1 - * @param array2 - * @return - */ - public boolean iSStrContain(String[] array1, String[] array2) { - if (array1 == null || array2 == null) { - return false; - } - - boolean iSContain = false; - for (String s : array2) { - // 新的字母比较时,重新初始化变量 - iSContain = false; - // 判读array2中每个字符,只要包括在array1中 ,就算包含 - for (String s2 : array1) { - if (s.equals(s2)) { - iSContain = true; - break; - } - } - - // 如果已经判断出不包含了,则直接中断循环 - if (!iSContain) { - break; - } - } - - return iSContain; - } - - /** - * 项集进行连接运算 - */ - public void computeLink() { - // 连接计算的终止数,k项集必须算到k-1子项集为止 - int endNum = 0; - // 当前已经进行连接运算到几项集,开始时就是1项集 - int currentNum = 1; - // 商品,1频繁项集映射图 - HashMap itemMap = new HashMap<>(); - FrequentItem tempItem; - // 初始列表 - ArrayList list = new ArrayList<>(); - // 经过连接运算后产生的结果项集 - resultItem = new ArrayList<>(); - resultItemID = new ArrayList<>(); - // 商品ID的种类 - ArrayList idType = new ArrayList<>(); - for (String[] a : totalGoodsIDs) { - for (String s : a) { - if (!idType.contains(s)) { - tempItem = new FrequentItem(new String[] { s }, 1); - idType.add(s); - resultItemID.add(new String[] { s }); - } else { - // 支持度计数加1 - tempItem = itemMap.get(s); - tempItem.setCount(tempItem.getCount() + 1); - } - itemMap.put(s, tempItem); - } - } - // 将初始频繁项集转入到列表中,以便继续做连接运算 - for (Map.Entry entry : itemMap.entrySet()) { - list.add((FrequentItem) entry.getValue()); - } - // 按照商品ID进行排序,否则连接计算结果将会不一致,将会减少 - Collections.sort(list); - resultItem.addAll(list); - - String[] array1; - String[] array2; - String[] resultArray; - ArrayList tempIds; - ArrayList resultContainer; - // 总共要算到endNum项集 - endNum = list.size() - 1; - - while (currentNum < endNum) { - resultContainer = new ArrayList<>(); - for (int i = 0; i < list.size() - 1; i++) { - tempItem = list.get(i); - array1 = tempItem.getIdArray(); - for (int j = i + 1; j < list.size(); j++) { - tempIds = new ArrayList<>(); - array2 = list.get(j).getIdArray(); - for (int k = 0; k < array1.length; k++) { - // 如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作 - if (array1[k].equals(array2[k])) { - tempIds.add(array1[k]); - } else { - tempIds.add(array1[k]); - tempIds.add(array2[k]); - } - } - resultArray = new String[tempIds.size()]; - tempIds.toArray(resultArray); - - boolean isContain = false; - // 过滤不符合条件的的ID数组,包括重复的和长度不符合要求的 - if (resultArray.length == (array1.length + 1)) { - isContain = isIDArrayContains(resultContainer, - resultArray); - if (!isContain) { - resultContainer.add(resultArray); - } - } - } - } - - // 做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集 - list = cutItem(resultContainer); - currentNum++; - } - - // 输出频繁项集 - for (int k = 1; k <= currentNum; k++) { - System.out.println("频繁" + k + "项集:"); - for (FrequentItem i : resultItem) { - if (i.getLength() == k) { - System.out.print("{"); - for (String t : i.getIdArray()) { - System.out.print(t + ","); - } - System.out.print("},"); - } - } - System.out.println(); - } - } - - /** - * 判断列表结果中是否已经包含此数组 - * - * @param container - * ID数组容器 - * @param array - * 待比较数组 - * @return - */ - private boolean isIDArrayContains(ArrayList container, - String[] array) { - boolean isContain = true; - if (container.size() == 0) { - isContain = false; - return isContain; - } - - for (String[] s : container) { - // 比较的视乎必须保证长度一样 - if (s.length != array.length) { - continue; - } - - isContain = true; - for (int i = 0; i < s.length; i++) { - // 只要有一个id不等,就算不相等 - if (s[i] != array[i]) { - isContain = false; - break; - } - } - - // 如果已经判断是包含在容器中时,直接退出 - if (isContain) { - break; - } - } - - return isContain; - } - - /** - * 对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集 - */ - private ArrayList cutItem(ArrayList resultIds) { - String[] temp; - // 忽略的索引位置,以此构建子集 - int igNoreIndex = 0; - FrequentItem tempItem; - // 剪枝生成新的频繁项集 - ArrayList newItem = new ArrayList<>(); - // 不符合要求的id - ArrayList deleteIdArray = new ArrayList<>(); - // 子项集是否也为频繁子项集 - boolean isContain = true; - - for (String[] array : resultIds) { - // 列举出其中的一个个的子项集,判断存在于频繁项集列表中 - temp = new String[array.length - 1]; - for (igNoreIndex = 0; igNoreIndex < array.length; igNoreIndex++) { - isContain = true; - for (int j = 0, k = 0; j < array.length; j++) { - if (j != igNoreIndex) { - temp[k] = array[j]; - k++; - } - } - - if (!isIDArrayContains(resultItemID, temp)) { - isContain = false; - break; - } - } - - if (!isContain) { - deleteIdArray.add(array); - } - } - - // 移除不符合条件的ID组合 - resultIds.removeAll(deleteIdArray); - - // 移除支持度计数不够的id集合 - int tempCount = 0; - for (String[] array : resultIds) { - tempCount = 0; - for (String[] array2 : totalGoodsIDs) { - if (isStrArrayContain(array2, array)) { - tempCount++; - } - } - - // 如果支持度计数大于等于最小最小支持度计数则生成新的频繁项集,并加入结果集中 - if (tempCount >= minSupportCount) { - tempItem = new FrequentItem(array, tempCount); - newItem.add(tempItem); - resultItemID.add(array); - resultItem.add(tempItem); - } - } - - return newItem; - } - - /** - * 数组array2是否包含于array1中,不需要完全一样 - * - * @param array1 - * @param array2 - * @return - */ - private boolean isStrArrayContain(String[] array1, String[] array2) { - boolean isContain = true; - for (String s2 : array2) { - isContain = false; - for (String s1 : array1) { - // 只要s2字符存在于array1中,这个字符就算包含在array1中 - if (s2.equals(s1)) { - isContain = true; - break; - } - } - - // 一旦发现不包含的字符,则array2数组不包含于array1中 - if (!isContain) { - break; - } - } - - return isContain; - } - - /** - * 根据产生的频繁项集输出关联规则 - * - * @param minConf - * 最小置信度阈值 - */ - public void printAttachRule(double minConf) { - // 进行连接和剪枝操作 - computeLink(); - - int count1 = 0; - int count2 = 0; - ArrayList childGroup1; - ArrayList childGroup2; - String[] group1; - String[] group2; - // 以最后一个频繁项集做关联规则的输出 - String[] array = resultItem.get(resultItem.size() - 1).getIdArray(); - // 子集总数,计算的时候除去自身和空集 - int totalNum = (int) Math.pow(2, array.length); - String[] temp; - // 二进制数组,用来代表各个子集 - int[] binaryArray; - // 除去头和尾部 - for (int i = 1; i < totalNum - 1; i++) { - binaryArray = new int[array.length]; - numToBinaryArray(binaryArray, i); - - childGroup1 = new ArrayList<>(); - childGroup2 = new ArrayList<>(); - count1 = 0; - count2 = 0; - // 按照二进制位关系取出子集 - for (int j = 0; j < binaryArray.length; j++) { - if (binaryArray[j] == 1) { - childGroup1.add(array[j]); - } else { - childGroup2.add(array[j]); - } - } - - group1 = new String[childGroup1.size()]; - group2 = new String[childGroup2.size()]; - - childGroup1.toArray(group1); - childGroup2.toArray(group2); - - for (String[] a : totalGoodsIDs) { - if (isStrArrayContain(a, group1)) { - count1++; - - // 在group1的条件下,统计group2的事件发生次数 - if (isStrArrayContain(a, group2)) { - count2++; - } - } - } - - // {A}-->{B}的意思为在A的情况下发生B的概率 - System.out.print("{"); - for (String s : group1) { - System.out.print(s + ", "); - } - System.out.print("}-->"); - System.out.print("{"); - for (String s : group2) { - System.out.print(s + ", "); - } - System.out.print(MessageFormat.format( - "},confidence(置信度):{0}/{1}={2}", count2, count1, count2 - * 1.0 / count1)); - if (count2 * 1.0 / count1 < minConf) { - // 不符合要求,不是强规则 - System.out.println("由于此规则置信度未达到最小置信度的要求,不是强规则"); - } else { - System.out.println("为强规则"); - } - } - - } - - /** - * 数字转为二进制形式 - * - * @param binaryArray - * 转化后的二进制数组形式 - * @param num - * 待转化数字 - */ - private void numToBinaryArray(int[] binaryArray, int num) { - int index = 0; - while (num != 0) { - binaryArray[index] = num % 2; - index++; - num /= 2; - } - } - - /** - * 获取算法挖掘的所有频繁项集 - * @return - */ - public ArrayList getTotalFrequentItems(){ - return this.resultItem; - } + // 最小支持度计数 + private int minSupportCount; + // 每个事务中的商品ID + private ArrayList totalGoodsIDs; + // 过程中计算出来的所有频繁项集列表 + private ArrayList resultItem; + // 过程中计算出来频繁项集的ID集合 + private ArrayList resultItemID; + + public AprioriTool(ArrayList totalGoodsIDs, int minSupportCount){ + this.totalGoodsIDs = totalGoodsIDs; + this.minSupportCount = minSupportCount; + } + + + /** + * 判读字符数组array2是否包含于数组array1中 + * + * @param array1 + * @param array2 + * @return + */ + public boolean iSStrContain(String[] array1, String[] array2){ + if (array1 == null || array2 == null) { + return false; + } + + boolean iSContain = false; + for (String s : array2) { + // 新的字母比较时,重新初始化变量 + iSContain = false; + // 判读array2中每个字符,只要包括在array1中 ,就算包含 + for (String s2 : array1) { + if (s.equals(s2)) { + iSContain = true; + break; + } + } + + // 如果已经判断出不包含了,则直接中断循环 + if (!iSContain) { + break; + } + } + + return iSContain; + } + + /** + * 项集进行连接运算 + */ + public void computeLink(){ + // 连接计算的终止数,k项集必须算到k-1子项集为止 + int endNum; + // 当前已经进行连接运算到几项集,开始时就是1项集 + int currentNum = 1; + // 商品,1频繁项集映射图 + HashMap itemMap = new HashMap<>(); + FrequentItem tempItem; + // 初始列表 + ArrayList list = new ArrayList<>(); + // 经过连接运算后产生的结果项集 + resultItem = new ArrayList<>(); + resultItemID = new ArrayList<>(); + // 商品ID的种类 + ArrayList idType = new ArrayList<>(); + for (String[] a : totalGoodsIDs) { + for (String s : a) { + if (!idType.contains(s)) { + tempItem = new FrequentItem(new String[]{s}, 1); + idType.add(s); + resultItemID.add(new String[]{s}); + } else { + // 支持度计数加1 + tempItem = itemMap.get(s); + tempItem.setCount(tempItem.getCount() + 1); + } + itemMap.put(s, tempItem); + } + } + // 将初始频繁项集转入到列表中,以便继续做连接运算 + for (Map.Entry entry : itemMap.entrySet()) { + list.add((FrequentItem) entry.getValue()); + } + // 按照商品ID进行排序,否则连接计算结果将会不一致,将会减少 + Collections.sort(list); + resultItem.addAll(list); + + String[] array1; + String[] array2; + String[] resultArray; + ArrayList tempIds; + ArrayList resultContainer; + // 总共要算到endNum项集 + endNum = list.size() - 1; + + while (currentNum < endNum) { + resultContainer = new ArrayList<>(); + for (int i = 0; i < list.size() - 1; i++) { + tempItem = list.get(i); + array1 = tempItem.getIdArray(); + for (int j = i + 1; j < list.size(); j++) { + tempIds = new ArrayList<>(); + array2 = list.get(j).getIdArray(); + for (int k = 0; k < array1.length; k++) { + // 如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作 + if (array1[k].equals(array2[k])) { + tempIds.add(array1[k]); + } else { + tempIds.add(array1[k]); + tempIds.add(array2[k]); + } + } + resultArray = new String[tempIds.size()]; + tempIds.toArray(resultArray); + + boolean isContain; + // 过滤不符合条件的的ID数组,包括重复的和长度不符合要求的 + if (resultArray.length == (array1.length + 1)) { + isContain = isIDArrayContains(resultContainer, + resultArray); + if (!isContain) { + resultContainer.add(resultArray); + } + } + } + } + + // 做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集 + list = cutItem(resultContainer); + currentNum++; + } + + // 输出频繁项集 + for (int k = 1; k <= currentNum; k++) { + System.out.println("频繁" + k + "项集:"); + for (FrequentItem i : resultItem) { + if (i.getLength() == k) { + System.out.print("{"); + for (String t : i.getIdArray()) { + System.out.print(t + ","); + } + System.out.print("},"); + } + } + System.out.println(); + } + } + + /** + * 判断列表结果中是否已经包含此数组 + * + * @param container ID数组容器 + * @param array 待比较数组 + */ + private boolean isIDArrayContains(ArrayList container, + String[] array){ + boolean isContain = true; + if (container.size() == 0) { + return false; + } + + for (String[] s : container) { + // 比较的视乎必须保证长度一样 + if (s.length != array.length) { + continue; + } + + isContain = true; + for (int i = 0; i < s.length; i++) { + // 只要有一个id不等,就算不相等 + if (!Objects.equals(s[i], array[i])) { + isContain = false; + break; + } + } + + // 如果已经判断是包含在容器中时,直接退出 + if (isContain) { + break; + } + } + return isContain; + } + + /** + * 对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集 + */ + private ArrayList cutItem(ArrayList resultIds){ + String[] temp; + // 忽略的索引位置,以此构建子集 + int igNoreIndex; + FrequentItem tempItem; + // 剪枝生成新的频繁项集 + ArrayList newItem = new ArrayList<>(); + // 不符合要求的id + ArrayList deleteIdArray = new ArrayList<>(); + // 子项集是否也为频繁子项集 + boolean isContain = true; + + for (String[] array : resultIds) { + // 列举出其中的一个个的子项集,判断存在于频繁项集列表中 + temp = new String[array.length - 1]; + for (igNoreIndex = 0; igNoreIndex < array.length; igNoreIndex++) { + isContain = true; + for (int j = 0, k = 0; j < array.length; j++) { + if (j != igNoreIndex) { + temp[k] = array[j]; + k++; + } + } + + if (!isIDArrayContains(resultItemID, temp)) { + isContain = false; + break; + } + } + + if (!isContain) { + deleteIdArray.add(array); + } + } + + // 移除不符合条件的ID组合 + resultIds.removeAll(deleteIdArray); + + // 移除支持度计数不够的id集合 + int tempCount; + for (String[] array : resultIds) { + tempCount = 0; + for (String[] array2 : totalGoodsIDs) { + if (isStrArrayContain(array2, array)) { + tempCount++; + } + } + + // 如果支持度计数大于等于最小最小支持度计数则生成新的频繁项集,并加入结果集中 + if (tempCount >= minSupportCount) { + tempItem = new FrequentItem(array, tempCount); + newItem.add(tempItem); + resultItemID.add(array); + resultItem.add(tempItem); + } + } + + return newItem; + } + + /** + * 数组array2是否包含于array1中,不需要完全一样 + * + * @param array1 数组1 + * @param array2 数组2 + */ + private boolean isStrArrayContain(String[] array1, String[] array2){ + boolean isContain = true; + for (String s2 : array2) { + isContain = false; + for (String s1 : array1) { + // 只要s2字符存在于array1中,这个字符就算包含在array1中 + if (s2.equals(s1)) { + isContain = true; + break; + } + } + + // 一旦发现不包含的字符,则array2数组不包含于array1中 + if (!isContain) { + break; + } + } + + return isContain; + } + + /** + * 根据产生的频繁项集输出关联规则 + * + * @param minConf 最小置信度阈值 + */ + public void printAttachRule(double minConf){ + // 进行连接和剪枝操作 + computeLink(); + + int count1; + int count2; + ArrayList childGroup1; + ArrayList childGroup2; + String[] group1; + String[] group2; + // 以最后一个频繁项集做关联规则的输出 + String[] array = resultItem.get(resultItem.size() - 1).getIdArray(); + // 子集总数,计算的时候除去自身和空集 + int totalNum = (int) Math.pow(2, array.length); + // 二进制数组,用来代表各个子集 + int[] binaryArray; + // 除去头和尾部 + for (int i = 1; i < totalNum - 1; i++) { + binaryArray = new int[array.length]; + numToBinaryArray(binaryArray, i); + + childGroup1 = new ArrayList<>(); + childGroup2 = new ArrayList<>(); + count1 = 0; + count2 = 0; + // 按照二进制位关系取出子集 + for (int j = 0; j < binaryArray.length; j++) { + if (binaryArray[j] == 1) { + childGroup1.add(array[j]); + } else { + childGroup2.add(array[j]); + } + } + + group1 = new String[childGroup1.size()]; + group2 = new String[childGroup2.size()]; + + childGroup1.toArray(group1); + childGroup2.toArray(group2); + + for (String[] a : totalGoodsIDs) { + if (isStrArrayContain(a, group1)) { + count1++; + + // 在group1的条件下,统计group2的事件发生次数 + if (isStrArrayContain(a, group2)) { + count2++; + } + } + } + + // {A}-->{B}的意思为在A的情况下发生B的概率 + System.out.print("{"); + for (String s : group1) { + System.out.print(s + ", "); + } + System.out.print("}-->"); + System.out.print("{"); + for (String s : group2) { + System.out.print(s + ", "); + } + System.out.print(MessageFormat.format( + "},confidence(置信度):{0}/{1}={2}", count2, count1, count2 + * 1.0 / count1)); + if (count2 * 1.0 / count1 < minConf) { + // 不符合要求,不是强规则 + System.out.println("由于此规则置信度未达到最小置信度的要求,不是强规则"); + } else { + System.out.println("为强规则"); + } + } + + } + + /** + * 数字转为二进制形式 + * + * @param binaryArray 转化后的二进制数组形式 + * @param num 待转化数字 + */ + private void numToBinaryArray(int[] binaryArray, int num){ + int index = 0; + while (num != 0) { + binaryArray[index] = num % 2; + index++; + num /= 2; + } + } + + /** + * 获取算法挖掘的所有频繁项集 + */ + public ArrayList getTotalFrequentItems(){ + return this.resultItem; + } } diff --git a/IntegratedMining/DataMining_CBA/AprioriTool/FrequentItem.java b/IntegratedMining/DataMining_CBA/AprioriTool/FrequentItem.java index 6335e3d..d95b6d3 100644 --- a/IntegratedMining/DataMining_CBA/AprioriTool/FrequentItem.java +++ b/IntegratedMining/DataMining_CBA/AprioriTool/FrequentItem.java @@ -1,56 +1,51 @@ -package DataMining_CBA.AprioriTool; +package IntegratedMining.DataMining_CBA.AprioriTool; /** * 频繁项集 - * - * @author lyq - * + * + * @author Qstar */ -public class FrequentItem implements Comparable{ - // 频繁项集的集合ID - private String[] idArray; - // 频繁项集的支持度计数 - private int count; - //频繁项集的长度,1项集或是2项集,亦或是3项集 - private int length; - - public FrequentItem(String[] idArray, int count){ - this.idArray = idArray; - this.count = count; - length = idArray.length; - } - - public String[] getIdArray() { - return idArray; - } - - public void setIdArray(String[] idArray) { - this.idArray = idArray; - } - - public int getCount() { - return count; - } - - public void setCount(int count) { - this.count = count; - } - - public int getLength() { - return length; - } - - public void setLength(int length) { - this.length = length; - } - - @Override - public int compareTo(FrequentItem o) { - // TODO Auto-generated method stub - Integer int1 = Integer.parseInt(this.getIdArray()[0]); - Integer int2 = Integer.parseInt(o.getIdArray()[0]); - - return int1.compareTo(int2); - } - +public class FrequentItem implements Comparable { + // 频繁项集的集合ID + private String[] idArray; + // 频繁项集的支持度计数 + private int count; + //频繁项集的长度,1项集或是2项集,亦或是3项集 + private int length; + + FrequentItem(String[] idArray, int count){ + this.idArray = idArray; + this.count = count; + length = idArray.length; + } + + public String[] getIdArray(){ + return idArray; + } + + public int getCount(){ + return count; + } + + public void setCount(int count){ + this.count = count; + } + + public int getLength(){ + return length; + } + + public void setLength(int length){ + this.length = length; + } + + @Override + public int compareTo(FrequentItem o){ + // TODO Auto-generated method stub + Integer int1 = Integer.parseInt(this.getIdArray()[0]); + Integer int2 = Integer.parseInt(o.getIdArray()[0]); + + return int1.compareTo(int2); + } + } diff --git a/IntegratedMining/DataMining_CBA/CBATool.java b/IntegratedMining/DataMining_CBA/CBATool.java index e079e7b..b1c6588 100644 --- a/IntegratedMining/DataMining_CBA/CBATool.java +++ b/IntegratedMining/DataMining_CBA/CBATool.java @@ -1,4 +1,7 @@ -package DataMining_CBA; +package IntegratedMining.DataMining_CBA; + +import IntegratedMining.DataMining_CBA.AprioriTool.AprioriTool; +import IntegratedMining.DataMining_CBA.AprioriTool.FrequentItem; import java.io.BufferedReader; import java.io.File; @@ -6,344 +9,322 @@ import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -import DataMining_CBA.AprioriTool.AprioriTool; -import DataMining_CBA.AprioriTool.FrequentItem; /** * CBA算法(关联规则分类)工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class CBATool { - // 年龄的类别划分 - public final String AGE = "Age"; - public final String AGE_YOUNG = "Young"; - public final String AGE_MIDDLE_AGED = "Middle_aged"; - public final String AGE_Senior = "Senior"; - - // 测试数据地址 - private String filePath; - // 最小支持度阈值率 - private double minSupportRate; - // 最小置信度阈值,用来判断是否能够成为关联规则 - private double minConf; - // 最小支持度 - private int minSupportCount; - // 属性列名称 - private String[] attrNames; - // 类别属性所代表的数字集合 - private ArrayList classTypes; - // 用二维数组保存测试数据 - private ArrayList totalDatas; - // Apriori算法工具类 - private AprioriTool aprioriTool; - // 属性到数字的映射图 - private HashMap attr2Num; - private HashMap num2Attr; - - public CBATool(String filePath, double minSupportRate, double minConf) { - this.filePath = filePath; - this.minConf = minConf; - this.minSupportRate = minSupportRate; - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - totalDatas = new ArrayList<>(); - for (String[] array : dataArray) { - totalDatas.add(array); - } - attrNames = totalDatas.get(0); - minSupportCount = (int) (minSupportRate * totalDatas.size()); - - attributeReplace(); - } - - /** - * 属性值的替换,替换成数字的形式,以便进行频繁项的挖掘 - */ - private void attributeReplace() { - int currentValue = 1; - int num = 0; - String s; - // 属性名到数字的映射图 - attr2Num = new HashMap<>(); - num2Attr = new HashMap<>(); - classTypes = new ArrayList<>(); - - // 按照1列列的方式来,从左往右边扫描,跳过列名称行和id列 - for (int j = 1; j < attrNames.length; j++) { - for (int i = 1; i < totalDatas.size(); i++) { - s = totalDatas.get(i)[j]; - // 如果是数字形式的,这里只做年龄类别转换,其他的数字情况类似 - if (attrNames[j].equals(AGE)) { - num = Integer.parseInt(s); - if (num <= 20 && num > 0) { - totalDatas.get(i)[j] = AGE_YOUNG; - } else if (num > 20 && num <= 40) { - totalDatas.get(i)[j] = AGE_MIDDLE_AGED; - } else if (num > 40) { - totalDatas.get(i)[j] = AGE_Senior; - } - } - - if (!attr2Num.containsKey(totalDatas.get(i)[j])) { - attr2Num.put(totalDatas.get(i)[j], currentValue); - num2Attr.put(currentValue, totalDatas.get(i)[j]); - if (j == attrNames.length - 1) { - // 如果是组后一列,说明是分类类别列,记录下来 - classTypes.add(currentValue); - } - - currentValue++; - } - } - } - - // 对原始的数据作属性替换,每条记录变为类似于事务数据的形式 - for (int i = 1; i < totalDatas.size(); i++) { - for (int j = 1; j < attrNames.length; j++) { - s = totalDatas.get(i)[j]; - if (attr2Num.containsKey(s)) { - totalDatas.get(i)[j] = attr2Num.get(s) + ""; - } - } - } - } - - /** - * Apriori计算全部频繁项集 - * @return - */ - private ArrayList aprioriCalculate() { - String[] tempArray; - ArrayList totalFrequentItems; - ArrayList copyData = (ArrayList) totalDatas.clone(); - // 去除属性名称行 - copyData.remove(0); - // 去除首列ID - for (int i = 0; i < copyData.size(); i++) { - String[] array = copyData.get(i); - tempArray = new String[array.length - 1]; - System.arraycopy(array, 1, tempArray, 0, tempArray.length); - copyData.set(i, tempArray); - } - aprioriTool = new AprioriTool(copyData, minSupportCount); - aprioriTool.computeLink(); - totalFrequentItems = aprioriTool.getTotalFrequentItems(); - - return totalFrequentItems; - } - - /** - * 基于关联规则的分类 - * - * @param attrValues - * 预先知道的一些属性 - * @return - */ - public String CBAJudge(String attrValues) { - int value = 0; - // 最终分类类别 - String classType = null; - String[] tempArray; - // 已知的属性值 - ArrayList attrValueList = new ArrayList<>(); - ArrayList totalFrequentItems; - - totalFrequentItems = aprioriCalculate(); - // 将查询条件进行逐一属性的分割 - String[] array = attrValues.split(","); - for (String record : array) { - tempArray = record.split("="); - value = attr2Num.get(tempArray[1]); - attrValueList.add(value + ""); - } - - // 在频繁项集中寻找符合条件的项 - for (FrequentItem item : totalFrequentItems) { - // 过滤掉不满足个数频繁项 - if (item.getIdArray().length < (attrValueList.size() + 1)) { - continue; - } - - // 要保证查询的属性都包含在频繁项集中 - if (itemIsSatisfied(item, attrValueList)) { - tempArray = item.getIdArray(); - classType = classificationBaseRules(tempArray); - - if (classType != null) { - // 作属性替换 - classType = num2Attr.get(Integer.parseInt(classType)); - break; - } - } - } - - return classType; - } - - /** - * 基于关联规则进行分类 - * - * @param items - * 频繁项 - * @return - */ - private String classificationBaseRules(String[] items) { - String classType = null; - String[] arrayTemp; - int count1 = 0; - int count2 = 0; - // 置信度 - double confidenceRate; - - String[] noClassTypeItems = new String[items.length - 1]; - for (int i = 0, k = 0; i < items.length; i++) { - if (!classTypes.contains(Integer.parseInt(items[i]))) { - noClassTypeItems[k] = items[i]; - k++; - } else { - classType = items[i]; - } - } - - for (String[] array : totalDatas) { - // 去除ID数字号 - arrayTemp = new String[array.length - 1]; - System.arraycopy(array, 1, arrayTemp, 0, array.length - 1); - if (isStrArrayContain(arrayTemp, noClassTypeItems)) { - count1++; - - if (isStrArrayContain(arrayTemp, items)) { - count2++; - } - } - } - - // 做置信度的计算 - confidenceRate = count1 * 1.0 / count2; - if (confidenceRate >= minConf) { - return classType; - } else { - // 如果不满足最小置信度要求,则此关联规则无效 - return null; - } - } - - /** - * 判断单个字符是否包含在字符数组中 - * - * @param array - * 字符数组 - * @param s - * 判断的单字符 - * @return - */ - private boolean strIsContained(String[] array, String s) { - boolean isContained = false; - - for (String str : array) { - if (str.equals(s)) { - isContained = true; - break; - } - } - - return isContained; - } - - /** - * 数组array2是否包含于array1中,不需要完全一样 - * - * @param array1 - * @param array2 - * @return - */ - private boolean isStrArrayContain(String[] array1, String[] array2) { - boolean isContain = true; - for (String s2 : array2) { - isContain = false; - for (String s1 : array1) { - // 只要s2字符存在于array1中,这个字符就算包含在array1中 - if (s2.equals(s1)) { - isContain = true; - break; - } - } - - // 一旦发现不包含的字符,则array2数组不包含于array1中 - if (!isContain) { - break; - } - } - - return isContain; - } - - /** - * 判断频繁项集是否满足查询 - * - * @param item - * 待判断的频繁项集 - * @param attrValues - * 查询的属性值列表 - * @return - */ - private boolean itemIsSatisfied(FrequentItem item, - ArrayList attrValues) { - boolean isContained = false; - String[] array = item.getIdArray(); - - for (String s : attrValues) { - isContained = true; - - if (!strIsContained(array, s)) { - isContained = false; - break; - } - - if (!isContained) { - break; - } - } - - if (isContained) { - isContained = false; - - // 还要验证是否频繁项集中是否包含分类属性 - for (Integer type : classTypes) { - if (strIsContained(array, type + "")) { - isContained = true; - break; - } - } - } - - return isContained; - } +class CBATool { + // 年龄的类别划分 + private final String AGE = "Age"; + private final String AGE_YOUNG = "Young"; + private final String AGE_MIDDLE_AGED = "Middle_aged"; + private final String AGE_Senior = "Senior"; + + // 测试数据地址 + private String filePath; + // 最小支持度阈值率 + private double minSupportRate; + // 最小置信度阈值,用来判断是否能够成为关联规则 + private double minConf; + // 最小支持度 + private int minSupportCount; + // 属性列名称 + private String[] attrNames; + // 类别属性所代表的数字集合 + private ArrayList classTypes; + // 用二维数组保存测试数据 + private ArrayList totalDatas; + // Apriori算法工具类 + private AprioriTool aprioriTool; + // 属性到数字的映射图 + private HashMap attr2Num; + private HashMap num2Attr; + + CBATool(String filePath, double minSupportRate, double minConf){ + this.filePath = filePath; + this.minConf = minConf; + this.minSupportRate = minSupportRate; + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + totalDatas = new ArrayList<>(); + for (String[] array : dataArray) { + totalDatas.add(array); + } + attrNames = totalDatas.get(0); + minSupportCount = (int) (minSupportRate * totalDatas.size()); + + attributeReplace(); + } + + /** + * 属性值的替换,替换成数字的形式,以便进行频繁项的挖掘 + */ + private void attributeReplace(){ + int currentValue = 1; + int num; + String s; + // 属性名到数字的映射图 + attr2Num = new HashMap<>(); + num2Attr = new HashMap<>(); + classTypes = new ArrayList<>(); + + // 按照1列列的方式来,从左往右边扫描,跳过列名称行和id列 + for (int j = 1; j < attrNames.length; j++) { + for (int i = 1; i < totalDatas.size(); i++) { + s = totalDatas.get(i)[j]; + // 如果是数字形式的,这里只做年龄类别转换,其他的数字情况类似 + if (attrNames[j].equals(AGE)) { + num = Integer.parseInt(s); + if (num <= 20 && num > 0) { + totalDatas.get(i)[j] = AGE_YOUNG; + } else if (num > 20 && num <= 40) { + totalDatas.get(i)[j] = AGE_MIDDLE_AGED; + } else if (num > 40) { + totalDatas.get(i)[j] = AGE_Senior; + } + } + + if (!attr2Num.containsKey(totalDatas.get(i)[j])) { + attr2Num.put(totalDatas.get(i)[j], currentValue); + num2Attr.put(currentValue, totalDatas.get(i)[j]); + if (j == attrNames.length - 1) { + // 如果是组后一列,说明是分类类别列,记录下来 + classTypes.add(currentValue); + } + + currentValue++; + } + } + } + + // 对原始的数据作属性替换,每条记录变为类似于事务数据的形式 + for (int i = 1; i < totalDatas.size(); i++) { + for (int j = 1; j < attrNames.length; j++) { + s = totalDatas.get(i)[j]; + if (attr2Num.containsKey(s)) { + totalDatas.get(i)[j] = attr2Num.get(s) + ""; + } + } + } + } + + /** + * Apriori计算全部频繁项集 + */ + private ArrayList aprioriCalculate(){ + String[] tempArray; + ArrayList totalFrequentItems; + ArrayList copyData = (ArrayList) totalDatas.clone(); + // 去除属性名称行 + copyData.remove(0); + // 去除首列ID + for (int i = 0; i < copyData.size(); i++) { + String[] array = copyData.get(i); + tempArray = new String[array.length - 1]; + System.arraycopy(array, 1, tempArray, 0, tempArray.length); + copyData.set(i, tempArray); + } + aprioriTool = new AprioriTool(copyData, minSupportCount); + aprioriTool.computeLink(); + totalFrequentItems = aprioriTool.getTotalFrequentItems(); + + return totalFrequentItems; + } + + /** + * 基于关联规则的分类 + * + * @param attrValues 预先知道的一些属性 + */ + String CBAJudge(String attrValues){ + int value; + // 最终分类类别 + String classType = null; + String[] tempArray; + // 已知的属性值 + ArrayList attrValueList = new ArrayList<>(); + ArrayList totalFrequentItems; + + totalFrequentItems = aprioriCalculate(); + // 将查询条件进行逐一属性的分割 + String[] array = attrValues.split(","); + for (String record : array) { + tempArray = record.split("="); + value = attr2Num.get(tempArray[1]); + attrValueList.add(value + ""); + } + + // 在频繁项集中寻找符合条件的项 + for (FrequentItem item : totalFrequentItems) { + // 过滤掉不满足个数频繁项 + if (item.getIdArray().length < (attrValueList.size() + 1)) { + continue; + } + + // 要保证查询的属性都包含在频繁项集中 + if (itemIsSatisfied(item, attrValueList)) { + tempArray = item.getIdArray(); + classType = classificationBaseRules(tempArray); + + if (classType != null) { + // 作属性替换 + classType = num2Attr.get(Integer.parseInt(classType)); + break; + } + } + } + + return classType; + } + + /** + * 基于关联规则进行分类 + * + * @param items 频繁项 + */ + private String classificationBaseRules(String[] items){ + String classType = null; + String[] arrayTemp; + int count1 = 0; + int count2 = 0; + // 置信度 + double confidenceRate; + + String[] noClassTypeItems = new String[items.length - 1]; + for (int i = 0, k = 0; i < items.length; i++) { + if (!classTypes.contains(Integer.parseInt(items[i]))) { + noClassTypeItems[k] = items[i]; + k++; + } else { + classType = items[i]; + } + } + + for (String[] array : totalDatas) { + // 去除ID数字号 + arrayTemp = new String[array.length - 1]; + System.arraycopy(array, 1, arrayTemp, 0, array.length - 1); + if (isStrArrayContain(arrayTemp, noClassTypeItems)) { + count1++; + + if (isStrArrayContain(arrayTemp, items)) { + count2++; + } + } + } + + // 做置信度的计算 + confidenceRate = count1 * 1.0 / count2; + if (confidenceRate >= minConf) { + return classType; + } else { + // 如果不满足最小置信度要求,则此关联规则无效 + return null; + } + } + + /** + * 判断单个字符是否包含在字符数组中 + * + * @param array 字符数组 + * @param s 判断的单字符 + */ + private boolean strIsContained(String[] array, String s){ + boolean isContained = false; + + for (String str : array) { + if (str.equals(s)) { + isContained = true; + break; + } + } + + return isContained; + } + + /** + * 数组array2是否包含于array1中,不需要完全一样 + * + * @param array1 数组1 + * @param array2 数组2 + */ + private boolean isStrArrayContain(String[] array1, String[] array2){ + boolean isContain = true; + for (String s2 : array2) { + isContain = false; + for (String s1 : array1) { + // 只要s2字符存在于array1中,这个字符就算包含在array1中 + if (s2.equals(s1)) { + isContain = true; + break; + } + } + + // 一旦发现不包含的字符,则array2数组不包含于array1中 + if (!isContain) { + break; + } + } + + return isContain; + } + + /** + * 判断频繁项集是否满足查询 + * + * @param item 待判断的频繁项集 + * @param attrValues 查询的属性值列表 + */ + private boolean itemIsSatisfied(FrequentItem item, + ArrayList attrValues){ + boolean isContained = false; + String[] array = item.getIdArray(); + + for (String s : attrValues) { + isContained = true; + + if (!strIsContained(array, s)) { + isContained = false; + break; + } + } + + if (isContained) { + isContained = false; + + // 还要验证是否频繁项集中是否包含分类属性 + for (Integer type : classTypes) { + if (strIsContained(array, type + "")) { + isContained = true; + break; + } + } + } + + return isContained; + } } diff --git a/IntegratedMining/DataMining_CBA/Client.java b/IntegratedMining/DataMining_CBA/Client.java index 6d553f8..a6f6544 100644 --- a/IntegratedMining/DataMining_CBA/Client.java +++ b/IntegratedMining/DataMining_CBA/Client.java @@ -1,25 +1,25 @@ -package DataMining_CBA; +package IntegratedMining.DataMining_CBA; import java.text.MessageFormat; /** * CBA算法--基于关联规则的分类算法 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - String attrDesc = "Age=Senior,CreditRating=Fair"; - String classification = null; - - //最小支持度阈值率 - double minSupportRate = 0.2; - //最小置信度阈值 - double minConf = 0.7; - - CBATool tool = new CBATool(filePath, minSupportRate, minConf); - classification = tool.CBAJudge(attrDesc); - System.out.println(MessageFormat.format("{0}的关联分类结果为{1}", attrDesc, classification)); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/IntegratedMining/DataMining_CBA/input.txt"; + String attrDesc = "Age=Senior,CreditRating=Fair"; + String classification; + + //最小支持度阈值率 + double minSupportRate = 0.2; + //最小置信度阈值 + double minConf = 0.7; + + CBATool tool = new CBATool(filePath, minSupportRate, minConf); + classification = tool.CBAJudge(attrDesc); + System.out.println(MessageFormat.format("{0}的关联分类结果为{1}", attrDesc, classification)); + } } diff --git a/LinkMining/DataMining_HITS/Client.java b/LinkMining/DataMining_HITS/Client.java index 9de96c1..8ac7957 100644 --- a/LinkMining/DataMining_HITS/Client.java +++ b/LinkMining/DataMining_HITS/Client.java @@ -1,15 +1,15 @@ -package DataMining_HITS; +package LinkMining.DataMining_HITS; /** * HITS链接分析算法 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - - HITSTool tool = new HITSTool(filePath); - tool.printResultPage(); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/LinkMining/DataMining_HITS/input.txt"; + + HITSTool tool = new HITSTool(filePath); + tool.printResultPage(); + } } diff --git a/LinkMining/DataMining_HITS/HITSTool.java b/LinkMining/DataMining_HITS/HITSTool.java index 4e26910..847cdc8 100644 --- a/LinkMining/DataMining_HITS/HITSTool.java +++ b/LinkMining/DataMining_HITS/HITSTool.java @@ -1,4 +1,4 @@ -package DataMining_HITS; +package LinkMining.DataMining_HITS; import java.io.BufferedReader; import java.io.File; @@ -8,143 +8,143 @@ /** * HITS链接分析算法工具类 - * @author lyq * + * @author Qstar */ -public class HITSTool { - //输入数据文件地址 - private String filePath; - //网页个数 - private int pageNum; - //网页Authority权威值 - private double[] authority; - //网页hub中心值 - private double[] hub; - //链接矩阵关系 - private int[][] linkMatrix; - //网页种类 - private ArrayList pageClass; - - public HITSTool(String filePath){ - this.filePath = filePath; - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - pageClass = new ArrayList<>(); - // 统计网页类型种数 - for (String[] array : dataArray) { - for (String s : array) { - if (!pageClass.contains(s)) { - pageClass.add(s); - } - } - } - - int i = 0; - int j = 0; - pageNum = pageClass.size(); - linkMatrix = new int[pageNum][pageNum]; - authority = new double[pageNum]; - hub = new double[pageNum]; - for(int k=0; k 0.01 * pageNum){ - for(int k=0; k maxHub){ - maxHub = newHub[k]; - } - - if(newAuthority[k] > maxAuthority){ - maxAuthority = newAuthority[k]; - maxAuthorityIndex = k; - } - } - - error = 0; - //归一化处理 - for(int k=0; k pageClass; + + HITSTool(String filePath){ + this.filePath = filePath; + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + pageClass = new ArrayList<>(); + // 统计网页类型种数 + for (String[] array : dataArray) { + for (String s : array) { + if (!pageClass.contains(s)) { + pageClass.add(s); + } + } + } + + int i; + int j; + pageNum = pageClass.size(); + linkMatrix = new int[pageNum][pageNum]; + authority = new double[pageNum]; + hub = new double[pageNum]; + for (int k = 0; k < pageNum; k++) { + //初始时默认权威值和中心值都为1 + authority[k] = 1; + hub[k] = 1; + } + + for (String[] array : dataArray) { + + i = Integer.parseInt(array[0]); + j = Integer.parseInt(array[1]); + + // 设置linkMatrix[i][j]为1代表i网页包含指向j网页的链接 + linkMatrix[i - 1][j - 1] = 1; + } + } + + /** + * 输出结果页面,也就是authority权威值最高的页面 + */ + void printResultPage(){ + //最大Hub和Authority值,用于后面的归一化计算 + double maxHub; + double maxAuthority; + int maxAuthorityIndex = 0; + //误差值,用于收敛判断 + double error = Integer.MAX_VALUE; + double[] newHub = new double[pageNum]; + double[] newAuthority = new double[pageNum]; + + + while (error > 0.01 * pageNum) { + for (int k = 0; k < pageNum; k++) { + newHub[k] = 0; + newAuthority[k] = 0; + } + + //hub和authority值的更新计算 + for (int i = 0; i < pageNum; i++) { + for (int j = 0; j < pageNum; j++) { + if (linkMatrix[i][j] == 1) { + newHub[i] += authority[j]; + newAuthority[j] += hub[i]; + } + } + } + + maxHub = 0; + maxAuthority = 0; + for (int k = 0; k < pageNum; k++) { + if (newHub[k] > maxHub) { + maxHub = newHub[k]; + } + + if (newAuthority[k] > maxAuthority) { + maxAuthority = newAuthority[k]; + maxAuthorityIndex = k; + } + } + + error = 0; + //归一化处理 + for (int k = 0; k < pageNum; k++) { + newHub[k] /= maxHub; + newAuthority[k] /= maxAuthority; + + error += Math.abs(newHub[k] - hub[k]); + System.out.println(newAuthority[k] + ":" + newHub[k]); + + hub[k] = newHub[k]; + authority[k] = newAuthority[k]; + } + System.out.println("---------"); + } + + System.out.println("****最终收敛的网页的权威值和中心值****"); + for (int k = 0; k < pageNum; k++) { + System.out.println("网页" + pageClass.get(k) + ":" + authority[k] + ":" + hub[k]); + } + System.out.println("权威值最高的网页为:网页" + pageClass.get(maxAuthorityIndex)); + } } diff --git a/LinkMining/DataMining_PageRank/Client.java b/LinkMining/DataMining_PageRank/Client.java index 9d89209..d91f019 100644 --- a/LinkMining/DataMining_PageRank/Client.java +++ b/LinkMining/DataMining_PageRank/Client.java @@ -1,15 +1,15 @@ -package DataMining_PageRank; +package LinkMining.DataMining_PageRank; /** * PageRank计算网页重要性/排名算法 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - - PageRankTool tool = new PageRankTool(filePath); - tool.printPageRankValue(); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/LinkMining/DataMining_PageRank/input.txt"; + + PageRankTool tool = new PageRankTool(filePath); + tool.printPageRankValue(); + } } diff --git a/LinkMining/DataMining_PageRank/PageRankTool.java b/LinkMining/DataMining_PageRank/PageRankTool.java index 9421f62..6c6461a 100644 --- a/LinkMining/DataMining_PageRank/PageRankTool.java +++ b/LinkMining/DataMining_PageRank/PageRankTool.java @@ -1,184 +1,181 @@ -package DataMining_PageRank; +package LinkMining.DataMining_PageRank; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; -import java.lang.reflect.Array; import java.text.MessageFormat; import java.util.ArrayList; /** * PageRank网页排名算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class PageRankTool { - // 测试输入数据 - private String filePath; - // 网页总数量 - private int pageNum; - // 链接关系矩阵 - private double[][] linkMatrix; - // 每个页面pageRank值初始向量 - private double[] pageRankVecor; - - // 网页数量分类 - ArrayList pageClass; - - public PageRankTool(String filePath) { - this.filePath = filePath; - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - pageClass = new ArrayList<>(); - // 统计网页类型种数 - for (String[] array : dataArray) { - for (String s : array) { - if (!pageClass.contains(s)) { - pageClass.add(s); - } - } - } - - int i = 0; - int j = 0; - pageNum = pageClass.size(); - linkMatrix = new double[pageNum][pageNum]; - pageRankVecor = new double[pageNum]; - for (int k = 0; k < pageNum; k++) { - // 初始每个页面的pageRank值为1 - pageRankVecor[k] = 1.0; - } - for (String[] array : dataArray) { - - i = Integer.parseInt(array[0]); - j = Integer.parseInt(array[1]); - - // 设置linkMatrix[i][j]为1代表i网页包含指向j网页的链接 - linkMatrix[i - 1][j - 1] = 1; - } - } - - /** - * 将矩阵转置 - */ - private void transferMatrix() { - int count = 0; - for (double[] array : linkMatrix) { - // 计算页面链接个数 - count = 0; - for (double d : array) { - if (d == 1) { - count++; - } - } - // 按概率均分 - for (int i = 0; i < array.length; i++) { - if (array[i] == 1) { - array[i] /= count; - } - } - } - - double t = 0; - // 将矩阵转置换,作为概率转移矩阵 - for (int i = 0; i < linkMatrix.length; i++) { - for (int j = i + 1; j < linkMatrix[0].length; j++) { - t = linkMatrix[i][j]; - linkMatrix[i][j] = linkMatrix[j][i]; - linkMatrix[j][i] = t; - } - } - } - - /** - * 利用幂法计算pageRank值 - */ - public void printPageRankValue() { - transferMatrix(); - // 阻尼系数 - double damp = 0.5; - // 链接概率矩阵 - double[][] A = new double[pageNum][pageNum]; - double[][] e = new double[pageNum][pageNum]; - - // 调用公式A=d*q+(1-d)*e/m,m为网页总个数,d就是damp - double temp = (1 - damp) / pageNum; - for (int i = 0; i < e.length; i++) { - for (int j = 0; j < e[0].length; j++) { - e[i][j] = temp; - } - } - - for (int i = 0; i < pageNum; i++) { - for (int j = 0; j < pageNum; j++) { - temp = damp * linkMatrix[i][j] + e[i][j]; - A[i][j] = temp; - - } - } - - // 误差值,作为判断收敛标准 - double errorValue = Integer.MAX_VALUE; - double[] newPRVector = new double[pageNum]; - // 当平均每个PR值误差小于0.001时就算达到收敛 - while (errorValue > 0.001 * pageNum) { - System.out.println("**********"); - for (int i = 0; i < pageNum; i++) { - temp = 0; - // 将A*pageRankVector,利用幂法求解,直到pageRankVector值收敛 - for (int j = 0; j < pageNum; j++) { - // temp就是每个网页到i页面的pageRank值 - temp += A[i][j] * pageRankVecor[j]; - } - - // 最后的temp就是i网页的总PageRank值 - newPRVector[i] = temp; - System.out.println(temp); - } - - errorValue = 0; - for (int i = 0; i < pageNum; i++) { - errorValue += Math.abs(pageRankVecor[i] - newPRVector[i]); - // 新的向量代替旧的向量 - pageRankVecor[i] = newPRVector[i]; - } - } - - String name = null; - temp = 0; - System.out.println("--------------------"); - for (int i = 0; i < pageNum; i++) { - System.out.println(MessageFormat.format("网页{0}的pageRank值:{1}", - pageClass.get(i), pageRankVecor[i])); - if (pageRankVecor[i] > temp) { - temp = pageRankVecor[i]; - name = pageClass.get(i); - } - } - System.out.println(MessageFormat.format("等级最高的网页为:{0}", name)); - } +class PageRankTool { + // 网页数量分类 + private ArrayList pageClass; + // 测试输入数据 + private String filePath; + // 网页总数量 + private int pageNum; + // 链接关系矩阵 + private double[][] linkMatrix; + // 每个页面pageRank值初始向量 + private double[] pageRankVecor; + + PageRankTool(String filePath){ + this.filePath = filePath; + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + pageClass = new ArrayList<>(); + // 统计网页类型种数 + for (String[] array : dataArray) { + for (String s : array) { + if (!pageClass.contains(s)) { + pageClass.add(s); + } + } + } + + int i; + int j; + pageNum = pageClass.size(); + linkMatrix = new double[pageNum][pageNum]; + pageRankVecor = new double[pageNum]; + for (int k = 0; k < pageNum; k++) { + // 初始每个页面的pageRank值为1 + pageRankVecor[k] = 1.0; + } + for (String[] array : dataArray) { + + i = Integer.parseInt(array[0]); + j = Integer.parseInt(array[1]); + + // 设置linkMatrix[i][j]为1代表i网页包含指向j网页的链接 + linkMatrix[i - 1][j - 1] = 1; + } + } + + /** + * 将矩阵转置 + */ + private void transferMatrix(){ + int count; + for (double[] array : linkMatrix) { + // 计算页面链接个数 + count = 0; + for (double d : array) { + if (d == 1) { + count++; + } + } + // 按概率均分 + for (int i = 0; i < array.length; i++) { + if (array[i] == 1) { + array[i] /= count; + } + } + } + + double t; + // 将矩阵转置换,作为概率转移矩阵 + for (int i = 0; i < linkMatrix.length; i++) { + for (int j = i + 1; j < linkMatrix[0].length; j++) { + t = linkMatrix[i][j]; + linkMatrix[i][j] = linkMatrix[j][i]; + linkMatrix[j][i] = t; + } + } + } + + /** + * 利用幂法计算pageRank值 + */ + void printPageRankValue(){ + transferMatrix(); + // 阻尼系数 + double damp = 0.5; + // 链接概率矩阵 + double[][] A = new double[pageNum][pageNum]; + double[][] e = new double[pageNum][pageNum]; + + // 调用公式A=d*q+(1-d)*e/m,m为网页总个数,d就是damp + double temp = (1 - damp) / pageNum; + for (int i = 0; i < e.length; i++) { + for (int j = 0; j < e[0].length; j++) { + e[i][j] = temp; + } + } + + for (int i = 0; i < pageNum; i++) { + for (int j = 0; j < pageNum; j++) { + temp = damp * linkMatrix[i][j] + e[i][j]; + A[i][j] = temp; + + } + } + + // 误差值,作为判断收敛标准 + double errorValue = Integer.MAX_VALUE; + double[] newPRVector = new double[pageNum]; + // 当平均每个PR值误差小于0.001时就算达到收敛 + while (errorValue > 0.001 * pageNum) { + System.out.println("**********"); + for (int i = 0; i < pageNum; i++) { + temp = 0; + // 将A*pageRankVector,利用幂法求解,直到pageRankVector值收敛 + for (int j = 0; j < pageNum; j++) { + // temp就是每个网页到i页面的pageRank值 + temp += A[i][j] * pageRankVecor[j]; + } + + // 最后的temp就是i网页的总PageRank值 + newPRVector[i] = temp; + System.out.println(temp); + } + + errorValue = 0; + for (int i = 0; i < pageNum; i++) { + errorValue += Math.abs(pageRankVecor[i] - newPRVector[i]); + // 新的向量代替旧的向量 + pageRankVecor[i] = newPRVector[i]; + } + } + + String name = null; + temp = 0; + System.out.println("--------------------"); + for (int i = 0; i < pageNum; i++) { + System.out.println(MessageFormat.format("网页{0}的pageRank值:{1}", + pageClass.get(i), pageRankVecor[i])); + if (pageRankVecor[i] > temp) { + temp = pageRankVecor[i]; + name = pageClass.get(i); + } + } + System.out.println(MessageFormat.format("等级最高的网页为:{0}", name)); + } } diff --git a/Others/DataMining_ACO/ACOTool.java b/Others/DataMining_ACO/ACOTool.java index 21b9760..33fdbdb 100644 --- a/Others/DataMining_ACO/ACOTool.java +++ b/Others/DataMining_ACO/ACOTool.java @@ -1,351 +1,339 @@ -package DataMining_ACO; +package Others.DataMining_ACO; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.text.MessageFormat; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Random; +import java.util.*; /** * 蚁群算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class ACOTool { - // 输入数据类型 - public static final int INPUT_CITY_NAME = 1; - public static final int INPUT_CITY_DIS = 2; - - // 城市间距离邻接矩阵 - public static double[][] disMatrix; - // 当前时间 - public static int currentTime; - - // 测试数据地址 - private String filePath; - // 蚂蚁数量 - private int antNum; - // 控制参数 - private double alpha; - private double beita; - private double p; - private double Q; - // 随机数产生器 - private Random random; - // 城市名称集合,这里为了方便,将城市用数字表示 - private ArrayList totalCitys; - // 所有的蚂蚁集合 - private ArrayList totalAnts; - // 城市间的信息素浓度矩阵,随着时间的增多而减少 - private double[][] pheromoneMatrix; - // 目标的最短路径,顺序为从集合的前部往后挪动 - private ArrayList bestPath; - // 信息素矩阵存储图,key采用的格式(i,j,t)->value - private Map pheromoneTimeMap; - - public ACOTool(String filePath, int antNum, double alpha, double beita, - double p, double Q) { - this.filePath = filePath; - this.antNum = antNum; - this.alpha = alpha; - this.beita = beita; - this.p = p; - this.Q = Q; - this.currentTime = 0; - - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - int flag = -1; - int src = 0; - int des = 0; - int size = 0; - // 进行城市名称种数的统计 - this.totalCitys = new ArrayList<>(); - for (String[] array : dataArray) { - if (array[0].equals("#") && totalCitys.size() == 0) { - flag = INPUT_CITY_NAME; - - continue; - } else if (array[0].equals("#") && totalCitys.size() > 0) { - size = totalCitys.size(); - // 初始化距离矩阵 - this.disMatrix = new double[size + 1][size + 1]; - this.pheromoneMatrix = new double[size + 1][size + 1]; - - // 初始值-1代表此对应位置无值 - for (int i = 0; i < size; i++) { - for (int j = 0; j < size; j++) { - this.disMatrix[i][j] = -1; - this.pheromoneMatrix[i][j] = -1; - } - } - - flag = INPUT_CITY_DIS; - continue; - } - - if (flag == INPUT_CITY_NAME) { - this.totalCitys.add(array[0]); - } else { - src = Integer.parseInt(array[0]); - des = Integer.parseInt(array[1]); - - this.disMatrix[src][des] = Double.parseDouble(array[2]); - this.disMatrix[des][src] = Double.parseDouble(array[2]); - } - } - } - - /** - * 计算从蚂蚁城市i到j的概率 - * - * @param cityI - * 城市I - * @param cityJ - * 城市J - * @param currentTime - * 当前时间 - * @return - */ - private double calIToJProbably(String cityI, String cityJ, int currentTime) { - double pro = 0; - double n = 0; - double pheromone; - int i; - int j; - - i = Integer.parseInt(cityI); - j = Integer.parseInt(cityJ); - - pheromone = getPheromone(currentTime, cityI, cityJ); - n = 1.0 / disMatrix[i][j]; - - if (pheromone == 0) { - pheromone = 1; - } - - pro = Math.pow(n, alpha) * Math.pow(pheromone, beita); - - return pro; - } - - /** - * 计算综合概率蚂蚁从I城市走到J城市的概率 - * - * @return - */ - public String selectAntNextCity(Ant ant, int currentTime) { - double randomNum; - double tempPro; - // 总概率指数 - double proTotal; - String nextCity = null; - ArrayList allowedCitys; - // 各城市概率集 - double[] proArray; - - // 如果是刚刚开始的时候,没有路过任何城市,则随机返回一个城市 - if (ant.currentPath.size() == 0) { - nextCity = String.valueOf(random.nextInt(totalCitys.size()) + 1); - - return nextCity; - } else if (ant.nonVisitedCitys.isEmpty()) { - // 如果全部遍历完毕,则再次回到起点 - nextCity = ant.currentPath.get(0); - - return nextCity; - } - - proTotal = 0; - allowedCitys = ant.nonVisitedCitys; - proArray = new double[allowedCitys.size()]; - - for (int i = 0; i < allowedCitys.size(); i++) { - nextCity = allowedCitys.get(i); - proArray[i] = calIToJProbably(ant.currentPos, nextCity, currentTime); - proTotal += proArray[i]; - } - - for (int i = 0; i < allowedCitys.size(); i++) { - // 归一化处理 - proArray[i] /= proTotal; - } - - // 用随机数选择下一个城市 - randomNum = random.nextInt(100) + 1; - randomNum = randomNum / 100; - // 因为1.0是无法判断到的,,总和会无限接近1.0取为0.99做判断 - if (randomNum == 1) { - randomNum = randomNum - 0.01; - } - - tempPro = 0; - // 确定区间 - for (int j = 0; j < allowedCitys.size(); j++) { - if (randomNum > tempPro && randomNum <= tempPro + proArray[j]) { - // 采用拷贝的方式避免引用重复 - nextCity = allowedCitys.get(j); - break; - } else { - tempPro += proArray[j]; - } - } - - return nextCity; - } - - /** - * 获取给定时间点上从城市i到城市j的信息素浓度 - * - * @param t - * @param cityI - * @param cityJ - * @return - */ - private double getPheromone(int t, String cityI, String cityJ) { - double pheromone = 0; - String key; - - // 上一周期需将时间倒回一周期 - key = MessageFormat.format("{0},{1},{2}", cityI, cityJ, t); - - if (pheromoneTimeMap.containsKey(key)) { - pheromone = pheromoneTimeMap.get(key); - } - - return pheromone; - } - - /** - * 每轮结束,刷新信息素浓度矩阵 - * - * @param t - */ - private void refreshPheromone(int t) { - double pheromone = 0; - // 上一轮周期结束后的信息素浓度,丛信息素浓度图中查找 - double lastTimeP = 0; - // 本轮信息素浓度增加量 - double addPheromone; - String key; - - for (String i : totalCitys) { - for (String j : totalCitys) { - if (!i.equals(j)) { - // 上一周期需将时间倒回一周期 - key = MessageFormat.format("{0},{1},{2}", i, j, t - 1); - - if (pheromoneTimeMap.containsKey(key)) { - lastTimeP = pheromoneTimeMap.get(key); - } else { - lastTimeP = 0; - } - - addPheromone = 0; - for (Ant ant : totalAnts) { - if(ant.pathContained(i, j)){ - // 每只蚂蚁传播的信息素为控制因子除以距离总成本 - addPheromone += Q / ant.calSumDistance(); - } - } - - // 将上次的结果值加上递增的量,并存入图中 - pheromone = p * lastTimeP + addPheromone; - key = MessageFormat.format("{0},{1},{2}", i, j, t); - pheromoneTimeMap.put(key, pheromone); - } - } - } - - } - - /** - * 蚁群算法迭代次数 - * @param loopCount - * 具体遍历次数 - */ - public void antStartSearching(int loopCount) { - // 蚁群寻找的总次数 - int count = 0; - // 选中的下一个城市 - String selectedCity = ""; - - pheromoneTimeMap = new HashMap(); - totalAnts = new ArrayList<>(); - random = new Random(); - - while (count < loopCount) { - initAnts(); - - while (true) { - for (Ant ant : totalAnts) { - selectedCity = selectAntNextCity(ant, currentTime); - ant.goToNextCity(selectedCity); - } - - // 如果已经遍历完所有城市,则跳出此轮循环 - if (totalAnts.get(0).isBack()) { - break; - } - } - - // 周期时间叠加 - currentTime++; - refreshPheromone(currentTime); - count++; - } - - // 根据距离成本,选出所花距离最短的一个路径 - Collections.sort(totalAnts); - bestPath = totalAnts.get(0).currentPath; - System.out.println(MessageFormat.format("经过{0}次循环遍历,最终得出的最佳路径:", count)); - System.out.print("entrance"); - for (String cityName : bestPath) { - System.out.print(MessageFormat.format("-->{0}", cityName)); - } - } - - /** - * 初始化蚁群操作 - */ - private void initAnts() { - Ant tempAnt; - ArrayList nonVisitedCitys; - totalAnts.clear(); - - // 初始化蚁群 - for (int i = 0; i < antNum; i++) { - nonVisitedCitys = (ArrayList) totalCitys.clone(); - tempAnt = new Ant(pheromoneMatrix, nonVisitedCitys); - - totalAnts.add(tempAnt); - } - } +class ACOTool { + // 输入数据类型 + private static final int INPUT_CITY_NAME = 1; + private static final int INPUT_CITY_DIS = 2; + + // 城市间距离邻接矩阵 + static double[][] disMatrix; + // 当前时间 + private static int currentTime; + + // 测试数据地址 + private String filePath; + // 蚂蚁数量 + private int antNum; + // 控制参数 + private double alpha; + private double beita; + private double p; + private double Q; + // 随机数产生器 + private Random random; + // 城市名称集合,这里为了方便,将城市用数字表示 + private ArrayList totalCitys; + // 所有的蚂蚁集合 + private ArrayList totalAnts; + // 城市间的信息素浓度矩阵,随着时间的增多而减少 + private double[][] pheromoneMatrix; + // 目标的最短路径,顺序为从集合的前部往后挪动 + private ArrayList bestPath; + // 信息素矩阵存储图,key采用的格式(i,j,t)->value + private Map pheromoneTimeMap; + + ACOTool(String filePath, int antNum, double alpha, double beita, + double p, double Q){ + this.filePath = filePath; + this.antNum = antNum; + this.alpha = alpha; + this.beita = beita; + this.p = p; + this.Q = Q; + currentTime = 0; + + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + int flag = -1; + int src; + int des; + int size; + // 进行城市名称种数的统计 + this.totalCitys = new ArrayList<>(); + for (String[] array : dataArray) { + if (array[0].equals("#") && totalCitys.size() == 0) { + flag = INPUT_CITY_NAME; + + continue; + } else if (array[0].equals("#") && totalCitys.size() > 0) { + size = totalCitys.size(); + // 初始化距离矩阵 + disMatrix = new double[size + 1][size + 1]; + this.pheromoneMatrix = new double[size + 1][size + 1]; + + // 初始值-1代表此对应位置无值 + for (int i = 0; i < size; i++) { + for (int j = 0; j < size; j++) { + disMatrix[i][j] = -1; + this.pheromoneMatrix[i][j] = -1; + } + } + + flag = INPUT_CITY_DIS; + continue; + } + + if (flag == INPUT_CITY_NAME) { + this.totalCitys.add(array[0]); + } else { + src = Integer.parseInt(array[0]); + des = Integer.parseInt(array[1]); + + disMatrix[src][des] = Double.parseDouble(array[2]); + disMatrix[des][src] = Double.parseDouble(array[2]); + } + } + } + + /** + * 计算从蚂蚁城市i到j的概率 + * + * @param cityI 城市I + * @param cityJ 城市J + * @param currentTime 当前时间 + */ + private double calIToJProbably(String cityI, String cityJ, int currentTime){ + double pro; + double n; + double pheromone; + int i; + int j; + + i = Integer.parseInt(cityI); + j = Integer.parseInt(cityJ); + + pheromone = getPheromone(currentTime, cityI, cityJ); + n = 1.0 / disMatrix[i][j]; + + if (pheromone == 0) { + pheromone = 1; + } + + pro = Math.pow(n, alpha) * Math.pow(pheromone, beita); + + return pro; + } + + /** + * 计算综合概率蚂蚁从I城市走到J城市的概率 + */ + private String selectAntNextCity(Ant ant, int currentTime){ + double randomNum; + double tempPro; + // 总概率指数 + double proTotal; + String nextCity = null; + ArrayList allowedCitys; + // 各城市概率集 + double[] proArray; + + // 如果是刚刚开始的时候,没有路过任何城市,则随机返回一个城市 + if (ant.currentPath.size() == 0) { + nextCity = String.valueOf(random.nextInt(totalCitys.size()) + 1); + + return nextCity; + } else if (ant.nonVisitedCitys.isEmpty()) { + // 如果全部遍历完毕,则再次回到起点 + nextCity = ant.currentPath.get(0); + + return nextCity; + } + + proTotal = 0; + allowedCitys = ant.nonVisitedCitys; + proArray = new double[allowedCitys.size()]; + + for (int i = 0; i < allowedCitys.size(); i++) { + nextCity = allowedCitys.get(i); + proArray[i] = calIToJProbably(ant.currentPos, nextCity, currentTime); + proTotal += proArray[i]; + } + + for (int i = 0; i < allowedCitys.size(); i++) { + // 归一化处理 + proArray[i] /= proTotal; + } + + // 用随机数选择下一个城市 + randomNum = random.nextInt(100) + 1; + randomNum = randomNum / 100; + // 因为1.0是无法判断到的,,总和会无限接近1.0取为0.99做判断 + if (randomNum == 1) { + randomNum = randomNum - 0.01; + } + + tempPro = 0; + // 确定区间 + for (int j = 0; j < allowedCitys.size(); j++) { + if (randomNum > tempPro && randomNum <= tempPro + proArray[j]) { + // 采用拷贝的方式避免引用重复 + nextCity = allowedCitys.get(j); + break; + } else { + tempPro += proArray[j]; + } + } + + return nextCity; + } + + /** + * 获取给定时间点上从城市i到城市j的信息素浓度 + * + * @param t + * @param cityI + * @param cityJ + */ + private double getPheromone(int t, String cityI, String cityJ){ + double pheromone = 0; + String key; + + // 上一周期需将时间倒回一周期 + key = MessageFormat.format("{0},{1},{2}", cityI, cityJ, t); + + if (pheromoneTimeMap.containsKey(key)) { + pheromone = pheromoneTimeMap.get(key); + } + + return pheromone; + } + + /** + * 每轮结束,刷新信息素浓度矩阵 + * + * @param t + */ + private void refreshPheromone(int t){ + double pheromone; + // 上一轮周期结束后的信息素浓度,丛信息素浓度图中查找 + double lastTimeP; + // 本轮信息素浓度增加量 + double addPheromone; + String key; + + for (String i : totalCitys) { + for (String j : totalCitys) { + if (!i.equals(j)) { + // 上一周期需将时间倒回一周期 + key = MessageFormat.format("{0},{1},{2}", i, j, t - 1); + + if (pheromoneTimeMap.containsKey(key)) { + lastTimeP = pheromoneTimeMap.get(key); + } else { + lastTimeP = 0; + } + + addPheromone = 0; + for (Ant ant : totalAnts) { + if (ant.pathContained(i, j)) { + // 每只蚂蚁传播的信息素为控制因子除以距离总成本 + addPheromone += Q / ant.calSumDistance(); + } + } + + // 将上次的结果值加上递增的量,并存入图中 + pheromone = p * lastTimeP + addPheromone; + key = MessageFormat.format("{0},{1},{2}", i, j, t); + pheromoneTimeMap.put(key, pheromone); + } + } + } + + } + + /** + * 蚁群算法迭代次数 + * + * @param loopCount 具体遍历次数 + */ + void antStartSearching(int loopCount){ + // 蚁群寻找的总次数 + int count = 0; + // 选中的下一个城市 + String selectedCity; + + pheromoneTimeMap = new HashMap<>(); + totalAnts = new ArrayList<>(); + random = new Random(); + + while (count < loopCount) { + initAnts(); + + while (true) { + for (Ant ant : totalAnts) { + selectedCity = selectAntNextCity(ant, currentTime); + ant.goToNextCity(selectedCity); + } + + // 如果已经遍历完所有城市,则跳出此轮循环 + if (totalAnts.get(0).isBack()) { + break; + } + } + + // 周期时间叠加 + currentTime++; + refreshPheromone(currentTime); + count++; + } + + // 根据距离成本,选出所花距离最短的一个路径 + Collections.sort(totalAnts); + bestPath = totalAnts.get(0).currentPath; + System.out.println(MessageFormat.format("经过{0}次循环遍历,最终得出的最佳路径:", count)); + System.out.print("entrance"); + for (String cityName : bestPath) { + System.out.print(MessageFormat.format("-->{0}", cityName)); + } + } + + /** + * 初始化蚁群操作 + */ + private void initAnts(){ + Ant tempAnt; + ArrayList nonVisitedCitys; + totalAnts.clear(); + + // 初始化蚁群 + for (int i = 0; i < antNum; i++) { + nonVisitedCitys = (ArrayList) totalCitys.clone(); + tempAnt = new Ant(pheromoneMatrix, nonVisitedCitys); + + totalAnts.add(tempAnt); + } + } } diff --git a/Others/DataMining_ACO/Ant.java b/Others/DataMining_ACO/Ant.java index fd89c71..4db58a7 100644 --- a/Others/DataMining_ACO/Ant.java +++ b/Others/DataMining_ACO/Ant.java @@ -1,11 +1,11 @@ -package DataMining_ACO; +package Others.DataMining_ACO; import java.util.ArrayList; /** * 蚂蚁类,进行路径搜索的载体 * - * @author lyq + * @author Qstar * */ public class Ant implements Comparable { diff --git a/Others/DataMining_ACO/Client.java b/Others/DataMining_ACO/Client.java index 0e9ede9..1225e3e 100644 --- a/Others/DataMining_ACO/Client.java +++ b/Others/DataMining_ACO/Client.java @@ -1,32 +1,32 @@ -package DataMining_ACO; +package Others.DataMining_ACO; /** * 蚁群算法测试类 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - //测试数据 - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - //蚂蚁数量 - int antNum; - //蚁群算法迭代次数 - int loopCount; - //控制参数 - double alpha; - double beita; - double p; - double Q; - - antNum = 3; - alpha = 0.5; - beita = 1; - p = 0.5; - Q = 5; - loopCount = 5; - - ACOTool tool = new ACOTool(filePath, antNum, alpha, beita, p, Q); - tool.antStartSearching(loopCount); - } + public static void main(String[] args){ + //测试数据 + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_ACO/input.txt"; + //蚂蚁数量 + int antNum; + //蚁群算法迭代次数 + int loopCount; + //控制参数 + double alpha; + double beita; + double p; + double Q; + + antNum = 3; + alpha = 0.5; + beita = 1; + p = 0.5; + Q = 5; + loopCount = 5; + + ACOTool tool = new ACOTool(filePath, antNum, alpha, beita, p, Q); + tool.antStartSearching(loopCount); + } } diff --git a/Others/DataMining_BayesNetwork/BayesNetWorkTool.java b/Others/DataMining_BayesNetwork/BayesNetWorkTool.java index cbf99ae..045fc03 100644 --- a/Others/DataMining_BayesNetwork/BayesNetWorkTool.java +++ b/Others/DataMining_BayesNetwork/BayesNetWorkTool.java @@ -1,4 +1,4 @@ -package DataMining_BayesNetwork; +package Others.DataMining_BayesNetwork; import java.io.BufferedReader; import java.io.File; @@ -9,320 +9,310 @@ /** * 贝叶斯网络算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class BayesNetWorkTool { - // 联合概率分布数据文件地址 - private String dataFilePath; - // 事件关联数据文件地址 - private String attachFilePath; - // 属性列列数 - private int columns; - // 概率分布数据 - private String[][] totalData; - // 关联数据对 - private ArrayList attachData; - // 节点存放列表 - private ArrayList nodes; - // 属性名与列数之间的对应关系 - private HashMap attr2Column; - - public BayesNetWorkTool(String dataFilePath, String attachFilePath) { - this.dataFilePath = dataFilePath; - this.attachFilePath = attachFilePath; - - initDatas(); - } - - /** - * 初始化关联数据和概率分布数据 - */ - private void initDatas() { - String[] columnValues; - String[] array; - ArrayList datas; - ArrayList adatas; - - // 从文件中读取数据 - datas = readDataFile(dataFilePath); - adatas = readDataFile(attachFilePath); - - columnValues = datas.get(0).split(" "); - // 属性割名称代表事件B(盗窃),E(地震),A(警铃响).M(接到M的电话),J同M的意思, - // 属性值都是y,n代表yes发生和no不发生 - this.attr2Column = new HashMap<>(); - for (int i = 0; i < columnValues.length; i++) { - // 从数据中取出属性名称行,列数值存入图中 - this.attr2Column.put(columnValues[i], i); - } - - this.columns = columnValues.length; - this.totalData = new String[datas.size()][columns]; - for (int i = 0; i < datas.size(); i++) { - this.totalData[i] = datas.get(i).split(" "); - } - - this.attachData = new ArrayList<>(); - // 解析关联数据对 - for (String str : adatas) { - array = str.split(" "); - this.attachData.add(array); - } - - // 构造贝叶斯网络结构图 - constructDAG(); - } - - /** - * 从文件中读取数据 - */ - private ArrayList readDataFile(String filePath) { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - while ((str = in.readLine()) != null) { - dataArray.add(str); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - return dataArray; - } - - /** - * 根据关联数据构造贝叶斯网络无环有向图 - */ - private void constructDAG() { - // 节点存在标识 - boolean srcExist; - boolean desExist; - String name1; - String name2; - Node srcNode; - Node desNode; - - this.nodes = new ArrayList<>(); - for (String[] array : this.attachData) { - srcExist = false; - desExist = false; - - name1 = array[0]; - name2 = array[1]; - - // 新建节点 - srcNode = new Node(name1); - desNode = new Node(name2); - - for (Node temp : this.nodes) { - // 如果找到相同节点,则取出 - if (srcNode.isEqual(temp)) { - srcExist = true; - srcNode = temp; - } else if (desNode.isEqual(temp)) { - desExist = true; - desNode = temp; - } - - // 如果2个节点都已找到,则跳出循环 - if (srcExist && desExist) { - break; - } - } - - // 将2个节点进行连接 - srcNode.connectNode(desNode); - - // 根据标识判断是否需要加入列表容器中 - if (!srcExist) { - this.nodes.add(srcNode); - } - - if (!desExist) { - this.nodes.add(desNode); - } - } - } - - /** - * 查询条件概率 - * - * @param attrValues - * 条件属性值 - * @return - */ - private double queryConditionPro(ArrayList attrValues) { - // 判断是否满足先验属性值条件 - boolean hasPrior; - // 判断是否满足后验属性值条件 - boolean hasBack; - int priorIndex; - int attrIndex; - double backPro; - double totalPro; - double pro; - double currentPro; - // 先验属性 - String[] priorValue; - String[] tempData; - - pro = 0; - totalPro = 0; - backPro = 0; - attrValues.get(0); - priorValue = attrValues.get(0); - // 得到后验概率 - attrValues.remove(0); - - // 取出先验属性的列数 - priorIndex = this.attr2Column.get(priorValue[0]); - // 跳过第一行的属性名称行 - for (int i = 1; i < this.totalData.length; i++) { - tempData = this.totalData[i]; - - hasPrior = false; - hasBack = true; - - // 当前行的概率 - currentPro = Double.parseDouble(tempData[this.columns - 1]); - // 判断是否满足先验条件 - if (tempData[priorIndex].equals(priorValue[1])) { - hasPrior = true; - } - - for (String[] array : attrValues) { - attrIndex = this.attr2Column.get(array[0]); - - // 判断值是否满足条件 - if (!tempData[attrIndex].equals(array[1])) { - hasBack = false; - break; - } - } - - // 进行计数统计,分别计算满足后验属性的值和同时满足条件的个数 - if (hasBack) { - backPro += currentPro; - if (hasPrior) { - totalPro += currentPro; - } - } else if (hasPrior && attrValues.size() == 0) { - // 如果只有先验概率则为纯概率的计算 - totalPro += currentPro; - backPro = 1.0; - } - } - - // 计算总的概率=都发生概率/只发生后验条件的时间概率 - pro = totalPro / backPro; - - return pro; - } - - /** - * 根据贝叶斯网络计算概率 - * - * @param queryStr - * 查询条件串 - * @return - */ - public double calProByNetWork(String queryStr) { - double temp; - double pro; - String[] array; - // 先验条件值 - String[] preValue; - // 后验条件值 - String[] backValue; - // 所有先验条件和后验条件值的属性值的汇总 - ArrayList attrValues; - - // 判断是否满足网络结构 - if (!satisfiedNewWork(queryStr)) { - return -1; - } - - pro = 1; - // 首先做查询条件的分解 - array = queryStr.split(","); - - // 概率的初值等于第一个事件发生的随机概率 - attrValues = new ArrayList<>(); - attrValues.add(array[0].split("=")); - pro = queryConditionPro(attrValues); - - for (int i = 0; i < array.length - 1; i++) { - attrValues.clear(); - - // 下标小的在前面的属于后验属性 - backValue = array[i].split("="); - preValue = array[i + 1].split("="); - attrValues.add(preValue); - attrValues.add(backValue); - - // 算出此种情况的概率值 - temp = queryConditionPro(attrValues); - // 进行积的相乘 - pro *= temp; - } - - return pro; - } - - /** - * 验证事件的查询因果关系是否满足贝叶斯网络 - * - * @param queryStr - * 查询字符串 - * @return - */ - private boolean satisfiedNewWork(String queryStr) { - String attrName; - String[] array; - boolean isExist; - boolean isSatisfied; - // 当前节点 - Node currentNode; - // 候选节点列表 - ArrayList nodeList; - - isSatisfied = true; - currentNode = null; - // 做查询字符串的分解 - array = queryStr.split(","); - nodeList = this.nodes; - - for (String s : array) { - // 开始时默认属性对应的节点不存在 - isExist = false; - // 得到属性事件名 - attrName = s.split("=")[0]; - - for (Node n : nodeList) { - if (n.name.equals(attrName)) { - isExist = true; - - currentNode = n; - // 下一轮的候选节点为当前节点的孩子节点 - nodeList = currentNode.childNodes; - - break; - } - } - - // 如果存在未找到的节点,则说明不满足依赖结构跳出循环 - if (!isExist) { - isSatisfied = false; - break; - } - } - - return isSatisfied; - } +class BayesNetWorkTool { + // 联合概率分布数据文件地址 + private String dataFilePath; + // 事件关联数据文件地址 + private String attachFilePath; + // 属性列列数 + private int columns; + // 概率分布数据 + private String[][] totalData; + // 关联数据对 + private ArrayList attachData; + // 节点存放列表 + private ArrayList nodes; + // 属性名与列数之间的对应关系 + private HashMap attr2Column; + + BayesNetWorkTool(String dataFilePath, String attachFilePath){ + this.dataFilePath = dataFilePath; + this.attachFilePath = attachFilePath; + + initDatas(); + } + + /** + * 初始化关联数据和概率分布数据 + */ + private void initDatas(){ + String[] columnValues; + String[] array; + ArrayList datas; + ArrayList adatas; + + // 从文件中读取数据 + datas = readDataFile(dataFilePath); + adatas = readDataFile(attachFilePath); + + columnValues = datas.get(0).split(" "); + // 属性割名称代表事件B(盗窃),E(地震),A(警铃响).M(接到M的电话),J同M的意思, + // 属性值都是y,n代表yes发生和no不发生 + this.attr2Column = new HashMap<>(); + for (int i = 0; i < columnValues.length; i++) { + // 从数据中取出属性名称行,列数值存入图中 + this.attr2Column.put(columnValues[i], i); + } + + this.columns = columnValues.length; + this.totalData = new String[datas.size()][columns]; + for (int i = 0; i < datas.size(); i++) { + this.totalData[i] = datas.get(i).split(" "); + } + + this.attachData = new ArrayList<>(); + // 解析关联数据对 + for (String str : adatas) { + array = str.split(" "); + this.attachData.add(array); + } + + // 构造贝叶斯网络结构图 + constructDAG(); + } + + /** + * 从文件中读取数据 + */ + private ArrayList readDataFile(String filePath){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + while ((str = in.readLine()) != null) { + dataArray.add(str); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + return dataArray; + } + + /** + * 根据关联数据构造贝叶斯网络无环有向图 + */ + private void constructDAG(){ + // 节点存在标识 + boolean srcExist; + boolean desExist; + String name1; + String name2; + Node srcNode; + Node desNode; + + this.nodes = new ArrayList<>(); + for (String[] array : this.attachData) { + srcExist = false; + desExist = false; + + name1 = array[0]; + name2 = array[1]; + + // 新建节点 + srcNode = new Node(name1); + desNode = new Node(name2); + + for (Node temp : this.nodes) { + // 如果找到相同节点,则取出 + if (srcNode.isEqual(temp)) { + srcExist = true; + srcNode = temp; + } else if (desNode.isEqual(temp)) { + desExist = true; + desNode = temp; + } + + // 如果2个节点都已找到,则跳出循环 + if (srcExist && desExist) { + break; + } + } + + // 将2个节点进行连接 + srcNode.connectNode(desNode); + + // 根据标识判断是否需要加入列表容器中 + if (!srcExist) { + this.nodes.add(srcNode); + } + + if (!desExist) { + this.nodes.add(desNode); + } + } + } + + /** + * 查询条件概率 + * + * @param attrValues 条件属性值 + */ + private double queryConditionPro(ArrayList attrValues){ + // 判断是否满足先验属性值条件 + boolean hasPrior; + // 判断是否满足后验属性值条件 + boolean hasBack; + int priorIndex; + int attrIndex; + double backPro; + double totalPro; + double pro; + double currentPro; + // 先验属性 + String[] priorValue; + String[] tempData; + + totalPro = 0; + backPro = 0; + attrValues.get(0); + priorValue = attrValues.get(0); + // 得到后验概率 + attrValues.remove(0); + + // 取出先验属性的列数 + priorIndex = this.attr2Column.get(priorValue[0]); + // 跳过第一行的属性名称行 + for (int i = 1; i < this.totalData.length; i++) { + tempData = this.totalData[i]; + + hasPrior = false; + hasBack = true; + + // 当前行的概率 + currentPro = Double.parseDouble(tempData[this.columns - 1]); + // 判断是否满足先验条件 + if (tempData[priorIndex].equals(priorValue[1])) { + hasPrior = true; + } + + for (String[] array : attrValues) { + attrIndex = this.attr2Column.get(array[0]); + + // 判断值是否满足条件 + if (!tempData[attrIndex].equals(array[1])) { + hasBack = false; + break; + } + } + + // 进行计数统计,分别计算满足后验属性的值和同时满足条件的个数 + if (hasBack) { + backPro += currentPro; + if (hasPrior) { + totalPro += currentPro; + } + } else if (hasPrior && attrValues.size() == 0) { + // 如果只有先验概率则为纯概率的计算 + totalPro += currentPro; + backPro = 1.0; + } + } + + // 计算总的概率=都发生概率/只发生后验条件的时间概率 + pro = totalPro / backPro; + + return pro; + } + + /** + * 根据贝叶斯网络计算概率 + * + * @param queryStr 查询条件串 + */ + double calProByNetWork(String queryStr){ + double temp; + double pro; + String[] array; + // 先验条件值 + String[] preValue; + // 后验条件值 + String[] backValue; + // 所有先验条件和后验条件值的属性值的汇总 + ArrayList attrValues; + + // 判断是否满足网络结构 + if (!satisfiedNewWork(queryStr)) { + return -1; + } + + // 首先做查询条件的分解 + array = queryStr.split(","); + + // 概率的初值等于第一个事件发生的随机概率 + attrValues = new ArrayList<>(); + attrValues.add(array[0].split("=")); + pro = queryConditionPro(attrValues); + + for (int i = 0; i < array.length - 1; i++) { + attrValues.clear(); + + // 下标小的在前面的属于后验属性 + backValue = array[i].split("="); + preValue = array[i + 1].split("="); + attrValues.add(preValue); + attrValues.add(backValue); + + // 算出此种情况的概率值 + temp = queryConditionPro(attrValues); + // 进行积的相乘 + pro *= temp; + } + + return pro; + } + + /** + * 验证事件的查询因果关系是否满足贝叶斯网络 + * + * @param queryStr 查询字符串 + */ + private boolean satisfiedNewWork(String queryStr){ + String attrName; + String[] array; + boolean isExist; + boolean isSatisfied; + // 当前节点 + Node currentNode; + // 候选节点列表 + ArrayList nodeList; + + isSatisfied = true; + // 做查询字符串的分解 + array = queryStr.split(","); + nodeList = this.nodes; + + for (String s : array) { + // 开始时默认属性对应的节点不存在 + isExist = false; + // 得到属性事件名 + attrName = s.split("=")[0]; + + for (Node n : nodeList) { + if (n.name.equals(attrName)) { + isExist = true; + + currentNode = n; + // 下一轮的候选节点为当前节点的孩子节点 + nodeList = currentNode.childNodes; + + break; + } + } + + // 如果存在未找到的节点,则说明不满足依赖结构跳出循环 + if (!isExist) { + isSatisfied = false; + break; + } + } + + return isSatisfied; + } } diff --git a/Others/DataMining_BayesNetwork/Client.java b/Others/DataMining_BayesNetwork/Client.java index 98706c4..80cccab 100644 --- a/Others/DataMining_BayesNetwork/Client.java +++ b/Others/DataMining_BayesNetwork/Client.java @@ -1,32 +1,29 @@ -package DataMining_BayesNetwork; - -import java.text.MessageFormat; +package Others.DataMining_BayesNetwork; /** * 贝叶斯网络场景测试类 - * - * @author lyq - * + * + * @author Qstar */ public class Client { - public static void main(String[] args) { - String dataFilePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - String attachFilePath = "C:\\Users\\lyq\\Desktop\\icon\\attach.txt"; - // 查询串语句 - String queryStr; - // 结果概率 - double result; + public static void main(String[] args){ + String dataFilePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_BayesNetwork/input.txt"; + String attachFilePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_BayesNetwork/attach.txt"; + // 查询串语句 + String queryStr; + // 结果概率 + double result; - // 查询语句的描述的事件是地震发生了,导致响铃响了,导致接到Mary的电话 - queryStr = "E=y,A=y,M=y"; - BayesNetWorkTool tool = new BayesNetWorkTool(dataFilePath, - attachFilePath); - result = tool.calProByNetWork(queryStr); + // 查询语句的描述的事件是地震发生了,导致响铃响了,导致接到Mary的电话 + queryStr = "E=y,A=y,M=y"; + BayesNetWorkTool tool = new BayesNetWorkTool(dataFilePath, + attachFilePath); + result = tool.calProByNetWork(queryStr); - if (result == -1) { - System.out.println("所描述的事件不满足贝叶斯网络的结构,无法求其概率"); - } else { - System.out.println(String.format("事件%s发生的概率为%s", queryStr, result)); - } - } + if (result == -1) { + System.out.println("所描述的事件不满足贝叶斯网络的结构,无法求其概率"); + } else { + System.out.println(String.format("事件%s发生的概率为%s", queryStr, result)); + } + } } diff --git a/Others/DataMining_BayesNetwork/Node.java b/Others/DataMining_BayesNetwork/Node.java index bb2a07d..1db2dc4 100644 --- a/Others/DataMining_BayesNetwork/Node.java +++ b/Others/DataMining_BayesNetwork/Node.java @@ -1,58 +1,54 @@ -package DataMining_BayesNetwork; +package Others.DataMining_BayesNetwork; import java.util.ArrayList; /** * 贝叶斯网络节点类 - * - * @author lyq - * + * + * @author Qstar */ -public class Node { - // 节点的属性名称 - String name; - // 节点的父亲节点,也就是上游节点,可能多个 - ArrayList parentNodes; - // 节点的子节点,也就是下游节点,可能多个 - ArrayList childNodes; - - public Node(String name) { - this.name = name; - - // 初始化变量 - this.parentNodes = new ArrayList<>(); - this.childNodes = new ArrayList<>(); - } - - /** - * 将自身节点连接到目标给定的节点 - * - * @param node - * 下游节点 - */ - public void connectNode(Node node) { - // 将下游节点加入自身节点的孩子节点中 - this.childNodes.add(node); - // 将自身节点加入到下游节点的父节点中 - node.parentNodes.add(this); - } - - /** - * 判断与目标节点是否相同,主要比较名称是否相同即可 - * - * @param node - * 目标结点 - * @return - */ - public boolean isEqual(Node node) { - boolean isEqual; - - isEqual = false; - // 节点名称相同则视为相等 - if (this.name.equals(node.name)) { - isEqual = true; - } - - return isEqual; - } +class Node { + // 节点的属性名称 + String name; + // 节点的子节点,也就是下游节点,可能多个 + ArrayList childNodes; + // 节点的父亲节点,也就是上游节点,可能多个 + private ArrayList parentNodes; + + Node(String name){ + this.name = name; + + // 初始化变量 + this.parentNodes = new ArrayList<>(); + this.childNodes = new ArrayList<>(); + } + + /** + * 将自身节点连接到目标给定的节点 + * + * @param node 下游节点 + */ + void connectNode(Node node){ + // 将下游节点加入自身节点的孩子节点中 + this.childNodes.add(node); + // 将自身节点加入到下游节点的父节点中 + node.parentNodes.add(this); + } + + /** + * 判断与目标节点是否相同,主要比较名称是否相同即可 + * + * @param node 目标结点 + */ + boolean isEqual(Node node){ + boolean isEqual; + + isEqual = false; + // 节点名称相同则视为相等 + if (this.name.equals(node.name)) { + isEqual = true; + } + + return isEqual; + } } diff --git a/Others/DataMining_CABDDCC/CABDDCCTool.java b/Others/DataMining_CABDDCC/CABDDCCTool.java index 34081b4..3fd93b7 100644 --- a/Others/DataMining_CABDDCC/CABDDCCTool.java +++ b/Others/DataMining_CABDDCC/CABDDCCTool.java @@ -1,4 +1,4 @@ -package DataMining_CABDDCC; +package Others.DataMining_CABDDCC; import java.io.BufferedReader; import java.io.File; @@ -9,94 +9,93 @@ /** * 基于连通图的分裂聚类算法 - * - * @author lyq - * + * + * @author Qstar */ -public class CABDDCCTool { - // 测试数据点数据 - private String filePath; - // 连通图距离阈值l - private int length; - // 原始坐标点 - public static ArrayList totalPoints; - // 聚类结果坐标点集合 - private ArrayList> resultClusters; - // 连通图 - private Graph graph; +class CABDDCCTool { + // 原始坐标点 + static ArrayList totalPoints; + // 测试数据点数据 + private String filePath; + // 连通图距离阈值l + private int length; + // 聚类结果坐标点集合 + private ArrayList> resultClusters; + // 连通图 + private Graph graph; - public CABDDCCTool(String filePath, int length) { - this.filePath = filePath; - this.length = length; + CABDDCCTool(String filePath, int length){ + this.filePath = filePath; + this.length = length; - readDataFile(); - } + readDataFile(); + } - /** - * 从文件中读取数据 - */ - public void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); + /** + * 从文件中读取数据 + */ + public void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } - Point p; - totalPoints = new ArrayList<>(); - for (String[] array : dataArray) { - p = new Point(array[0], array[1], array[2]); - totalPoints.add(p); - } + Point p; + totalPoints = new ArrayList<>(); + for (String[] array : dataArray) { + p = new Point(array[0], array[1], array[2]); + totalPoints.add(p); + } - // 用边和点构造图 - graph = new Graph(null, totalPoints); - } + // 用边和点构造图 + graph = new Graph(null, totalPoints); + } - /** - * 分裂连通图得到聚类 - */ - public void splitCluster() { - // 获取形成连通子图 - ArrayList subGraphs; - ArrayList> pointList; - resultClusters = new ArrayList<>(); + /** + * 分裂连通图得到聚类 + */ + void splitCluster(){ + // 获取形成连通子图 + ArrayList subGraphs; + ArrayList> pointList; + resultClusters = new ArrayList<>(); - subGraphs = graph.splitGraphByLength(length); + subGraphs = graph.splitGraphByLength(length); - for (Graph g : subGraphs) { - // 获取每个连通子图分裂后的聚类结果 - pointList = g.getClusterByDivding(); - resultClusters.addAll(pointList); - } - - printResultCluster(); - } + for (Graph g : subGraphs) { + // 获取每个连通子图分裂后的聚类结果 + pointList = g.getClusterByDivding(); + resultClusters.addAll(pointList); + } - /** - * 输出结果聚簇 - */ - private void printResultCluster() { - int i = 1; - for (ArrayList cluster : resultClusters) { - System.out.print("聚簇" + i + ":"); - for (Point p : cluster){ - System.out.print(MessageFormat.format("({0}, {1}) ", p.x, p.y)); - } - System.out.println(); - i++; - } - - } + printResultCluster(); + } + + /** + * 输出结果聚簇 + */ + private void printResultCluster(){ + int i = 1; + for (ArrayList cluster : resultClusters) { + System.out.print("聚簇" + i + ":"); + for (Point p : cluster) { + System.out.print(MessageFormat.format("({0}, {1}) ", p.x, p.y)); + } + System.out.println(); + i++; + } + + } } diff --git a/Others/DataMining_CABDDCC/Client.java b/Others/DataMining_CABDDCC/Client.java index c57e3f5..663105a 100644 --- a/Others/DataMining_CABDDCC/Client.java +++ b/Others/DataMining_CABDDCC/Client.java @@ -1,17 +1,17 @@ -package DataMining_CABDDCC; +package Others.DataMining_CABDDCC; /** * 基于连通图的分裂聚类算法 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] agrs){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\graphData.txt"; - //连通距离阈值 - int length = 3; - - CABDDCCTool tool = new CABDDCCTool(filePath, length); - tool.splitCluster(); - } + public static void main(String[] agrs){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_CABDDCC/graphData.txt"; + //连通距离阈值 + int length = 3; + + CABDDCCTool tool = new CABDDCCTool(filePath, length); + tool.splitCluster(); + } } diff --git a/Others/DataMining_CABDDCC/Graph.java b/Others/DataMining_CABDDCC/Graph.java index b59d06e..d696951 100644 --- a/Others/DataMining_CABDDCC/Graph.java +++ b/Others/DataMining_CABDDCC/Graph.java @@ -1,287 +1,250 @@ -package DataMining_CABDDCC; +package Others.DataMining_CABDDCC; import java.util.ArrayList; import java.util.Collections; /** * 连通图类 - * - * @author lyq - * + * + * @author Qstar */ -public class Graph { - // 坐标点之间的连接属性,括号内为坐标id号 - int[][] edges; - // 连通图内的坐标点数 - ArrayList points; - // 此图下分割后的聚类子图 - ArrayList> clusters; - - public Graph(int[][] edges) { - this.edges = edges; - this.points = getPointByEdges(edges); - } - - public Graph(int[][] edges, ArrayList points) { - this.edges = edges; - this.points = points; - } - - public int[][] getEdges() { - return edges; - } - - public void setEdges(int[][] edges) { - this.edges = edges; - } - - public ArrayList getPoints() { - return points; - } - - public void setPoints(ArrayList points) { - this.points = points; - } - - /** - * 根据距离阈值做连通图的划分,构成连通图集 - * - * @param length - * 距离阈值 - * @return - */ - public ArrayList splitGraphByLength(int length) { - int[][] edges; - Graph tempGraph; - ArrayList graphs = new ArrayList<>(); - - for (Point p : points) { - if (!p.isVisited) { - // 括号中的下标为id号 - edges = new int[points.size()][points.size()]; - dfsExpand(p, length, edges); - - tempGraph = new Graph(edges); - graphs.add(tempGraph); - } else { - continue; - } - } - - return graphs; - } - - /** - * 深度优先方式扩展连通图 - * - * @param points - * 需要继续深搜的坐标点 - * @param length - * 距离阈值 - * @param edges - * 边数组 - */ - private void dfsExpand(Point point, int length, int edges[][]) { - int id1 = 0; - int id2 = 0; - double distance = 0; - ArrayList tempPoints; - - // 如果处理过了,则跳过 - if (point.isVisited) { - return; - } - - id1 = point.id; - point.isVisited = true; - tempPoints = new ArrayList<>(); - for (Point p2 : points) { - id2 = p2.id; - - if (id1 == id2) { - continue; - } else { - distance = point.ouDistance(p2); - if (distance <= length) { - edges[id1][id2] = 1; - edges[id2][id1] = 1; - - tempPoints.add(p2); - } - } - } - - // 继续递归 - for (Point p : tempPoints) { - dfsExpand(p, length, edges); - } - } - - /** - * 判断连通图是否还需要再被划分 - * - * @param pointList1 - * 坐标点集合1 - * @param pointList2 - * 坐标点集合2 - * @return - */ - private boolean needDivided(ArrayList pointList1, - ArrayList pointList2) { - boolean needDivided = false; - // 承受系数t=轻的集合的坐标点数/2部分连接的边数 - double t = 0; - // 分裂阈值,即平均每边所要承受的重量 - double landa = 0; - int pointNum1 = pointList1.size(); - int pointNum2 = pointList2.size(); - // 总边数 - int totalEdgeNum = 0; - // 连接2部分的边数量 - int connectedEdgeNum = 0; - ArrayList totalPoints = new ArrayList<>(); - - totalPoints.addAll(pointList1); - totalPoints.addAll(pointList2); - int id1 = 0; - int id2 = 0; - for (Point p1 : totalPoints) { - id1 = p1.id; - for (Point p2 : totalPoints) { - id2 = p2.id; - - if (edges[id1][id2] == 1 && id1 < id2) { - if ((pointList1.contains(p1) && pointList2.contains(p2)) - || (pointList1.contains(p2) && pointList2 - .contains(p1))) { - connectedEdgeNum++; - } - totalEdgeNum++; - } - } - } - - if (pointNum1 < pointNum2) { - // 承受系数t=轻的集合的坐标点数/连接2部分的边数 - t = 1.0 * pointNum1 / connectedEdgeNum; - } else { - t = 1.0 * pointNum2 / connectedEdgeNum; - } - - // 计算分裂阈值,括号内为总边数/总点数,就是平均每边所承受的点数量 - landa = 0.5 * Math.exp((1.0 * totalEdgeNum / (pointNum1 + pointNum2))); - - // 如果承受系数不小于分裂阈值,则代表需要分裂 - if (t >= landa) { - needDivided = true; - } - - return needDivided; - } - - /** - * 递归的划分连通图 - * - * @param pointList - * 待划分的连通图的所有坐标点 - */ - public void divideGraph(ArrayList pointList) { - // 判断此坐标点集合是否能够被分割 - boolean canDivide = false; - ArrayList> pointGroup; - ArrayList pointList1 = new ArrayList<>(); - ArrayList pointList2 = new ArrayList<>(); - - for (int m = 2; m <= pointList.size() / 2; m++) { - // 进行坐标点的分割 - pointGroup = removePoint(pointList, m); - pointList1 = pointGroup.get(0); - pointList2 = pointGroup.get(1); - - // 判断是否满足分裂条件 - if (needDivided(pointList1, pointList2)) { - canDivide = true; - divideGraph(pointList1); - divideGraph(pointList2); - } - } - - // 如果所有的分割组合都无法分割,则说明此已经是一个聚类 - if (!canDivide) { - clusters.add(pointList); - } - } - - /** - * 获取分裂得到的聚类结果 - * - * @return - */ - public ArrayList> getClusterByDivding() { - clusters = new ArrayList<>(); - - divideGraph(points); - - return clusters; - } - - /** - * 将当前坐标点集合移除removeNum个点,构成2个子坐标点集合 - * - * @param pointList - * 原集合点 - * @param removeNum - * 移除的数量 - */ - private ArrayList> removePoint(ArrayList pointList, - int removeNum) { - //浅拷贝一份原坐标点数据 - ArrayList copyPointList = (ArrayList) pointList.clone(); - ArrayList> pointGroup = new ArrayList<>(); - ArrayList pointList2 = new ArrayList<>(); - // 进行按照坐标轴大小排序 - Collections.sort(copyPointList); - - for (int i = 0; i < removeNum; i++) { - pointList2.add(copyPointList.get(i)); - } - copyPointList.removeAll(pointList2); - - pointGroup.add(copyPointList); - pointGroup.add(pointList2); - - return pointGroup; - } - - /** - * 根据边的情况获取其中的点 - * - * @param edges - * 当前的已知的边的情况 - * @return - */ - private ArrayList getPointByEdges(int[][] edges) { - Point p1; - Point p2; - ArrayList pointList = new ArrayList<>(); - - for (int i = 0; i < edges.length; i++) { - for (int j = 0; j < edges[0].length; j++) { - if (edges[i][j] == 1) { - p1 = CABDDCCTool.totalPoints.get(i); - p2 = CABDDCCTool.totalPoints.get(j); - - if (!pointList.contains(p1)) { - pointList.add(p1); - } - - if (!pointList.contains(p2)) { - pointList.add(p2); - } - } - } - } - - return pointList; - } +class Graph { + // 坐标点之间的连接属性,括号内为坐标id号 + private int[][] edges; + // 连通图内的坐标点数 + private ArrayList points; + // 此图下分割后的聚类子图 + private ArrayList> clusters; + + private Graph(int[][] edges){ + this.edges = edges; + this.points = getPointByEdges(edges); + } + + Graph(int[][] edges, ArrayList points){ + this.edges = edges; + this.points = points; + } + + /** + * 根据距离阈值做连通图的划分,构成连通图集 + * + * @param length 距离阈值 + */ + ArrayList splitGraphByLength(int length){ + int[][] edges; + Graph tempGraph; + ArrayList graphs = new ArrayList<>(); + + for (Point p : points) { + if (!p.isVisited) { + // 括号中的下标为id号 + edges = new int[points.size()][points.size()]; + dfsExpand(p, length, edges); + + tempGraph = new Graph(edges); + graphs.add(tempGraph); + } + } + return graphs; + } + + /** + * 深度优先方式扩展连通图 + * + * @param point 需要继续深搜的坐标点 + * @param length 距离阈值 + * @param edges 边数组 + */ + private void dfsExpand(Point point, int length, int edges[][]){ + int id1; + int id2; + double distance; + ArrayList tempPoints; + + // 如果处理过了,则跳过 + if (point.isVisited) { + return; + } + + id1 = point.id; + point.isVisited = true; + tempPoints = new ArrayList<>(); + for (Point p2 : points) { + id2 = p2.id; + + if (id1 != id2) { + distance = point.ouDistance(p2); + if (distance <= length) { + edges[id1][id2] = 1; + edges[id2][id1] = 1; + + tempPoints.add(p2); + } + } + } + + // 继续递归 + for (Point p : tempPoints) { + dfsExpand(p, length, edges); + } + } + + /** + * 判断连通图是否还需要再被划分 + * + * @param pointList1 坐标点集合1 + * @param pointList2 坐标点集合2 + */ + private boolean needDivided(ArrayList pointList1, + ArrayList pointList2){ + boolean needDivided = false; + // 承受系数t=轻的集合的坐标点数/2部分连接的边数 + double t; + // 分裂阈值,即平均每边所要承受的重量 + double landa; + int pointNum1 = pointList1.size(); + int pointNum2 = pointList2.size(); + // 总边数 + int totalEdgeNum = 0; + // 连接2部分的边数量 + int connectedEdgeNum = 0; + ArrayList totalPoints = new ArrayList<>(); + + totalPoints.addAll(pointList1); + totalPoints.addAll(pointList2); + int id1; + int id2; + for (Point p1 : totalPoints) { + id1 = p1.id; + for (Point p2 : totalPoints) { + id2 = p2.id; + + if (edges[id1][id2] == 1 && id1 < id2) { + if ((pointList1.contains(p1) && pointList2.contains(p2)) + || (pointList1.contains(p2) && pointList2 + .contains(p1))) { + connectedEdgeNum++; + } + totalEdgeNum++; + } + } + } + + if (pointNum1 < pointNum2) { + // 承受系数t=轻的集合的坐标点数/连接2部分的边数 + t = 1.0 * pointNum1 / connectedEdgeNum; + } else { + t = 1.0 * pointNum2 / connectedEdgeNum; + } + + // 计算分裂阈值,括号内为总边数/总点数,就是平均每边所承受的点数量 + landa = 0.5 * Math.exp((1.0 * totalEdgeNum / (pointNum1 + pointNum2))); + + // 如果承受系数不小于分裂阈值,则代表需要分裂 + if (t >= landa) { + needDivided = true; + } + + return needDivided; + } + + /** + * 递归的划分连通图 + * + * @param pointList 待划分的连通图的所有坐标点 + */ + private void divideGraph(ArrayList pointList){ + // 判断此坐标点集合是否能够被分割 + boolean canDivide = false; + ArrayList> pointGroup; + ArrayList pointList1; + ArrayList pointList2; + + for (int m = 2; m <= pointList.size() / 2; m++) { + // 进行坐标点的分割 + pointGroup = removePoint(pointList, m); + pointList1 = pointGroup.get(0); + pointList2 = pointGroup.get(1); + + // 判断是否满足分裂条件 + if (needDivided(pointList1, pointList2)) { + canDivide = true; + divideGraph(pointList1); + divideGraph(pointList2); + } + } + + // 如果所有的分割组合都无法分割,则说明此已经是一个聚类 + if (!canDivide) { + clusters.add(pointList); + } + } + + /** + * 获取分裂得到的聚类结果 + */ + ArrayList> getClusterByDivding(){ + clusters = new ArrayList<>(); + + divideGraph(points); + + return clusters; + } + + /** + * 将当前坐标点集合移除removeNum个点,构成2个子坐标点集合 + * + * @param pointList 原集合点 + * @param removeNum 移除的数量 + */ + private ArrayList> removePoint(ArrayList pointList, + int removeNum){ + //浅拷贝一份原坐标点数据 + ArrayList copyPointList = (ArrayList) pointList.clone(); + ArrayList> pointGroup = new ArrayList<>(); + ArrayList pointList2 = new ArrayList<>(); + // 进行按照坐标轴大小排序 + Collections.sort(copyPointList); + + for (int i = 0; i < removeNum; i++) { + pointList2.add(copyPointList.get(i)); + } + copyPointList.removeAll(pointList2); + + pointGroup.add(copyPointList); + pointGroup.add(pointList2); + + return pointGroup; + } + + /** + * 根据边的情况获取其中的点 + * + * @param edges 当前的已知的边的情况 + */ + private ArrayList getPointByEdges(int[][] edges){ + Point p1; + Point p2; + ArrayList pointList = new ArrayList<>(); + + for (int i = 0; i < edges.length; i++) { + for (int j = 0; j < edges[0].length; j++) { + if (edges[i][j] == 1) { + p1 = CABDDCCTool.totalPoints.get(i); + p2 = CABDDCCTool.totalPoints.get(j); + + if (!pointList.contains(p1)) { + pointList.add(p1); + } + + if (!pointList.contains(p2)) { + pointList.add(p2); + } + } + } + } + + return pointList; + } } diff --git a/Others/DataMining_CABDDCC/Point.java b/Others/DataMining_CABDDCC/Point.java index 2763be4..010ff78 100644 --- a/Others/DataMining_CABDDCC/Point.java +++ b/Others/DataMining_CABDDCC/Point.java @@ -1,69 +1,65 @@ -package DataMining_CABDDCC; - +package Others.DataMining_CABDDCC; +import java.util.Objects; /** * 坐标点类 - * @author lyq * + * @author Qstar */ -public class Point implements Comparable{ - //坐标点id号,id号唯一 - int id; - //坐标横坐标 - Integer x; - //坐标纵坐标 - Integer y; - //坐标点是否已经被访问(处理)过,在生成连通子图的时候用到 - boolean isVisited; - - public Point(String id, String x, String y){ - this.id = Integer.parseInt(id); - this.x = Integer.parseInt(x); - this.y = Integer.parseInt(y); - } - - /** - * 计算当前点与制定点之间的欧式距离 - * - * @param p - * 待计算聚类的p点 - * @return - */ - public double ouDistance(Point p) { - double distance = 0; +public class Point implements Comparable { + //坐标点id号,id号唯一 + int id; + //坐标横坐标 + Integer x; + //坐标纵坐标 + Integer y; + //坐标点是否已经被访问(处理)过,在生成连通子图的时候用到 + boolean isVisited; + + public Point(String id, String x, String y){ + this.id = Integer.parseInt(id); + this.x = Integer.parseInt(x); + this.y = Integer.parseInt(y); + } + + /** + * 计算当前点与制定点之间的欧式距离 + * + * @param p 待计算聚类的p点 + */ + double ouDistance(Point p){ + double distance; + + distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y) + * (this.y - p.y); + distance = Math.sqrt(distance); - distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y) - * (this.y - p.y); - distance = Math.sqrt(distance); + return distance; + } - return distance; - } - - /** - * 判断2个坐标点是否为用个坐标点 - * - * @param p - * 待比较坐标点 - * @return - */ - public boolean isTheSame(Point p) { - boolean isSamed = false; + /** + * 判断2个坐标点是否为用个坐标点 + * + * @param p 待比较坐标点 + */ + public boolean isTheSame(Point p){ + boolean isSamed = false; - if (this.x == p.x && this.y == p.y) { - isSamed = true; - } + if (Objects.equals(this.x, p.x) && Objects.equals(this.y, p.y)) { + isSamed = true; + } - return isSamed; - } + return isSamed; + } - @Override - public int compareTo(Point p) { - if(this.x.compareTo(p.x) != 0){ - return this.x.compareTo(p.x); - }else{ - //如果在x坐标相等的情况下比较y坐标 - return this.y.compareTo(p.y); - } - } + @Override + public int compareTo(Point p){ + if (this.x.compareTo(p.x) != 0) { + return this.x.compareTo(p.x); + } else { + //如果在x坐标相等的情况下比较y坐标 + return this.y.compareTo(p.y); + } + } } diff --git a/Others/DataMining_Chameleon/ChameleonTool.java b/Others/DataMining_Chameleon/ChameleonTool.java index 811ea3d..2f81a70 100644 --- a/Others/DataMining_Chameleon/ChameleonTool.java +++ b/Others/DataMining_Chameleon/ChameleonTool.java @@ -1,4 +1,4 @@ -package DataMining_Chameleon; +package Others.DataMining_Chameleon; import java.io.BufferedReader; import java.io.File; @@ -9,415 +9,400 @@ /** * Chameleon 两阶段聚类算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class ChameleonTool { - // 测试数据点文件地址 - private String filePath; - // 第一阶段的k近邻的k大小 - private int k; - // 簇度量函数阈值 - private double minMetric; - // 总的坐标点的个数 - private int pointNum; - // 总的连接矩阵的情况,括号表示的是坐标点的id号 - public static int[][] edges; - // 点与点之间的边的权重 - public static double[][] weights; - // 原始坐标点数据 - private ArrayList totalPoints; - // 第一阶段产生的所有的连通子图作为最初始的聚类 - private ArrayList initClusters; - // 结果簇结合 - private ArrayList resultClusters; - - public ChameleonTool(String filePath, int k, double minMetric) { - this.filePath = filePath; - this.k = k; - this.minMetric = minMetric; - - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - Point p; - totalPoints = new ArrayList<>(); - for (String[] array : dataArray) { - p = new Point(array[0], array[1], array[2]); - totalPoints.add(p); - } - pointNum = totalPoints.size(); - } - - /** - * 递归的合并小聚簇 - */ - private void combineSubClusters() { - Cluster cluster = null; - - resultClusters = new ArrayList<>(); - - // 当最后的聚簇只剩下一个的时候,则退出循环 - while (initClusters.size() > 1) { - cluster = initClusters.get(0); - combineAndRemove(cluster, initClusters); - } - } - - /** - * 递归的合并聚簇和移除聚簇 - * - * @param clusterList - */ - private ArrayList combineAndRemove(Cluster cluster, - ArrayList clusterList) { - ArrayList remainClusters; - double metric = 0; - double maxMetric = -Integer.MAX_VALUE; - Cluster cluster1 = null; - Cluster cluster2 = null; - - for (Cluster c2 : clusterList) { - if(cluster.id == c2.id){ - continue; - } - - metric = calMetricfunction(cluster, c2, 1); - - if (metric > maxMetric) { - maxMetric = metric; - cluster1 = cluster; - cluster2 = c2; - } - } - - // 如果度量函数值超过阈值,则进行合并,继续搜寻可以合并的簇 - if (maxMetric > minMetric) { - clusterList.remove(cluster2); - //将边进行连接 - connectClusterToCluster(cluster1, cluster2); - // 将簇1和簇2合并 - cluster1.points.addAll(cluster2.points); - remainClusters = combineAndRemove(cluster1, clusterList); - } else { - clusterList.remove(cluster); - remainClusters = clusterList; - resultClusters.add(cluster); - } - - return remainClusters; - } - - /** - * 将2个簇进行边的连接 - * @param c1 - * 聚簇1 - * @param c2 - * 聚簇2 - */ - private void connectClusterToCluster(Cluster c1, Cluster c2){ - ArrayList connectedEdges; - - connectedEdges = c1.calNearestEdge(c2, 2); - - for(int[] array: connectedEdges){ - edges[array[0]][array[1]] = 1; - edges[array[1]][array[0]] = 1; - } - } - - /** - * 算法第一阶段形成局部的连通图 - */ - private void connectedGraph() { - double distance = 0; - Point p1; - Point p2; - - // 初始化权重矩阵和连接矩阵 - weights = new double[pointNum][pointNum]; - edges = new int[pointNum][pointNum]; - for (int i = 0; i < pointNum; i++) { - for (int j = 0; j < pointNum; j++) { - p1 = totalPoints.get(i); - p2 = totalPoints.get(j); - - distance = p1.ouDistance(p2); - if (distance == 0) { - // 如果点为自身的话,则权重设置为0 - weights[i][j] = 0; - } else { - // 边的权重采用的值为距离的倒数,距离越近,权重越大 - weights[i][j] = 1.0 / distance; - } - } - } - - double[] tempWeight; - int[] ids; - int id1 = 0; - int id2 = 0; - // 对每个id坐标点,取其权重前k个最大的点进行相连 - for (int i = 0; i < pointNum; i++) { - tempWeight = weights[i]; - // 进行排序 - ids = sortWeightArray(tempWeight); - - // 取出前k个权重最大的边进行连接 - for (int j = 0; j < ids.length; j++) { - if (j < k) { - id1 = i; - id2 = ids[j]; - - edges[id1][id2] = 1; - edges[id2][id1] = 1; - } - } - } - } - - /** - * 权重的冒泡算法排序 - * - * @param array - * 待排序数组 - */ - private int[] sortWeightArray(double[] array) { - double[] copyArray = array.clone(); - int[] ids = null; - int k = 0; - double maxWeight = -1; - - ids = new int[pointNum]; - for(int i=0; i maxWeight){ - maxWeight = copyArray[j]; - k = j; - } - } - - ids[i] = k; - //将当前找到的最大的值重置为-1代表已经找到过了 - copyArray[k] = -1; - } - - return ids; - } - - /** - * 根据边的连通性去深度优先搜索所有的小聚簇 - */ - private void searchSmallCluster() { - int currentId = 0; - Point p; - Cluster cluster; - initClusters = new ArrayList<>(); - ArrayList pointList = null; - - // 以id的方式逐个去dfs搜索 - for (int i = 0; i < pointNum; i++) { - p = totalPoints.get(i); - - if (p.isVisited) { - continue; - } - - pointList = new ArrayList<>(); - pointList.add(p); - recusiveDfsSearch(p, -1, pointList); - - cluster = new Cluster(currentId, pointList); - initClusters.add(cluster); - - currentId++; - } - } - - /** - * 深度优先的方式找到边所连接着的所有坐标点 - * - * @param p - * 当前搜索的起点 - * @param lastId - * 此点的父坐标点 - * @param pList - * 坐标点列表 - */ - private void recusiveDfsSearch(Point p, int parentId, ArrayList pList) { - int id1 = 0; - int id2 = 0; - Point newPoint; - - if (p.isVisited) { - return; - } - - p.isVisited = true; - for (int j = 0; j < pointNum; j++) { - id1 = p.id; - id2 = j; - - if (edges[id1][id2] == 1 && id2 != parentId) { - newPoint = totalPoints.get(j); - pList.add(newPoint); - // 以此点为起点,继续递归搜索 - recusiveDfsSearch(newPoint, id1, pList); - } - } - } - - /** - * 计算连接2个簇的边的权重 - * - * @param c1 - * 聚簇1 - * @param c2 - * 聚簇2 - * @return - */ - private double calEC(Cluster c1, Cluster c2) { - double resultEC = 0; - ArrayList connectedEdges = null; - - connectedEdges = c1.calNearestEdge(c2, 2); - - // 计算连接2部分的边的权重和 - for (int[] array : connectedEdges) { - resultEC += weights[array[0]][array[1]]; - } - - return resultEC; - } - - /** - * 计算2个簇的相对互连性 - * - * @param c1 - * @param c2 - * @return - */ - private double calRI(Cluster c1, Cluster c2) { - double RI = 0; - double EC1 = 0; - double EC2 = 0; - double EC1To2 = 0; - - EC1 = c1.calEC(); - EC2 = c2.calEC(); - EC1To2 = calEC(c1, c2); - - RI = 2 * EC1To2 / (EC1 + EC2); - - return RI; - } - - /** - * 计算簇的相对近似度 - * - * @param c1 - * 簇1 - * @param c2 - * 簇2 - * @return - */ - private double calRC(Cluster c1, Cluster c2) { - double RC = 0; - double EC1 = 0; - double EC2 = 0; - double EC1To2 = 0; - int pNum1 = c1.points.size(); - int pNum2 = c2.points.size(); - - EC1 = c1.calEC(); - EC2 = c2.calEC(); - EC1To2 = calEC(c1, c2); - - RC = EC1To2 * (pNum1 + pNum2) / (pNum2 * EC1 + pNum1 * EC2); - - return RC; - } - - /** - * 计算度量函数的值 - * - * @param c1 - * 簇1 - * @param c2 - * 簇2 - * @param alpha - * 幂的参数值 - * @return - */ - private double calMetricfunction(Cluster c1, Cluster c2, int alpha) { - // 度量函数值 - double metricValue = 0; - double RI = 0; - double RC = 0; - - RI = calRI(c1, c2); - RC = calRC(c1, c2); - // 如果alpha大于1,则更重视相对近似性,如果alpha逍遥于1,注重相对互连性 - metricValue = RI * Math.pow(RC, alpha); - - return metricValue; - } - - /** - * 输出聚簇列 - * @param clusterList - * 输出聚簇列 - */ - private void printClusters(ArrayList clusterList) { - int i = 1; - - for (Cluster cluster : clusterList) { - System.out.print("聚簇" + i + ":"); - for (Point p : cluster.points) { - System.out.print(MessageFormat.format("({0}, {1}) ", p.x, p.y)); - } - System.out.println(); - i++; - } - - } - - /** - * 创建聚簇 - */ - public void buildCluster() { - // 第一阶段形成小聚簇 - connectedGraph(); - searchSmallCluster(); - System.out.println("第一阶段形成的小簇集合:"); - printClusters(initClusters); - - // 第二阶段根据RI和RC的值合并小聚簇形成最终结果聚簇 - combineSubClusters(); - System.out.println("最终的聚簇集合:"); - printClusters(resultClusters); - } +class ChameleonTool { + // 总的连接矩阵的情况,括号表示的是坐标点的id号 + static int[][] edges; + // 点与点之间的边的权重 + static double[][] weights; + // 测试数据点文件地址 + private String filePath; + // 第一阶段的k近邻的k大小 + private int k; + // 簇度量函数阈值 + private double minMetric; + // 总的坐标点的个数 + private int pointNum; + // 原始坐标点数据 + private ArrayList totalPoints; + // 第一阶段产生的所有的连通子图作为最初始的聚类 + private ArrayList initClusters; + // 结果簇结合 + private ArrayList resultClusters; + + ChameleonTool(String filePath, int k, double minMetric){ + this.filePath = filePath; + this.k = k; + this.minMetric = minMetric; + + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + Point p; + totalPoints = new ArrayList<>(); + for (String[] array : dataArray) { + p = new Point(array[0], array[1], array[2]); + totalPoints.add(p); + } + pointNum = totalPoints.size(); + } + + /** + * 递归的合并小聚簇 + */ + private void combineSubClusters(){ + Cluster cluster; + + resultClusters = new ArrayList<>(); + + // 当最后的聚簇只剩下一个的时候,则退出循环 + while (initClusters.size() > 1) { + cluster = initClusters.get(0); + combineAndRemove(cluster, initClusters); + } + } + + /** + * 递归的合并聚簇和移除聚簇 + * + * @param clusterList 聚簇集合 + */ + private ArrayList combineAndRemove(Cluster cluster, + ArrayList clusterList){ + ArrayList remainClusters; + double metric; + double maxMetric = -Integer.MAX_VALUE; + Cluster cluster1 = null; + Cluster cluster2 = null; + + for (Cluster c2 : clusterList) { + if (cluster.id == c2.id) { + continue; + } + + metric = calMetricfunction(cluster, c2, 1); + + if (metric > maxMetric) { + maxMetric = metric; + cluster1 = cluster; + cluster2 = c2; + } + } + + // 如果度量函数值超过阈值,则进行合并,继续搜寻可以合并的簇 + if (maxMetric > minMetric) { + clusterList.remove(cluster2); + //将边进行连接 + connectClusterToCluster(cluster1, cluster2); + // 将簇1和簇2合并 + if (cluster1 != null) { + cluster1.points.addAll(cluster2.points); + } + remainClusters = combineAndRemove(cluster1, clusterList); + } else { + clusterList.remove(cluster); + remainClusters = clusterList; + resultClusters.add(cluster); + } + + return remainClusters; + } + + /** + * 将2个簇进行边的连接 + * + * @param c1 聚簇1 + * @param c2 聚簇2 + */ + private void connectClusterToCluster(Cluster c1, Cluster c2){ + ArrayList connectedEdges; + + connectedEdges = c1.calNearestEdge(c2, 2); + + for (int[] array : connectedEdges) { + edges[array[0]][array[1]] = 1; + edges[array[1]][array[0]] = 1; + } + } + + /** + * 算法第一阶段形成局部的连通图 + */ + private void connectedGraph(){ + double distance; + Point p1; + Point p2; + + // 初始化权重矩阵和连接矩阵 + weights = new double[pointNum][pointNum]; + edges = new int[pointNum][pointNum]; + for (int i = 0; i < pointNum; i++) { + for (int j = 0; j < pointNum; j++) { + p1 = totalPoints.get(i); + p2 = totalPoints.get(j); + + distance = p1.ouDistance(p2); + if (distance == 0) { + // 如果点为自身的话,则权重设置为0 + weights[i][j] = 0; + } else { + // 边的权重采用的值为距离的倒数,距离越近,权重越大 + weights[i][j] = 1.0 / distance; + } + } + } + + double[] tempWeight; + int[] ids; + int id1; + int id2; + // 对每个id坐标点,取其权重前k个最大的点进行相连 + for (int i = 0; i < pointNum; i++) { + tempWeight = weights[i]; + // 进行排序 + ids = sortWeightArray(tempWeight); + + // 取出前k个权重最大的边进行连接 + for (int j = 0; j < ids.length; j++) { + if (j < k) { + id1 = i; + id2 = ids[j]; + + edges[id1][id2] = 1; + edges[id2][id1] = 1; + } + } + } + } + + /** + * 权重的冒泡算法排序 + * + * @param array 待排序数组 + */ + private int[] sortWeightArray(double[] array){ + double[] copyArray = array.clone(); + int[] ids; + int k = 0; + double maxWeight; + + ids = new int[pointNum]; + for (int i = 0; i < pointNum; i++) { + maxWeight = -1; + + for (int j = 0; j < copyArray.length; j++) { + if (copyArray[j] > maxWeight) { + maxWeight = copyArray[j]; + k = j; + } + } + + ids[i] = k; + //将当前找到的最大的值重置为-1代表已经找到过了 + copyArray[k] = -1; + } + + return ids; + } + + /** + * 根据边的连通性去深度优先搜索所有的小聚簇 + */ + private void searchSmallCluster(){ + int currentId = 0; + Point p; + Cluster cluster; + initClusters = new ArrayList<>(); + ArrayList pointList; + + // 以id的方式逐个去dfs搜索 + for (int i = 0; i < pointNum; i++) { + p = totalPoints.get(i); + + if (p.isVisited) { + continue; + } + + pointList = new ArrayList<>(); + pointList.add(p); + recusiveDfsSearch(p, -1, pointList); + + cluster = new Cluster(currentId, pointList); + initClusters.add(cluster); + + currentId++; + } + } + + /** + * 深度优先的方式找到边所连接着的所有坐标点 + * + * @param p 当前搜索的起点 + * @param parentId 此点的父坐标点 + * @param pList 坐标点列表 + */ + private void recusiveDfsSearch(Point p, int parentId, ArrayList pList){ + int id1; + int id2; + Point newPoint; + + if (p.isVisited) { + return; + } + + p.isVisited = true; + for (int j = 0; j < pointNum; j++) { + id1 = p.id; + id2 = j; + + if (edges[id1][id2] == 1 && id2 != parentId) { + newPoint = totalPoints.get(j); + pList.add(newPoint); + // 以此点为起点,继续递归搜索 + recusiveDfsSearch(newPoint, id1, pList); + } + } + } + + /** + * 计算连接2个簇的边的权重 + * + * @param c1 聚簇1 + * @param c2 聚簇2 + */ + private double calEC(Cluster c1, Cluster c2){ + double resultEC = 0; + ArrayList connectedEdges; + + connectedEdges = c1.calNearestEdge(c2, 2); + + // 计算连接2部分的边的权重和 + for (int[] array : connectedEdges) { + resultEC += weights[array[0]][array[1]]; + } + + return resultEC; + } + + /** + * 计算2个簇的相对互连性 + * + * @param c1 聚簇1 + * @param c2 聚簇2 + */ + private double calRI(Cluster c1, Cluster c2){ + double RI; + double EC1; + double EC2; + double EC1To2; + + EC1 = c1.calEC(); + EC2 = c2.calEC(); + EC1To2 = calEC(c1, c2); + + RI = 2 * EC1To2 / (EC1 + EC2); + + return RI; + } + + /** + * 计算簇的相对近似度 + * + * @param c1 簇1 + * @param c2 簇2 + */ + private double calRC(Cluster c1, Cluster c2){ + double RC; + double EC1; + double EC2; + double EC1To2; + int pNum1 = c1.points.size(); + int pNum2 = c2.points.size(); + + EC1 = c1.calEC(); + EC2 = c2.calEC(); + EC1To2 = calEC(c1, c2); + + RC = EC1To2 * (pNum1 + pNum2) / (pNum2 * EC1 + pNum1 * EC2); + + return RC; + } + + /** + * 计算度量函数的值 + * + * @param c1 簇1 + * @param c2 簇2 + * @param alpha 幂的参数值 + */ + private double calMetricfunction(Cluster c1, Cluster c2, int alpha){ + // 度量函数值 + double metricValue; + double RI; + double RC; + + RI = calRI(c1, c2); + RC = calRC(c1, c2); + // 如果alpha大于1,则更重视相对近似性,如果alpha逍遥于1,注重相对互连性 + metricValue = RI * Math.pow(RC, alpha); + + return metricValue; + } + + /** + * 输出聚簇列 + * + * @param clusterList 输出聚簇列 + */ + private void printClusters(ArrayList clusterList){ + int i = 1; + + for (Cluster cluster : clusterList) { + System.out.print("聚簇" + i + ":"); + for (Point p : cluster.points) { + System.out.print(MessageFormat.format("({0}, {1}) ", p.x, p.y)); + } + System.out.println(); + i++; + } + + } + + /** + * 创建聚簇 + */ + void buildCluster(){ + // 第一阶段形成小聚簇 + connectedGraph(); + searchSmallCluster(); + System.out.println("第一阶段形成的小簇集合:"); + printClusters(initClusters); + + // 第二阶段根据RI和RC的值合并小聚簇形成最终结果聚簇 + combineSubClusters(); + System.out.println("最终的聚簇集合:"); + printClusters(resultClusters); + } } diff --git a/Others/DataMining_Chameleon/Client.java b/Others/DataMining_Chameleon/Client.java index 254f760..a7c38c1 100644 --- a/Others/DataMining_Chameleon/Client.java +++ b/Others/DataMining_Chameleon/Client.java @@ -1,19 +1,19 @@ -package DataMining_Chameleon; +package Others.DataMining_Chameleon; /** * Chameleon(变色龙)两阶段聚类算法 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\graphData.txt"; - //k-近邻的k设置 - int k = 1; - //度量函数阈值 - double minMetric = 0.1; - - ChameleonTool tool = new ChameleonTool(filePath, k, minMetric); - tool.buildCluster(); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_Chameleon/graphData.txt"; + //k-近邻的k设置 + int k = 1; + //度量函数阈值 + double minMetric = 0.1; + + ChameleonTool tool = new ChameleonTool(filePath, k, minMetric); + tool.buildCluster(); + } } diff --git a/Others/DataMining_Chameleon/Cluster.java b/Others/DataMining_Chameleon/Cluster.java index 42e1f94..e1191c5 100644 --- a/Others/DataMining_Chameleon/Cluster.java +++ b/Others/DataMining_Chameleon/Cluster.java @@ -1,119 +1,115 @@ -package DataMining_Chameleon; +package Others.DataMining_Chameleon; import java.util.ArrayList; /** * 聚簇类 - * - * @author lyq - * + * + * @author Qstar */ -public class Cluster implements Cloneable{ - //簇唯一id标识号 - int id; - // 聚簇内的坐标点集合 - ArrayList points; - // 聚簇内的所有边的权重和 - double weightSum = 0; - - public Cluster(int id, ArrayList points) { - this.id = id; - this.points = points; - } - - /** - * 计算聚簇的内部的边权重和 - * - * @return - */ - public double calEC() { - int id1 = 0; - int id2 = 0; - weightSum = 0; - - for (Point p1 : points) { - for (Point p2 : points) { - id1 = p1.id; - id2 = p2.id; - - // 为了避免重复计算,取id1小的对应大的 - if (id1 < id2 && ChameleonTool.edges[id1][id2] == 1) { - weightSum += ChameleonTool.weights[id1][id2]; - } - } - } - - return weightSum; - } - - /** - * 计算2个簇之间最近的n条边 - * - * @param otherCluster - * 待比较的簇 - * @param n - * 最近的边的数目 - * @return - */ - public ArrayList calNearestEdge(Cluster otherCluster, int n){ - int count = 0; - double distance = 0; - double minDistance = Integer.MAX_VALUE; - Point point1 = null; - Point point2 = null; - ArrayList edgeList = new ArrayList<>(); - ArrayList pointList1 = (ArrayList) points.clone(); - ArrayList pointList2 = null; - Cluster c2 = null; - - try { - c2 = (Cluster) otherCluster.clone(); - pointList2 = c2.points; - } catch (CloneNotSupportedException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } - - int[] tempEdge; - // 循环计算出每次的最近距离 - while (count < n) { - tempEdge = new int[2]; - minDistance = Integer.MAX_VALUE; - - for (Point p1 : pointList1) { - for (Point p2 : pointList2) { - distance = p1.ouDistance(p2); - if (distance < minDistance) { - point1 = p1; - point2 = p2; - tempEdge[0] = p1.id; - tempEdge[1] = p2.id; - - minDistance = distance; - } - } - } - - pointList1.remove(point1); - pointList2.remove(point2); - edgeList.add(tempEdge); - count++; - } - - return edgeList; - } - - @Override - protected Object clone() throws CloneNotSupportedException { - // TODO Auto-generated method stub - - //引用需要再次复制,实现深拷贝 - ArrayList pointList = (ArrayList) this.points.clone(); - Cluster cluster = new Cluster(id, pointList); - - return cluster; - } - - +class Cluster implements Cloneable { + //簇唯一id标识号 + int id; + // 聚簇内的坐标点集合 + ArrayList points; + // 聚簇内的所有边的权重和 + private double weightSum = 0; + + Cluster(int id, ArrayList points){ + this.id = id; + this.points = points; + } + + /** + * 计算聚簇的内部的边权重和 + */ + double calEC(){ + int id1; + int id2; + weightSum = 0; + + for (Point p1 : points) { + for (Point p2 : points) { + id1 = p1.id; + id2 = p2.id; + + // 为了避免重复计算,取id1小的对应大的 + if (id1 < id2 && ChameleonTool.edges[id1][id2] == 1) { + weightSum += ChameleonTool.weights[id1][id2]; + } + } + } + + return weightSum; + } + + /** + * 计算2个簇之间最近的n条边 + * + * @param otherCluster 待比较的簇 + * @param n 最近的边的数目 + */ + ArrayList calNearestEdge(Cluster otherCluster, int n){ + int count = 0; + double distance; + double minDistance; + Point point1 = null; + Point point2 = null; + ArrayList edgeList = new ArrayList<>(); + ArrayList pointList1 = (ArrayList) points.clone(); + ArrayList pointList2 = null; + Cluster c2; + + try { + c2 = (Cluster) otherCluster.clone(); + pointList2 = c2.points; + } catch (CloneNotSupportedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + + int[] tempEdge; + // 循环计算出每次的最近距离 + while (count < n) { + tempEdge = new int[2]; + minDistance = Integer.MAX_VALUE; + + for (Point p1 : pointList1) { + if (pointList2 != null) { + for (Point p2 : pointList2) { + distance = p1.ouDistance(p2); + if (distance < minDistance) { + point1 = p1; + point2 = p2; + tempEdge[0] = p1.id; + tempEdge[1] = p2.id; + + minDistance = distance; + } + } + } + } + + pointList1.remove(point1); + if (pointList2 != null) { + pointList2.remove(point2); + } + edgeList.add(tempEdge); + count++; + } + + return edgeList; + } + + @Override + protected Object clone() throws CloneNotSupportedException{ + // TODO Auto-generated method stub + super.clone(); + //引用需要再次复制,实现深拷贝 + ArrayList pointList = (ArrayList) this.points.clone(); + + return new Cluster(id, pointList); + } + } diff --git a/Others/DataMining_Chameleon/Point.java b/Others/DataMining_Chameleon/Point.java index 2a3b8cc..a5c55f4 100644 --- a/Others/DataMining_Chameleon/Point.java +++ b/Others/DataMining_Chameleon/Point.java @@ -1,59 +1,56 @@ -package DataMining_Chameleon; +package Others.DataMining_Chameleon; +import java.util.Objects; /** * 坐标点类 - * @author lyq * + * @author Qstar */ -public class Point{ - //坐标点id号,id号唯一 - int id; - //坐标横坐标 - Integer x; - //坐标纵坐标 - Integer y; - //是否已经被访问过 - boolean isVisited; - - public Point(String id, String x, String y){ - this.id = Integer.parseInt(id); - this.x = Integer.parseInt(x); - this.y = Integer.parseInt(y); - } - - /** - * 计算当前点与制定点之间的欧式距离 - * - * @param p - * 待计算聚类的p点 - * @return - */ - public double ouDistance(Point p) { - double distance = 0; - - distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y) - * (this.y - p.y); - distance = Math.sqrt(distance); - - return distance; - } - - /** - * 判断2个坐标点是否为用个坐标点 - * - * @param p - * 待比较坐标点 - * @return - */ - public boolean isTheSame(Point p) { - boolean isSamed = false; - - if (this.x == p.x && this.y == p.y) { - isSamed = true; - } - - return isSamed; - } +public class Point { + //坐标点id号,id号唯一 + int id; + //坐标横坐标 + Integer x; + //坐标纵坐标 + Integer y; + //是否已经被访问过 + boolean isVisited; + + public Point(String id, String x, String y){ + this.id = Integer.parseInt(id); + this.x = Integer.parseInt(x); + this.y = Integer.parseInt(y); + } + + /** + * 计算当前点与制定点之间的欧式距离 + * + * @param p 待计算聚类的p点 + */ + double ouDistance(Point p){ + double distance; + + distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y) + * (this.y - p.y); + distance = Math.sqrt(distance); + + return distance; + } + + /** + * 判断2个坐标点是否为用个坐标点 + * + * @param p 待比较坐标点 + */ + public boolean isTheSame(Point p){ + boolean isSamed = false; + + if (Objects.equals(this.x, p.x) && Objects.equals(this.y, p.y)) { + isSamed = true; + } + + return isSamed; + } } diff --git a/Others/DataMining_DBSCAN/Client.java b/Others/DataMining_DBSCAN/Client.java index f3d810c..26357df 100644 --- a/Others/DataMining_DBSCAN/Client.java +++ b/Others/DataMining_DBSCAN/Client.java @@ -1,19 +1,19 @@ -package DataMining_DBSCAN; +package Others.DataMining_DBSCAN; /** * Dbscan基于密度的聚类算法测试类 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - //簇扫描半径 - double eps = 3; - //最小包含点数阈值 - int minPts = 3; - - DBSCANTool tool = new DBSCANTool(filePath, eps, minPts); - tool.dbScanCluster(); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_DBSCAN/input.txt"; + //簇扫描半径 + double eps = 3; + //最小包含点数阈值 + int minPts = 3; + + DBSCANTool tool = new DBSCANTool(filePath, eps, minPts); + tool.dbScanCluster(); + } } diff --git a/Others/DataMining_DBSCAN/DBSCANTool.java b/Others/DataMining_DBSCAN/DBSCANTool.java index 27f2f8e..944d94d 100644 --- a/Others/DataMining_DBSCAN/DBSCANTool.java +++ b/Others/DataMining_DBSCAN/DBSCANTool.java @@ -1,4 +1,4 @@ -package DataMining_DBSCAN; +package Others.DataMining_DBSCAN; import java.io.BufferedReader; import java.io.File; @@ -9,201 +9,196 @@ /** * DBSCAN基于密度聚类算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class DBSCANTool { - // 测试数据文件地址 - private String filePath; - // 簇扫描半径 - private double eps; - // 最小包含点数阈值 - private int minPts; - // 所有的数据坐标点 - private ArrayList totalPoints; - // 聚簇结果 - private ArrayList> resultClusters; - //噪声数据 - private ArrayList noisePoint; - - public DBSCANTool(String filePath, double eps, int minPts) { - this.filePath = filePath; - this.eps = eps; - this.minPts = minPts; - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - public void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - Point p; - totalPoints = new ArrayList<>(); - for (String[] array : dataArray) { - p = new Point(array[0], array[1]); - totalPoints.add(p); - } - } - - /** - * 递归的寻找聚簇 - * - * @param pointList - * 当前的点列表 - * @param parentCluster - * 父聚簇 - */ - private void recursiveCluster(Point point, ArrayList parentCluster) { - double distance = 0; - ArrayList cluster; - - // 如果已经访问过了,则跳过 - if (point.isVisited) { - return; - } - - point.isVisited = true; - cluster = new ArrayList<>(); - for (Point p2 : totalPoints) { - // 过滤掉自身的坐标点 - if (point.isTheSame(p2)) { - continue; - } - - distance = point.ouDistance(p2); - if (distance <= eps) { - // 如果聚类小于给定的半径,则加入簇中 - cluster.add(p2); - } - } - - if (cluster.size() >= minPts) { - // 将自己也加入到聚簇中 - cluster.add(point); - // 如果附近的节点个数超过最下值,则加入到父聚簇中,同时去除重复的点 - addCluster(parentCluster, cluster); - - for (Point p : cluster) { - recursiveCluster(p, parentCluster); - } - } - } - - /** - * 往父聚簇中添加局部簇坐标点 - * - * @param parentCluster - * 原始父聚簇坐标点 - * @param cluster - * 待合并的聚簇 - */ - private void addCluster(ArrayList parentCluster, - ArrayList cluster) { - boolean isCotained = false; - ArrayList addPoints = new ArrayList<>(); - - for (Point p : cluster) { - isCotained = false; - for (Point p2 : parentCluster) { - if (p.isTheSame(p2)) { - isCotained = true; - break; - } - } - - if (!isCotained) { - addPoints.add(p); - } - } - - parentCluster.addAll(addPoints); - } - - /** - * dbScan算法基于密度的聚类 - */ - public void dbScanCluster() { - ArrayList cluster = null; - resultClusters = new ArrayList<>(); - noisePoint = new ArrayList<>(); - - for (Point p : totalPoints) { - if(p.isVisited){ - continue; - } - - cluster = new ArrayList<>(); - recursiveCluster(p, cluster); - - if (cluster.size() > 0) { - resultClusters.add(cluster); - }else{ - noisePoint.add(p); - } - } - removeFalseNoise(); - - printClusters(); - } - - /** - * 移除被错误分类的噪声点数据 - */ - private void removeFalseNoise(){ - ArrayList totalCluster = new ArrayList<>(); - ArrayList deletePoints = new ArrayList<>(); - - //将聚簇合并 - for(ArrayList list: resultClusters){ - totalCluster.addAll(list); - } - - for(Point p: noisePoint){ - for(Point p2: totalCluster){ - if(p2.isTheSame(p)){ - deletePoints.add(p); - } - } - } - - noisePoint.removeAll(deletePoints); - } - - /** - * 输出聚类结果 - */ - private void printClusters() { - int i = 1; - for (ArrayList pList : resultClusters) { - System.out.print("聚簇" + (i++) + ":"); - for (Point p : pList) { - System.out.print(MessageFormat.format("({0},{1}) ", p.x, p.y)); - } - System.out.println(); - } - - System.out.println(); - System.out.print("噪声数据:"); - for (Point p : noisePoint) { - System.out.print(MessageFormat.format("({0},{1}) ", p.x, p.y)); - } - System.out.println(); - } +class DBSCANTool { + // 测试数据文件地址 + private String filePath; + // 簇扫描半径 + private double eps; + // 最小包含点数阈值 + private int minPts; + // 所有的数据坐标点 + private ArrayList totalPoints; + // 聚簇结果 + private ArrayList> resultClusters; + //噪声数据 + private ArrayList noisePoint; + + DBSCANTool(String filePath, double eps, int minPts){ + this.filePath = filePath; + this.eps = eps; + this.minPts = minPts; + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + public void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + Point p; + totalPoints = new ArrayList<>(); + for (String[] array : dataArray) { + p = new Point(array[0], array[1]); + totalPoints.add(p); + } + } + + /** + * 递归的寻找聚簇 + * + * @param point 当前的点列表 + * @param parentCluster 父聚簇 + */ + private void recursiveCluster(Point point, ArrayList parentCluster){ + double distance; + ArrayList cluster; + + // 如果已经访问过了,则跳过 + if (point.isVisited) { + return; + } + + point.isVisited = true; + cluster = new ArrayList<>(); + for (Point p2 : totalPoints) { + // 过滤掉自身的坐标点 + if (point.isTheSame(p2)) { + continue; + } + + distance = point.ouDistance(p2); + if (distance <= eps) { + // 如果聚类小于给定的半径,则加入簇中 + cluster.add(p2); + } + } + + if (cluster.size() >= minPts) { + // 将自己也加入到聚簇中 + cluster.add(point); + // 如果附近的节点个数超过最下值,则加入到父聚簇中,同时去除重复的点 + addCluster(parentCluster, cluster); + + for (Point p : cluster) { + recursiveCluster(p, parentCluster); + } + } + } + + /** + * 往父聚簇中添加局部簇坐标点 + * + * @param parentCluster 原始父聚簇坐标点 + * @param cluster 待合并的聚簇 + */ + private void addCluster(ArrayList parentCluster, + ArrayList cluster){ + boolean isCotained; + ArrayList addPoints = new ArrayList<>(); + + for (Point p : cluster) { + isCotained = false; + for (Point p2 : parentCluster) { + if (p.isTheSame(p2)) { + isCotained = true; + break; + } + } + + if (!isCotained) { + addPoints.add(p); + } + } + + parentCluster.addAll(addPoints); + } + + /** + * dbScan算法基于密度的聚类 + */ + void dbScanCluster(){ + ArrayList cluster; + resultClusters = new ArrayList<>(); + noisePoint = new ArrayList<>(); + + for (Point p : totalPoints) { + if (p.isVisited) { + continue; + } + + cluster = new ArrayList<>(); + recursiveCluster(p, cluster); + + if (cluster.size() > 0) { + resultClusters.add(cluster); + } else { + noisePoint.add(p); + } + } + removeFalseNoise(); + + printClusters(); + } + + /** + * 移除被错误分类的噪声点数据 + */ + private void removeFalseNoise(){ + ArrayList totalCluster = new ArrayList<>(); + ArrayList deletePoints = new ArrayList<>(); + + //将聚簇合并 + for (ArrayList list : resultClusters) { + totalCluster.addAll(list); + } + + for (Point p : noisePoint) { + for (Point p2 : totalCluster) { + if (p2.isTheSame(p)) { + deletePoints.add(p); + } + } + } + + noisePoint.removeAll(deletePoints); + } + + /** + * 输出聚类结果 + */ + private void printClusters(){ + int i = 1; + for (ArrayList pList : resultClusters) { + System.out.print("聚簇" + (i++) + ":"); + for (Point p : pList) { + System.out.print(MessageFormat.format("({0},{1}) ", p.x, p.y)); + } + System.out.println(); + } + + System.out.println(); + System.out.print("噪声数据:"); + for (Point p : noisePoint) { + System.out.print(MessageFormat.format("({0},{1}) ", p.x, p.y)); + } + System.out.println(); + } } diff --git a/Others/DataMining_DBSCAN/Point.java b/Others/DataMining_DBSCAN/Point.java index f773bad..14076c4 100644 --- a/Others/DataMining_DBSCAN/Point.java +++ b/Others/DataMining_DBSCAN/Point.java @@ -1,56 +1,51 @@ -package DataMining_DBSCAN; +package Others.DataMining_DBSCAN; /** * 坐标点类 - * - * @author lyq - * + * + * @author Qstar */ public class Point { - // 坐标点横坐标 - int x; - // 坐标点纵坐标 - int y; - // 此节点是否已经被访问过 - boolean isVisited; - - public Point(String x, String y) { - this.x = (Integer.parseInt(x)); - this.y = (Integer.parseInt(y)); - this.isVisited = false; - } - - /** - * 计算当前点与制定点之间的欧式距离 - * - * @param p - * 待计算聚类的p点 - * @return - */ - public double ouDistance(Point p) { - double distance = 0; - - distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y) - * (this.y - p.y); - distance = Math.sqrt(distance); - - return distance; - } - - /** - * 判断2个坐标点是否为用个坐标点 - * - * @param p - * 待比较坐标点 - * @return - */ - public boolean isTheSame(Point p) { - boolean isSamed = false; - - if (this.x == p.x && this.y == p.y) { - isSamed = true; - } - - return isSamed; - } + // 坐标点横坐标 + int x; + // 坐标点纵坐标 + int y; + // 此节点是否已经被访问过 + boolean isVisited; + + public Point(String x, String y){ + this.x = (Integer.parseInt(x)); + this.y = (Integer.parseInt(y)); + this.isVisited = false; + } + + /** + * 计算当前点与制定点之间的欧式距离 + * + * @param p 待计算聚类的p点 + */ + double ouDistance(Point p){ + double distance; + + distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y) + * (this.y - p.y); + distance = Math.sqrt(distance); + + return distance; + } + + /** + * 判断2个坐标点是否为用个坐标点 + * + * @param p 待比较坐标点 + */ + boolean isTheSame(Point p){ + boolean isSamed = false; + + if (this.x == p.x && this.y == p.y) { + isSamed = true; + } + + return isSamed; + } } diff --git a/Others/DataMining_GA/Client.java b/Others/DataMining_GA/Client.java index eff2dbc..ebb0fa5 100644 --- a/Others/DataMining_GA/Client.java +++ b/Others/DataMining_GA/Client.java @@ -1,19 +1,19 @@ -package GA; +package Others.DataMining_GA; /** * Genetic遗传算法测试类 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - //变量最小值和最大值 - int minNum = 1; - int maxNum = 7; - //初始群体规模 - int initSetsNum = 4; - - GATool tool = new GATool(minNum, maxNum, initSetsNum); - tool.geneticCal(); - } + public static void main(String[] args){ + //变量最小值和最大值 + int minNum = 1; + int maxNum = 7; + //初始群体规模 + int initSetsNum = 4; + + GATool tool = new GATool(minNum, maxNum, initSetsNum); + tool.geneticCal(); + } } diff --git a/Others/DataMining_GA/GATool.java b/Others/DataMining_GA/GATool.java index 567c393..ac56346 100644 --- a/Others/DataMining_GA/GATool.java +++ b/Others/DataMining_GA/GATool.java @@ -1,361 +1,347 @@ -package GA; +package Others.DataMining_GA; import java.util.ArrayList; +import java.util.Arrays; import java.util.Random; /** * 遗传算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class GATool { - // 变量最小值 - private int minNum; - // 变量最大值 - private int maxNum; - // 单个变量的编码位数 - private int codeNum; - // 初始种群的数量 - private int initSetsNum; - // 随机数生成器 - private Random random; - // 初始群体 - private ArrayList initSets; - - public GATool(int minNum, int maxNum, int initSetsNum) { - this.minNum = minNum; - this.maxNum = maxNum; - this.initSetsNum = initSetsNum; - - this.random = new Random(); - produceInitSets(); - } - - /** - * 产生初始化群体 - */ - private void produceInitSets() { - this.codeNum = 0; - int num = maxNum; - int[] array; - - initSets = new ArrayList<>(); - - // 确定编码位数 - while (num != 0) { - codeNum++; - num /= 2; - } - - for (int i = 0; i < initSetsNum; i++) { - array = produceInitCode(); - initSets.add(array); - } - } - - /** - * 产生初始个体的编码 - * - * @return - */ - private int[] produceInitCode() { - int num = 0; - int num2 = 0; - int[] tempArray; - int[] array1; - int[] array2; - - tempArray = new int[2 * codeNum]; - array1 = new int[codeNum]; - array2 = new int[codeNum]; - - num = 0; - while (num < minNum || num > maxNum) { - num = random.nextInt(maxNum) + 1; - } - numToBinaryArray(array1, num); - - while (num2 < minNum || num2 > maxNum) { - num2 = random.nextInt(maxNum) + 1; - } - numToBinaryArray(array2, num2); - - // 组成总的编码 - for (int i = 0, k = 0; i < tempArray.length; i++, k++) { - if (k < codeNum) { - tempArray[i] = array1[k]; - } else { - tempArray[i] = array2[k - codeNum]; - } - } - - return tempArray; - } - - /** - * 选择操作,把适值较高的个体优先遗传到下一代 - * - * @param initCodes - * 初始个体编码 - * @return - */ - private ArrayList selectOperate(ArrayList initCodes) { - double randomNum = 0; - double sumAdaptiveValue = 0; - ArrayList resultCodes = new ArrayList<>(); - double[] adaptiveValue = new double[initSetsNum]; - - for (int i = 0; i < initSetsNum; i++) { - adaptiveValue[i] = calCodeAdaptiveValue(initCodes.get(i)); - sumAdaptiveValue += adaptiveValue[i]; - } - - // 转成概率的形式,做归一化操作 - for (int i = 0; i < initSetsNum; i++) { - adaptiveValue[i] = adaptiveValue[i] / sumAdaptiveValue; - } - - for (int i = 0; i < initSetsNum; i++) { - randomNum = random.nextInt(100) + 1; - randomNum = randomNum / 100; - //因为1.0是无法判断到的,,总和会无限接近1.0取为0.99做判断 - if(randomNum == 1){ - randomNum = randomNum - 0.01; - } - - sumAdaptiveValue = 0; - // 确定区间 - for (int j = 0; j < initSetsNum; j++) { - if (randomNum > sumAdaptiveValue - && randomNum <= sumAdaptiveValue + adaptiveValue[j]) { - //采用拷贝的方式避免引用重复 - resultCodes.add(initCodes.get(j).clone()); - break; - } else { - sumAdaptiveValue += adaptiveValue[j]; - } - } - } - - return resultCodes; - } - - /** - * 交叉运算 - * - * @param selectedCodes - * 上步骤的选择后的编码 - * @return - */ - private ArrayList crossOperate(ArrayList selectedCodes) { - int randomNum = 0; - // 交叉点 - int crossPoint = 0; - ArrayList resultCodes = new ArrayList<>(); - // 随机编码队列,进行随机交叉配对 - ArrayList randomCodeSeqs = new ArrayList<>(); - - // 进行随机排序 - while (selectedCodes.size() > 0) { - randomNum = random.nextInt(selectedCodes.size()); - - randomCodeSeqs.add(selectedCodes.get(randomNum)); - selectedCodes.remove(randomNum); - } - - int temp = 0; - int[] array1; - int[] array2; - // 进行两两交叉运算 - for (int i = 1; i < randomCodeSeqs.size(); i++) { - if (i % 2 == 1) { - array1 = randomCodeSeqs.get(i - 1); - array2 = randomCodeSeqs.get(i); - crossPoint = random.nextInt(2 * codeNum - 1) + 1; - - // 进行交叉点位置后的编码调换 - for (int j = 0; j < 2 * codeNum; j++) { - if (j >= crossPoint) { - temp = array1[j]; - array1[j] = array2[j]; - array2[j] = temp; - } - } - - // 加入到交叉运算结果中 - resultCodes.add(array1); - resultCodes.add(array2); - } - } - - return resultCodes; - } - - /** - * 变异操作 - * - * @param crossCodes - * 交叉运算后的结果 - * @return - */ - private ArrayList variationOperate(ArrayList crossCodes) { - // 变异点 - int variationPoint = 0; - ArrayList resultCodes = new ArrayList<>(); - - for (int[] array : crossCodes) { - variationPoint = random.nextInt(codeNum * 2); - - for (int i = 0; i < array.length; i++) { - // 变异点进行变异 - if (i == variationPoint) { - array[i] = (array[i] == 0 ? 1 : 0); - break; - } - } - - resultCodes.add(array); - } - - return resultCodes; - } - - /** - * 数字转为二进制形式 - * - * @param binaryArray - * 转化后的二进制数组形式 - * @param num - * 待转化数字 - */ - private void numToBinaryArray(int[] binaryArray, int num) { - int index = 0; - int temp = 0; - while (num != 0) { - binaryArray[index] = num % 2; - index++; - num /= 2; - } - - //进行数组前和尾部的调换 - for(int i=0; i=0 ; i--, k++) { - if (binaryArray[i] == 1) { - result += Math.pow(2, k); - } - } - - return result; - } - - /** - * 计算个体编码的适值 - * - * @param codeArray - */ - private int calCodeAdaptiveValue(int[] codeArray) { - int result = 0; - int x1 = 0; - int x2 = 0; - int[] array1 = new int[codeNum]; - int[] array2 = new int[codeNum]; - - for (int i = 0, k = 0; i < codeArray.length; i++, k++) { - if (k < codeNum) { - array1[k] = codeArray[i]; - } else { - array2[k - codeNum] = codeArray[i]; - } - } - - // 进行适值的叠加 - x1 = binaryArrayToNum(array1); - x2 = binaryArrayToNum(array2); - result = x1 * x1 + x2 * x2; - - return result; - } - - /** - * 进行遗传算法计算 - */ - public void geneticCal() { - // 最大适值 - int maxFitness; - //迭代遗传次数 - int loopCount = 0; - boolean canExit = false; - ArrayList initCodes; - ArrayList selectedCodes; - ArrayList crossedCodes; - ArrayList variationCodes; - - int[] maxCode = new int[2*codeNum]; - //计算最大适值 - for(int i=0; i<2*codeNum; i++){ - maxCode[i] = 1; - } - maxFitness = calCodeAdaptiveValue(maxCode); - - initCodes = initSets; - while (true) { - for (int[] array : initCodes) { - // 遗传迭代的终止条件为存在编码达到最大适值 - if (maxFitness == calCodeAdaptiveValue(array)) { - canExit = true; - break; - } - } - - if (canExit) { - break; - } - - selectedCodes = selectOperate(initCodes); - crossedCodes = crossOperate(selectedCodes); - variationCodes = variationOperate(crossedCodes); - initCodes = variationCodes; - - loopCount++; - } - - System.out.println("总共遗传进化了" + loopCount +"次" ); - printFinalCodes(initCodes); - } - - /** - * 输出最后的编码集 - * - * @param finalCodes - * 最后的结果编码 - */ - private void printFinalCodes(ArrayList finalCodes) { - int j = 0; - - for (int[] array : finalCodes) { - System.out.print("个体" + (j + 1) + ":"); - for (int i = 0; i < array.length; i++) { - System.out.print(array[i]); - } - System.out.println(); - j++; - } - } +class GATool { + // 变量最小值 + private int minNum; + // 变量最大值 + private int maxNum; + // 单个变量的编码位数 + private int codeNum; + // 初始种群的数量 + private int initSetsNum; + // 随机数生成器 + private Random random; + // 初始群体 + private ArrayList initSets; + + GATool(int minNum, int maxNum, int initSetsNum){ + this.minNum = minNum; + this.maxNum = maxNum; + this.initSetsNum = initSetsNum; + + this.random = new Random(); + produceInitSets(); + } + + /** + * 产生初始化群体 + */ + private void produceInitSets(){ + this.codeNum = 0; + int num = maxNum; + int[] array; + + initSets = new ArrayList<>(); + + // 确定编码位数 + while (num != 0) { + codeNum++; + num /= 2; + } + + for (int i = 0; i < initSetsNum; i++) { + array = produceInitCode(); + initSets.add(array); + } + } + + /** + * 产生初始个体的编码 + */ + private int[] produceInitCode(){ + int num; + int num2 = 0; + int[] tempArray; + int[] array1; + int[] array2; + + tempArray = new int[2 * codeNum]; + array1 = new int[codeNum]; + array2 = new int[codeNum]; + + num = 0; + while (num < minNum || num > maxNum) { + num = random.nextInt(maxNum) + 1; + } + numToBinaryArray(array1, num); + + while (num2 < minNum || num2 > maxNum) { + num2 = random.nextInt(maxNum) + 1; + } + numToBinaryArray(array2, num2); + + // 组成总的编码 + for (int i = 0, k = 0; i < tempArray.length; i++, k++) { + if (k < codeNum) { + tempArray[i] = array1[k]; + } else { + tempArray[i] = array2[k - codeNum]; + } + } + + return tempArray; + } + + /** + * 选择操作,把适值较高的个体优先遗传到下一代 + * + * @param initCodes 初始个体编码 + */ + private ArrayList selectOperate(ArrayList initCodes){ + double randomNum; + double sumAdaptiveValue = 0; + ArrayList resultCodes = new ArrayList<>(); + double[] adaptiveValue = new double[initSetsNum]; + + for (int i = 0; i < initSetsNum; i++) { + adaptiveValue[i] = calCodeAdaptiveValue(initCodes.get(i)); + sumAdaptiveValue += adaptiveValue[i]; + } + + // 转成概率的形式,做归一化操作 + for (int i = 0; i < initSetsNum; i++) { + adaptiveValue[i] = adaptiveValue[i] / sumAdaptiveValue; + } + + for (int i = 0; i < initSetsNum; i++) { + randomNum = random.nextInt(100) + 1; + randomNum = randomNum / 100; + //因为1.0是无法判断到的,,总和会无限接近1.0取为0.99做判断 + if (randomNum == 1) { + randomNum = randomNum - 0.01; + } + + sumAdaptiveValue = 0; + // 确定区间 + for (int j = 0; j < initSetsNum; j++) { + if (randomNum > sumAdaptiveValue + && randomNum <= sumAdaptiveValue + adaptiveValue[j]) { + //采用拷贝的方式避免引用重复 + resultCodes.add(initCodes.get(j).clone()); + break; + } else { + sumAdaptiveValue += adaptiveValue[j]; + } + } + } + + return resultCodes; + } + + /** + * 交叉运算 + * + * @param selectedCodes 上步骤的选择后的编码 + */ + private ArrayList crossOperate(ArrayList selectedCodes){ + int randomNum; + // 交叉点 + int crossPoint; + ArrayList resultCodes = new ArrayList<>(); + // 随机编码队列,进行随机交叉配对 + ArrayList randomCodeSeqs = new ArrayList<>(); + + // 进行随机排序 + while (selectedCodes.size() > 0) { + randomNum = random.nextInt(selectedCodes.size()); + + randomCodeSeqs.add(selectedCodes.get(randomNum)); + selectedCodes.remove(randomNum); + } + + int temp; + int[] array1; + int[] array2; + // 进行两两交叉运算 + for (int i = 1; i < randomCodeSeqs.size(); i++) { + if (i % 2 == 1) { + array1 = randomCodeSeqs.get(i - 1); + array2 = randomCodeSeqs.get(i); + crossPoint = random.nextInt(2 * codeNum - 1) + 1; + + // 进行交叉点位置后的编码调换 + for (int j = 0; j < 2 * codeNum; j++) { + if (j >= crossPoint) { + temp = array1[j]; + array1[j] = array2[j]; + array2[j] = temp; + } + } + + // 加入到交叉运算结果中 + resultCodes.add(array1); + resultCodes.add(array2); + } + } + + return resultCodes; + } + + /** + * 变异操作 + * + * @param crossCodes 交叉运算后的结果 + */ + private ArrayList variationOperate(ArrayList crossCodes){ + // 变异点 + int variationPoint; + ArrayList resultCodes = new ArrayList<>(); + + for (int[] array : crossCodes) { + variationPoint = random.nextInt(codeNum * 2); + + for (int i = 0; i < array.length; i++) { + // 变异点进行变异 + if (i == variationPoint) { + array[i] = (array[i] == 0 ? 1 : 0); + break; + } + } + + resultCodes.add(array); + } + + return resultCodes; + } + + /** + * 数字转为二进制形式 + * + * @param binaryArray 转化后的二进制数组形式 + * @param num 待转化数字 + */ + private void numToBinaryArray(int[] binaryArray, int num){ + int index = 0; + int temp; + while (num != 0) { + binaryArray[index] = num % 2; + index++; + num /= 2; + } + + //进行数组前和尾部的调换 + for (int i = 0; i < binaryArray.length / 2; i++) { + temp = binaryArray[i]; + binaryArray[i] = binaryArray[binaryArray.length - 1 - i]; + binaryArray[binaryArray.length - 1 - i] = temp; + } + } + + /** + * 二进制数组转化为数字 + * + * @param binaryArray 待转化二进制数组 + */ + private int binaryArrayToNum(int[] binaryArray){ + int result = 0; + + for (int i = binaryArray.length - 1, k = 0; i >= 0; i--, k++) { + if (binaryArray[i] == 1) { + result += Math.pow(2, k); + } + } + + return result; + } + + /** + * 计算个体编码的适值 + * + * @param codeArray 编码数组 + */ + private int calCodeAdaptiveValue(int[] codeArray){ + int result; + int x1; + int x2; + int[] array1 = new int[codeNum]; + int[] array2 = new int[codeNum]; + + for (int i = 0, k = 0; i < codeArray.length; i++, k++) { + if (k < codeNum) { + array1[k] = codeArray[i]; + } else { + array2[k - codeNum] = codeArray[i]; + } + } + + // 进行适值的叠加 + x1 = binaryArrayToNum(array1); + x2 = binaryArrayToNum(array2); + result = x1 * x1 + x2 * x2; + + return result; + } + + /** + * 进行遗传算法计算 + */ + void geneticCal(){ + // 最大适值 + int maxFitness; + //迭代遗传次数 + int loopCount = 0; + boolean canExit = false; + ArrayList initCodes; + ArrayList selectedCodes; + ArrayList crossedCodes; + ArrayList variationCodes; + + int[] maxCode = new int[2 * codeNum]; + //计算最大适值 + for (int i = 0; i < 2 * codeNum; i++) { + maxCode[i] = 1; + } + maxFitness = calCodeAdaptiveValue(maxCode); + + initCodes = initSets; + while (true) { + for (int[] array : initCodes) { + // 遗传迭代的终止条件为存在编码达到最大适值 + if (maxFitness == calCodeAdaptiveValue(array)) { + canExit = true; + break; + } + } + + if (canExit) { + break; + } + + selectedCodes = selectOperate(initCodes); + crossedCodes = crossOperate(selectedCodes); + variationCodes = variationOperate(crossedCodes); + initCodes = variationCodes; + + loopCount++; + } + + System.out.println("总共遗传进化了" + loopCount + "次"); + printFinalCodes(initCodes); + } + + /** + * 输出最后的编码集 + * + * @param finalCodes 最后的结果编码 + */ + private void printFinalCodes(ArrayList finalCodes){ + int j = 0; + + for (int[] array : finalCodes) { + System.out.print("个体" + (j + 1) + ":"); + Arrays.stream(array).forEach(System.out::print); + System.out.println(); + j++; + } + } } diff --git a/Others/DataMining_GA_Maze/Client.java b/Others/DataMining_GA_Maze/Client.java index 0cec9c9..776cab1 100644 --- a/Others/DataMining_GA_Maze/Client.java +++ b/Others/DataMining_GA_Maze/Client.java @@ -1,19 +1,18 @@ -package GA_Maze; +package Others.DataMining_GA_Maze; /** * 遗传算法在走迷宫游戏的应用 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args) { - //迷宫地图文件数据地址 - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\mapData.txt"; - //初始个体数量 - int initSetsNum = 4; - - GATool tool = new GATool(filePath, initSetsNum); - tool.goOutMaze(); - } + public static void main(String[] args){ + //迷宫地图文件数据地址 + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_GA_Maze/mapData.txt"; + //初始个体数量 + int initSetsNum = 4; + GATool tool = new GATool(filePath, initSetsNum); + tool.goOutMaze(); + } } diff --git a/Others/DataMining_GA_Maze/GATool.java b/Others/DataMining_GA_Maze/GATool.java index 39c8270..4a5b2f6 100644 --- a/Others/DataMining_GA_Maze/GATool.java +++ b/Others/DataMining_GA_Maze/GATool.java @@ -1,4 +1,4 @@ -package GA_Maze; +package Others.DataMining_GA_Maze; import java.io.BufferedReader; import java.io.File; @@ -6,447 +6,433 @@ import java.io.IOException; import java.text.MessageFormat; import java.util.ArrayList; +import java.util.Arrays; import java.util.Random; /** * 遗传算法在走迷宫游戏的应用-遗传算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class GATool { - // 迷宫出入口标记 - public static final int MAZE_ENTRANCE_POS = 1; - public static final int MAZE_EXIT_POS = 2; - // 方向对应的编码数组 - public static final int[][] MAZE_DIRECTION_CODE = new int[][] { { 0, 0 }, - { 0, 1 }, { 1, 0 }, { 1, 1 }, }; - // 坐标点方向改变 - public static final int[][] MAZE_DIRECTION_CHANGE = new int[][] { - { -1, 0 }, { 1, 0 }, { 0, -1 }, { 0, 1 }, }; - // 方向的文字描述 - public static final String[] MAZE_DIRECTION_LABEL = new String[] { "上", - "下", "左", "右" }; - - // 地图数据文件地址 - private String filePath; - // 走迷宫的最短步数 - private int stepNum; - // 初始个体的数量 - private int initSetsNum; - // 迷宫入口位置 - private int[] startPos; - // 迷宫出口位置 - private int[] endPos; - // 迷宫地图数据 - private int[][] mazeData; - // 初始个体集 - private ArrayList initSets; - // 随机数产生器 - private Random random; - - public GATool(String filePath, int initSetsNum) { - this.filePath = filePath; - this.initSetsNum = initSetsNum; - - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - public void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - int rowNum = dataArray.size(); - mazeData = new int[rowNum][rowNum]; - for (int i = 0; i < rowNum; i++) { - String[] data = dataArray.get(i); - for (int j = 0; j < data.length; j++) { - mazeData[i][j] = Integer.parseInt(data[j]); - - // 赋值入口和出口位置 - if (mazeData[i][j] == MAZE_ENTRANCE_POS) { - startPos = new int[2]; - startPos[0] = i; - startPos[1] = j; - } else if (mazeData[i][j] == MAZE_EXIT_POS) { - endPos = new int[2]; - endPos[0] = i; - endPos[1] = j; - } - } - } - - // 计算走出迷宫的最短步数 - stepNum = Math.abs(startPos[0] - endPos[0]) - + Math.abs(startPos[1] - endPos[1]); - } - - /** - * 产生初始数据集 - */ - private void produceInitSet() { - // 方向编码 - int directionCode = 0; - random = new Random(); - initSets = new ArrayList<>(); - // 每个步骤的操作需要用2位数字表示 - int[] codeNum; - - for (int i = 0; i < initSetsNum; i++) { - codeNum = new int[stepNum * 2]; - for (int j = 0; j < stepNum; j++) { - directionCode = random.nextInt(4); - codeNum[2 * j] = MAZE_DIRECTION_CODE[directionCode][0]; - codeNum[2 * j + 1] = MAZE_DIRECTION_CODE[directionCode][1]; - } - - initSets.add(codeNum); - } - } - - /** - * 选择操作,把适值较高的个体优先遗传到下一代 - * - * @param initCodes - * 初始个体编码 - * @return - */ - private ArrayList selectOperate(ArrayList initCodes) { - double randomNum = 0; - double sumFitness = 0; - ArrayList resultCodes = new ArrayList<>(); - double[] adaptiveValue = new double[initSetsNum]; - - for (int i = 0; i < initSetsNum; i++) { - adaptiveValue[i] = calFitness(initCodes.get(i)); - sumFitness += adaptiveValue[i]; - } - - // 转成概率的形式,做归一化操作 - for (int i = 0; i < initSetsNum; i++) { - adaptiveValue[i] = adaptiveValue[i] / sumFitness; - } - - for (int i = 0; i < initSetsNum; i++) { - randomNum = random.nextInt(100) + 1; - randomNum = randomNum / 100; - //因为1.0是无法判断到的,,总和会无限接近1.0取为0.99做判断 - if(randomNum == 1){ - randomNum = randomNum - 0.01; - } - - sumFitness = 0; - // 确定区间 - for (int j = 0; j < initSetsNum; j++) { - if (randomNum > sumFitness - && randomNum <= sumFitness + adaptiveValue[j]) { - // 采用拷贝的方式避免引用重复 - resultCodes.add(initCodes.get(j).clone()); - break; - } else { - sumFitness += adaptiveValue[j]; - } - } - } - - return resultCodes; - } - - /** - * 交叉运算 - * - * @param selectedCodes - * 上步骤的选择后的编码 - * @return - */ - private ArrayList crossOperate(ArrayList selectedCodes) { - int randomNum = 0; - // 交叉点 - int crossPoint = 0; - ArrayList resultCodes = new ArrayList<>(); - // 随机编码队列,进行随机交叉配对 - ArrayList randomCodeSeqs = new ArrayList<>(); - - // 进行随机排序 - while (selectedCodes.size() > 0) { - randomNum = random.nextInt(selectedCodes.size()); - - randomCodeSeqs.add(selectedCodes.get(randomNum)); - selectedCodes.remove(randomNum); - } - - int temp = 0; - int[] array1; - int[] array2; - // 进行两两交叉运算 - for (int i = 1; i < randomCodeSeqs.size(); i++) { - if (i % 2 == 1) { - array1 = randomCodeSeqs.get(i - 1); - array2 = randomCodeSeqs.get(i); - crossPoint = random.nextInt(stepNum - 1) + 1; - - // 进行交叉点位置后的编码调换 - for (int j = 0; j < 2 * stepNum; j++) { - if (j >= 2 * crossPoint) { - temp = array1[j]; - array1[j] = array2[j]; - array2[j] = temp; - } - } - - // 加入到交叉运算结果中 - resultCodes.add(array1); - resultCodes.add(array2); - } - } - - return resultCodes; - } - - /** - * 变异操作 - * - * @param crossCodes - * 交叉运算后的结果 - * @return - */ - private ArrayList variationOperate(ArrayList crossCodes) { - // 变异点 - int variationPoint = 0; - ArrayList resultCodes = new ArrayList<>(); - - for (int[] array : crossCodes) { - variationPoint = random.nextInt(stepNum); - - for (int i = 0; i < array.length; i += 2) { - // 变异点进行变异 - if (i % 2 == 0 && i / 2 == variationPoint) { - array[i] = (array[i] == 0 ? 1 : 0); - array[i + 1] = (array[i + 1] == 0 ? 1 : 0); - break; - } - } - - resultCodes.add(array); - } - - return resultCodes; - } - - /** - * 根据编码计算适值 - * - * @param code - * 当前的编码 - * @return - */ - public double calFitness(int[] code) { - double fintness = 0; - // 由编码计算所得的终点横坐标 - int endX = 0; - // 由编码计算所得的终点纵坐标 - int endY = 0; - // 基于片段所代表的行走方向 - int direction = 0; - // 临时坐标点横坐标 - int tempX = 0; - // 临时坐标点纵坐标 - int tempY = 0; - - endX = startPos[0]; - endY = startPos[1]; - for (int i = 0; i < stepNum; i++) { - direction = binaryArrayToNum(new int[] { code[2 * i], - code[2 * i + 1] }); - - // 根据方向改变数组做坐标点的改变 - tempX = endX + MAZE_DIRECTION_CHANGE[direction][0]; - tempY = endY + MAZE_DIRECTION_CHANGE[direction][1]; - - // 判断坐标点是否越界 - if (tempX >= 0 && tempX < mazeData.length && tempY >= 0 - && tempY < mazeData[0].length) { - // 判断坐标点是否走到阻碍块 - if (mazeData[tempX][tempY] != -1) { - endX = tempX; - endY = tempY; - } - } - } - - // 根据适值函数进行适值的计算 - fintness = 1.0 / (Math.abs(endX - endPos[0]) - + Math.abs(endY - endPos[1]) + 1); - - return fintness; - } - - /** - * 根据当前编码判断是否已经找到出口位置 - * - * @param code - * 经过若干次遗传的编码 - * @return - */ - private boolean ifArriveEndPos(int[] code) { - boolean isArrived = false; - // 由编码计算所得的终点横坐标 - int endX = 0; - // 由编码计算所得的终点纵坐标 - int endY = 0; - // 基于片段所代表的行走方向 - int direction = 0; - // 临时坐标点横坐标 - int tempX = 0; - // 临时坐标点纵坐标 - int tempY = 0; - - endX = startPos[0]; - endY = startPos[1]; - for (int i = 0; i < stepNum; i++) { - direction = binaryArrayToNum(new int[] { code[2 * i], - code[2 * i + 1] }); - - // 根据方向改变数组做坐标点的改变 - tempX = endX + MAZE_DIRECTION_CHANGE[direction][0]; - tempY = endY + MAZE_DIRECTION_CHANGE[direction][1]; - - // 判断坐标点是否越界 - if (tempX >= 0 && tempX < mazeData.length && tempY >= 0 - && tempY < mazeData[0].length) { - // 判断坐标点是否走到阻碍块 - if (mazeData[tempX][tempY] != -1) { - endX = tempX; - endY = tempY; - } - } - } - - if (endX == endPos[0] && endY == endPos[1]) { - isArrived = true; - } - - return isArrived; - } - - /** - * 二进制数组转化为数字 - * - * @param binaryArray - * 待转化二进制数组 - */ - private int binaryArrayToNum(int[] binaryArray) { - int result = 0; - - for (int i = binaryArray.length - 1, k = 0; i >= 0; i--, k++) { - if (binaryArray[i] == 1) { - result += Math.pow(2, k); - } - } - - return result; - } - - /** - * 进行遗传算法走出迷宫 - */ - public void goOutMaze() { - // 迭代遗传次数 - int loopCount = 0; - boolean canExit = false; - // 结果路径 - int[] resultCode = null; - ArrayList initCodes; - ArrayList selectedCodes; - ArrayList crossedCodes; - ArrayList variationCodes; - - // 产生初始数据集 - produceInitSet(); - initCodes = initSets; - - while (true) { - for (int[] array : initCodes) { - // 遗传迭代的终止条件为是否找到出口位置 - if (ifArriveEndPos(array)) { - resultCode = array; - canExit = true; - break; - } - } - - if (canExit) { - break; - } - - selectedCodes = selectOperate(initCodes); - crossedCodes = crossOperate(selectedCodes); - variationCodes = variationOperate(crossedCodes); - initCodes = variationCodes; - - loopCount++; - - //如果遗传次数超过100次,则退出 - if(loopCount >= 100){ - break; - } - } - - System.out.println("总共遗传进化了" + loopCount + "次"); - printFindedRoute(resultCode); - } - - /** - * 输出找到的路径 - * - * @param code - */ - private void printFindedRoute(int[] code) { - if(code == null){ - System.out.println("在有限的遗传进化次数内,没有找到最优路径"); - return; - } - - int tempX = startPos[0]; - int tempY = startPos[1]; - int direction = 0; - - System.out.println(MessageFormat.format( - "起始点位置({0},{1}), 出口点位置({2}, {3})", tempX, tempY, endPos[0], - endPos[1])); - - System.out.print("搜索到的结果编码:"); - for(int value: code){ - System.out.print("" + value); - } - System.out.println(); - - for (int i = 0, k = 1; i < code.length; i += 2, k++) { - direction = binaryArrayToNum(new int[] { code[i], code[i + 1] }); - - tempX += MAZE_DIRECTION_CHANGE[direction][0]; - tempY += MAZE_DIRECTION_CHANGE[direction][1]; - - System.out.println(MessageFormat.format( - "第{0}步,编码为{1}{2},向{3}移动,移动后到达({4},{5})", k, code[i], code[i+1], - MAZE_DIRECTION_LABEL[direction], tempX, tempY)); - } - } +class GATool { + // 迷宫出入口标记 + private static final int MAZE_ENTRANCE_POS = 1; + private static final int MAZE_EXIT_POS = 2; + // 方向对应的编码数组 + private static final int[][] MAZE_DIRECTION_CODE = new int[][]{{0, 0}, + {0, 1}, {1, 0}, {1, 1},}; + // 坐标点方向改变 + private static final int[][] MAZE_DIRECTION_CHANGE = new int[][]{ + {-1, 0}, {1, 0}, {0, -1}, {0, 1},}; + // 方向的文字描述 + private static final String[] MAZE_DIRECTION_LABEL = new String[]{"上", + "下", "左", "右"}; + + // 地图数据文件地址 + private String filePath; + // 走迷宫的最短步数 + private int stepNum; + // 初始个体的数量 + private int initSetsNum; + // 迷宫入口位置 + private int[] startPos; + // 迷宫出口位置 + private int[] endPos; + // 迷宫地图数据 + private int[][] mazeData; + // 初始个体集 + private ArrayList initSets; + // 随机数产生器 + private Random random; + + GATool(String filePath, int initSetsNum){ + this.filePath = filePath; + this.initSetsNum = initSetsNum; + + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + public void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + int rowNum = dataArray.size(); + mazeData = new int[rowNum][rowNum]; + for (int i = 0; i < rowNum; i++) { + String[] data = dataArray.get(i); + for (int j = 0; j < data.length; j++) { + mazeData[i][j] = Integer.parseInt(data[j]); + + // 赋值入口和出口位置 + if (mazeData[i][j] == MAZE_ENTRANCE_POS) { + startPos = new int[2]; + startPos[0] = i; + startPos[1] = j; + } else if (mazeData[i][j] == MAZE_EXIT_POS) { + endPos = new int[2]; + endPos[0] = i; + endPos[1] = j; + } + } + } + + // 计算走出迷宫的最短步数 + stepNum = Math.abs(startPos[0] - endPos[0]) + + Math.abs(startPos[1] - endPos[1]); + } + + /** + * 产生初始数据集 + */ + private void produceInitSet(){ + // 方向编码 + int directionCode; + random = new Random(); + initSets = new ArrayList<>(); + // 每个步骤的操作需要用2位数字表示 + int[] codeNum; + + for (int i = 0; i < initSetsNum; i++) { + codeNum = new int[stepNum * 2]; + for (int j = 0; j < stepNum; j++) { + directionCode = random.nextInt(4); + codeNum[2 * j] = MAZE_DIRECTION_CODE[directionCode][0]; + codeNum[2 * j + 1] = MAZE_DIRECTION_CODE[directionCode][1]; + } + + initSets.add(codeNum); + } + } + + /** + * 选择操作,把适值较高的个体优先遗传到下一代 + * + * @param initCodes 初始个体编码 + */ + private ArrayList selectOperate(ArrayList initCodes){ + double randomNum; + double sumFitness = 0; + ArrayList resultCodes = new ArrayList<>(); + double[] adaptiveValue = new double[initSetsNum]; + + for (int i = 0; i < initSetsNum; i++) { + adaptiveValue[i] = calFitness(initCodes.get(i)); + sumFitness += adaptiveValue[i]; + } + + // 转成概率的形式,做归一化操作 + for (int i = 0; i < initSetsNum; i++) { + adaptiveValue[i] = adaptiveValue[i] / sumFitness; + } + + for (int i = 0; i < initSetsNum; i++) { + randomNum = random.nextInt(100) + 1; + randomNum = randomNum / 100; + //因为1.0是无法判断到的,,总和会无限接近1.0取为0.99做判断 + if (randomNum == 1) { + randomNum = randomNum - 0.01; + } + + sumFitness = 0; + // 确定区间 + for (int j = 0; j < initSetsNum; j++) { + if (randomNum > sumFitness + && randomNum <= sumFitness + adaptiveValue[j]) { + // 采用拷贝的方式避免引用重复 + resultCodes.add(initCodes.get(j).clone()); + break; + } else { + sumFitness += adaptiveValue[j]; + } + } + } + + return resultCodes; + } + + /** + * 交叉运算 + * + * @param selectedCodes 上步骤的选择后的编码 + */ + private ArrayList crossOperate(ArrayList selectedCodes){ + int randomNum; + // 交叉点 + int crossPoint; + ArrayList resultCodes = new ArrayList<>(); + // 随机编码队列,进行随机交叉配对 + ArrayList randomCodeSeqs = new ArrayList<>(); + + // 进行随机排序 + while (selectedCodes.size() > 0) { + randomNum = random.nextInt(selectedCodes.size()); + + randomCodeSeqs.add(selectedCodes.get(randomNum)); + selectedCodes.remove(randomNum); + } + + int temp; + int[] array1; + int[] array2; + // 进行两两交叉运算 + for (int i = 1; i < randomCodeSeqs.size(); i++) { + if (i % 2 == 1) { + array1 = randomCodeSeqs.get(i - 1); + array2 = randomCodeSeqs.get(i); + crossPoint = random.nextInt(stepNum - 1) + 1; + + // 进行交叉点位置后的编码调换 + for (int j = 0; j < 2 * stepNum; j++) { + if (j >= 2 * crossPoint) { + temp = array1[j]; + array1[j] = array2[j]; + array2[j] = temp; + } + } + + // 加入到交叉运算结果中 + resultCodes.add(array1); + resultCodes.add(array2); + } + } + return resultCodes; + } + + /** + * 变异操作 + * + * @param crossCodes 交叉运算后的结果 + */ + private ArrayList variationOperate(ArrayList crossCodes){ + // 变异点 + int variationPoint; + ArrayList resultCodes = new ArrayList<>(); + + for (int[] array : crossCodes) { + variationPoint = random.nextInt(stepNum); + + for (int i = 0; i < array.length; i += 2) { + // 变异点进行变异 + if (i % 2 == 0 && i / 2 == variationPoint) { + array[i] = (array[i] == 0 ? 1 : 0); + array[i + 1] = (array[i + 1] == 0 ? 1 : 0); + break; + } + } + + resultCodes.add(array); + } + + return resultCodes; + } + + /** + * 根据编码计算适值 + * + * @param code 当前的编码 + */ + private double calFitness(int[] code){ + double fintness; + // 由编码计算所得的终点横坐标 + int endX; + // 由编码计算所得的终点纵坐标 + int endY; + // 基于片段所代表的行走方向 + int direction; + // 临时坐标点横坐标 + int tempX; + // 临时坐标点纵坐标 + int tempY; + + endX = startPos[0]; + endY = startPos[1]; + for (int i = 0; i < stepNum; i++) { + direction = binaryArrayToNum(new int[]{code[2 * i], + code[2 * i + 1]}); + + // 根据方向改变数组做坐标点的改变 + tempX = endX + MAZE_DIRECTION_CHANGE[direction][0]; + tempY = endY + MAZE_DIRECTION_CHANGE[direction][1]; + + // 判断坐标点是否越界 + if (tempX >= 0 && tempX < mazeData.length && tempY >= 0 + && tempY < mazeData[0].length) { + // 判断坐标点是否走到阻碍块 + if (mazeData[tempX][tempY] != -1) { + endX = tempX; + endY = tempY; + } + } + } + + // 根据适值函数进行适值的计算 + fintness = 1.0 / (Math.abs(endX - endPos[0]) + + Math.abs(endY - endPos[1]) + 1); + + return fintness; + } + + /** + * 根据当前编码判断是否已经找到出口位置 + * + * @param code 经过若干次遗传的编码 + */ + private boolean ifArriveEndPos(int[] code){ + boolean isArrived = false; + // 由编码计算所得的终点横坐标 + int endX; + // 由编码计算所得的终点纵坐标 + int endY; + // 基于片段所代表的行走方向 + int direction; + // 临时坐标点横坐标 + int tempX; + // 临时坐标点纵坐标 + int tempY; + + endX = startPos[0]; + endY = startPos[1]; + for (int i = 0; i < stepNum; i++) { + direction = binaryArrayToNum(new int[]{code[2 * i], + code[2 * i + 1]}); + + // 根据方向改变数组做坐标点的改变 + tempX = endX + MAZE_DIRECTION_CHANGE[direction][0]; + tempY = endY + MAZE_DIRECTION_CHANGE[direction][1]; + + // 判断坐标点是否越界 + if (tempX >= 0 && tempX < mazeData.length && tempY >= 0 + && tempY < mazeData[0].length) { + // 判断坐标点是否走到阻碍块 + if (mazeData[tempX][tempY] != -1) { + endX = tempX; + endY = tempY; + } + } + } + + if (endX == endPos[0] && endY == endPos[1]) { + isArrived = true; + } + + return isArrived; + } + + /** + * 二进制数组转化为数字 + * + * @param binaryArray 待转化二进制数组 + */ + private int binaryArrayToNum(int[] binaryArray){ + int result = 0; + + for (int i = binaryArray.length - 1, k = 0; i >= 0; i--, k++) { + if (binaryArray[i] == 1) { + result += Math.pow(2, k); + } + } + + return result; + } + + /** + * 进行遗传算法走出迷宫 + */ + void goOutMaze(){ + // 迭代遗传次数 + int loopCount = 0; + boolean canExit = false; + // 结果路径 + int[] resultCode = null; + ArrayList initCodes; + ArrayList selectedCodes; + ArrayList crossedCodes; + ArrayList variationCodes; + + // 产生初始数据集 + produceInitSet(); + initCodes = initSets; + + while (true) { + for (int[] array : initCodes) { + // 遗传迭代的终止条件为是否找到出口位置 + if (ifArriveEndPos(array)) { + resultCode = array; + canExit = true; + break; + } + } + + if (canExit) { + break; + } + + selectedCodes = selectOperate(initCodes); + crossedCodes = crossOperate(selectedCodes); + variationCodes = variationOperate(crossedCodes); + initCodes = variationCodes; + + loopCount++; + + //如果遗传次数超过100次,则退出 + if (loopCount >= 100) { + break; + } + } + + System.out.println("总共遗传进化了" + loopCount + "次"); + printFindedRoute(resultCode); + } + + /** + * 输出找到的路径 + * + * @param code 编码 + */ + private void printFindedRoute(int[] code){ + if (code == null) { + System.out.println("在有限的遗传进化次数内,没有找到最优路径"); + return; + } + + int tempX = startPos[0]; + int tempY = startPos[1]; + int direction; + + System.out.println(MessageFormat.format( + "起始点位置({0},{1}), 出口点位置({2}, {3})", tempX, tempY, endPos[0], + endPos[1])); + + System.out.print("搜索到的结果编码:"); + Arrays.stream(code).forEach(value -> System.out.print("" + value)); + System.out.println(); + + for (int i = 0, k = 1; i < code.length; i += 2, k++) { + direction = binaryArrayToNum(new int[]{code[i], code[i + 1]}); + + tempX += MAZE_DIRECTION_CHANGE[direction][0]; + tempY += MAZE_DIRECTION_CHANGE[direction][1]; + + System.out.println(MessageFormat.format( + "第{0}步,编码为{1}{2},向{3}移动,移动后到达({4},{5})", k, code[i], code[i + 1], + MAZE_DIRECTION_LABEL[direction], tempX, tempY)); + } + } } diff --git a/Others/DataMining_KDTree/Client.java b/Others/DataMining_KDTree/Client.java index bba7377..f052ba7 100644 --- a/Others/DataMining_KDTree/Client.java +++ b/Others/DataMining_KDTree/Client.java @@ -1,36 +1,35 @@ -package DataMining_KDTree; +package Others.DataMining_KDTree; import java.text.MessageFormat; /** * KD树算法测试类 - * - * @author lyq - * + * + * @author Qstar */ public class Client { - public static void main(String[] args) { - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - Point queryNode; - Point searchedNode; - KDTreeTool tool = new KDTreeTool(filePath); + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_KDTree/input.txt"; + Point queryNode; + Point searchedNode; + KDTreeTool tool = new KDTreeTool(filePath); - // 进行KD树的构建 - tool.createKDTree(); + // 进行KD树的构建 + tool.createKDTree(); - // 通过KD树进行数据点的最近点查询 - queryNode = new Point(2.1, 3.1); - searchedNode = tool.searchNearestData(queryNode); - System.out.println(MessageFormat.format( - "距离查询点({0}, {1})最近的坐标点为({2}, {3})", queryNode.x, queryNode.y, - searchedNode.x, searchedNode.y)); - - //重新构造KD树,去除之前的访问记录 - tool.createKDTree(); - queryNode = new Point(2, 4.5); - searchedNode = tool.searchNearestData(queryNode); - System.out.println(MessageFormat.format( - "距离查询点({0}, {1})最近的坐标点为({2}, {3})", queryNode.x, queryNode.y, - searchedNode.x, searchedNode.y)); - } + // 通过KD树进行数据点的最近点查询 + queryNode = new Point(2.1, 3.1); + searchedNode = tool.searchNearestData(queryNode); + System.out.println(MessageFormat.format( + "距离查询点({0}, {1})最近的坐标点为({2}, {3})", queryNode.x, queryNode.y, + searchedNode.x, searchedNode.y)); + + //重新构造KD树,去除之前的访问记录 + tool.createKDTree(); + queryNode = new Point(2, 4.5); + searchedNode = tool.searchNearestData(queryNode); + System.out.println(MessageFormat.format( + "距离查询点({0}, {1})最近的坐标点为({2}, {3})", queryNode.x, queryNode.y, + searchedNode.x, searchedNode.y)); + } } diff --git a/Others/DataMining_KDTree/KDTreeTool.java b/Others/DataMining_KDTree/KDTreeTool.java index 0b5a53c..c40ee5a 100644 --- a/Others/DataMining_KDTree/KDTreeTool.java +++ b/Others/DataMining_KDTree/KDTreeTool.java @@ -1,4 +1,4 @@ -package DataMining_KDTree; +package Others.DataMining_KDTree; import java.io.BufferedReader; import java.io.File; @@ -6,381 +6,346 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collections; -import java.util.Comparator; import java.util.Stack; /** * KD树-k维空间关键数据检索算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class KDTreeTool { - // 空间平面的方向 - public static final int DIRECTION_X = 0; - public static final int DIRECTION_Y = 1; - - // 输入的测试数据坐标点文件 - private String filePath; - // 原始所有数据点数据 - private ArrayList totalDatas; - // KD树根节点 - private TreeNode rootNode; - - public KDTreeTool(String filePath) { - this.filePath = filePath; - - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - Point p; - totalDatas = new ArrayList<>(); - for (String[] array : dataArray) { - p = new Point(array[0], array[1]); - totalDatas.add(p); - } - } - - /** - * 创建KD树 - * - * @return - */ - public TreeNode createKDTree() { - ArrayList copyDatas; - - rootNode = new TreeNode(); - // 根据节点开始时所表示的空间时无限大的 - rootNode.range = new Range(); - copyDatas = (ArrayList) totalDatas.clone(); - recusiveConstructNode(rootNode, copyDatas); - - return rootNode; - } - - /** - * 递归进行KD树的构造 - * - * @param node - * 当前正在构造的节点 - * @param datas - * 该节点对应的正在处理的数据 - * @return - */ - private void recusiveConstructNode(TreeNode node, ArrayList datas) { - int direction = 0; - ArrayList leftSideDatas; - ArrayList rightSideDatas; - Point p; - TreeNode leftNode; - TreeNode rightNode; - Range range; - Range range2; - - // 如果划分的数据点集合只有1个数据,则不再划分 - if (datas.size() == 1) { - node.nodeData = datas.get(0); - return; - } - - // 首先在当前的数据点集合中进行分割方向的选择 - direction = selectSplitDrc(datas); - // 根据方向取出中位数点作为数据矢量 - p = getMiddlePoint(datas, direction); - - node.spilt = direction; - node.nodeData = p; - - leftSideDatas = getLeftSideDatas(datas, p, direction); - datas.removeAll(leftSideDatas); - // 还要去掉自身 - datas.remove(p); - rightSideDatas = datas; - - if (leftSideDatas.size() > 0) { - leftNode = new TreeNode(); - leftNode.parentNode = node; - range2 = Range.initLeftRange(p, direction); - // 获取父节点的空间矢量,进行交集运算做范围拆分 - range = node.range.crossOperation(range2); - leftNode.range = range; - - node.leftNode = leftNode; - recusiveConstructNode(leftNode, leftSideDatas); - } - - if (rightSideDatas.size() > 0) { - rightNode = new TreeNode(); - rightNode.parentNode = node; - range2 = Range.initRightRange(p, direction); - // 获取父节点的空间矢量,进行交集运算做范围拆分 - range = node.range.crossOperation(range2); - rightNode.range = range; - - node.rightNode = rightNode; - recusiveConstructNode(rightNode, rightSideDatas); - } - } - - /** - * 搜索出给定数据点的最近点 - * - * @param p - * 待比较坐标点 - */ - public Point searchNearestData(Point p) { - // 节点距离给定数据点的距离 - TreeNode nearestNode = null; - // 用栈记录遍历过的节点 - Stack stackNodes; - - stackNodes = new Stack<>(); - findedNearestLeafNode(p, rootNode, stackNodes); - - // 取出叶子节点,作为当前找到的最近节点 - nearestNode = stackNodes.pop(); - nearestNode = dfsSearchNodes(stackNodes, p, nearestNode); - - return nearestNode.nodeData; - } - - /** - * 深度优先的方式进行最近点的查找 - * - * @param stack - * KD树节点栈 - * @param desPoint - * 给定的数据点 - * @param nearestNode - * 当前找到的最近节点 - * @return - */ - private TreeNode dfsSearchNodes(Stack stack, Point desPoint, - TreeNode nearestNode) { - // 是否碰到父节点边界 - boolean isCollision; - double minDis; - double dis; - TreeNode parentNode; - - // 如果栈内节点已经全部弹出,则遍历结束 - if (stack.isEmpty()) { - return nearestNode; - } - - // 获取父节点 - parentNode = stack.pop(); - - minDis = desPoint.ouDistance(nearestNode.nodeData); - dis = desPoint.ouDistance(parentNode.nodeData); - - // 如果与当前回溯到的父节点距离更短,则搜索到的节点进行更新 - if (dis < minDis) { - minDis = dis; - nearestNode = parentNode; - } - - // 默认没有碰撞到 - isCollision = false; - // 判断是否触碰到了父节点的空间分割线 - if (parentNode.spilt == DIRECTION_X) { - if (parentNode.nodeData.x > desPoint.x - minDis - && parentNode.nodeData.x < desPoint.x + minDis) { - isCollision = true; - } - } else { - if (parentNode.nodeData.y > desPoint.y - minDis - && parentNode.nodeData.y < desPoint.y + minDis) { - isCollision = true; - } - } - - // 如果触碰到父边界了,并且此节点的孩子节点还未完全遍历完,则可以继续遍历 - if (isCollision - && (!parentNode.leftNode.isVisited || !parentNode.rightNode.isVisited)) { - TreeNode newNode; - // 新建当前的小局部节点栈 - Stack otherStack = new Stack<>(); - // 从parentNode的树以下继续寻找 - findedNearestLeafNode(desPoint, parentNode, otherStack); - newNode = dfsSearchNodes(otherStack, desPoint, otherStack.pop()); - - dis = newNode.nodeData.ouDistance(desPoint); - if (dis < minDis) { - nearestNode = newNode; - } - } - - // 继续往上回溯 - nearestNode = dfsSearchNodes(stack, desPoint, nearestNode); - - return nearestNode; - } - - /** - * 找到与所给定节点的最近的叶子节点 - * - * @param p - * 待比较节点 - * @param node - * 当前搜索到的节点 - * @param stack - * 遍历过的节点栈 - */ - private void findedNearestLeafNode(Point p, TreeNode node, - Stack stack) { - // 分割方向 - int splitDic; - - // 将遍历过的节点加入栈中 - stack.push(node); - // 标记为访问过 - node.isVisited = true; - // 如果此节点没有左右孩子节点说明已经是叶子节点了 - if (node.leftNode == null && node.rightNode == null) { - return; - } - - splitDic = node.spilt; - // 选择一个符合分割范围的节点继续递归搜寻 - if ((splitDic == DIRECTION_X && p.x < node.nodeData.x) - || (splitDic == DIRECTION_Y && p.y < node.nodeData.y)) { - if (!node.leftNode.isVisited) { - findedNearestLeafNode(p, node.leftNode, stack); - } else { - // 如果左孩子节点已经访问过,则访问另一边 - findedNearestLeafNode(p, node.rightNode, stack); - } - } else if ((splitDic == DIRECTION_X && p.x > node.nodeData.x) - || (splitDic == DIRECTION_Y && p.y > node.nodeData.y)) { - if (!node.rightNode.isVisited) { - findedNearestLeafNode(p, node.rightNode, stack); - } else { - // 如果右孩子节点已经访问过,则访问另一边 - findedNearestLeafNode(p, node.leftNode, stack); - } - } - } - - /** - * 根据给定的数据点通过计算反差选择的分割点 - * - * @param datas - * 部分的集合点集合 - * @return - */ - private int selectSplitDrc(ArrayList datas) { - int direction = 0; - double avgX = 0; - double avgY = 0; - double varianceX = 0; - double varianceY = 0; - - for (Point p : datas) { - avgX += p.x; - avgY += p.y; - } - - avgX /= datas.size(); - avgY /= datas.size(); - - for (Point p : datas) { - varianceX += (p.x - avgX) * (p.x - avgX); - varianceY += (p.y - avgY) * (p.y - avgY); - } - - // 求最后的方差 - varianceX /= datas.size(); - varianceY /= datas.size(); - - // 通过比较方差的大小决定分割方向,选择波动较大的进行划分 - direction = varianceX > varianceY ? DIRECTION_X : DIRECTION_Y; - - return direction; - } - - /** - * 根据坐标点方位进行排序,选出中间点的坐标数据 - * - * @param datas - * 数据点集合 - * @param dir - * 排序的坐标方向 - */ - private Point getMiddlePoint(ArrayList datas, int dir) { - int index = 0; - Point middlePoint; - - index = datas.size() / 2; - if (dir == DIRECTION_X) { - Collections.sort(datas, new Comparator() { - - @Override - public int compare(Point o1, Point o2) { - // TODO Auto-generated method stub - return o1.x.compareTo(o2.x); - } - }); - } else { - Collections.sort(datas, new Comparator() { - - @Override - public int compare(Point o1, Point o2) { - // TODO Auto-generated method stub - return o1.y.compareTo(o2.y); - } - }); - } - - // 取出中位数 - middlePoint = datas.get(index); - - return middlePoint; - } - - /** - * 根据方向得到原部分节点集合左侧的数据点 - * - * @param datas - * 原始数据点集合 - * @param nodeData - * 数据矢量 - * @param dir - * 分割方向 - * @return - */ - private ArrayList getLeftSideDatas(ArrayList datas, - Point nodeData, int dir) { - ArrayList leftSideDatas = new ArrayList<>(); - - for (Point p : datas) { - if (dir == DIRECTION_X && p.x < nodeData.x) { - leftSideDatas.add(p); - } else if (dir == DIRECTION_Y && p.y < nodeData.y) { - leftSideDatas.add(p); - } - } - - return leftSideDatas; - } +class KDTreeTool { + // 空间平面的方向 + static final int DIRECTION_X = 0; + private static final int DIRECTION_Y = 1; + + // 输入的测试数据坐标点文件 + private String filePath; + // 原始所有数据点数据 + private ArrayList totalDatas; + // KD树根节点 + private TreeNode rootNode; + + KDTreeTool(String filePath){ + this.filePath = filePath; + + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + Point p; + totalDatas = new ArrayList<>(); + for (String[] array : dataArray) { + p = new Point(array[0], array[1]); + totalDatas.add(p); + } + } + + /** + * 创建KD树 + */ + TreeNode createKDTree(){ + ArrayList copyDatas; + + rootNode = new TreeNode(); + // 根据节点开始时所表示的空间时无限大的 + rootNode.range = new Range(); + copyDatas = (ArrayList) totalDatas.clone(); + recursiveConstructNode(rootNode, copyDatas); + + return rootNode; + } + + /** + * 递归进行KD树的构造 + * + * @param node 当前正在构造的节点 + * @param data 该节点对应的正在处理的数据 + */ + private void recursiveConstructNode(TreeNode node, ArrayList data){ + int direction; + ArrayList leftSideData; + ArrayList rightSideData; + Point p; + TreeNode leftNode; + TreeNode rightNode; + Range range; + Range range2; + + // 如果划分的数据点集合只有1个数据,则不再划分 + if (data.size() == 1) { + node.nodeData = data.get(0); + return; + } + + // 首先在当前的数据点集合中进行分割方向的选择 + direction = selectSplitDrc(data); + // 根据方向取出中位数点作为数据矢量 + p = getMiddlePoint(data, direction); + + node.spilt = direction; + node.nodeData = p; + + leftSideData = getLeftSideData(data, p, direction); + data.removeAll(leftSideData); + // 还要去掉自身 + data.remove(p); + rightSideData = data; + + if (leftSideData.size() > 0) { + leftNode = new TreeNode(); + leftNode.parentNode = node; + range2 = Range.initLeftRange(p, direction); + // 获取父节点的空间矢量,进行交集运算做范围拆分 + range = node.range.crossOperation(range2); + leftNode.range = range; + + node.leftNode = leftNode; + recursiveConstructNode(leftNode, leftSideData); + } + + if (rightSideData.size() > 0) { + rightNode = new TreeNode(); + rightNode.parentNode = node; + range2 = Range.initRightRange(p, direction); + // 获取父节点的空间矢量,进行交集运算做范围拆分 + range = node.range.crossOperation(range2); + rightNode.range = range; + + node.rightNode = rightNode; + recursiveConstructNode(rightNode, rightSideData); + } + } + + /** + * 搜索出给定数据点的最近点 + * + * @param p 待比较坐标点 + */ + Point searchNearestData(Point p){ + // 节点距离给定数据点的距离 + TreeNode nearestNode; + // 用栈记录遍历过的节点 + Stack stackNodes; + + stackNodes = new Stack<>(); + findedNearestLeafNode(p, rootNode, stackNodes); + + // 取出叶子节点,作为当前找到的最近节点 + nearestNode = stackNodes.pop(); + nearestNode = dfsSearchNodes(stackNodes, p, nearestNode); + + return nearestNode.nodeData; + } + + /** + * 深度优先的方式进行最近点的查找 + * + * @param stack KD树节点栈 + * @param desPoint 给定的数据点 + * @param nearestNode 当前找到的最近节点 + */ + private TreeNode dfsSearchNodes(Stack stack, Point desPoint, + TreeNode nearestNode){ + // 是否碰到父节点边界 + boolean isCollision; + double minDis; + double dis; + TreeNode parentNode; + + // 如果栈内节点已经全部弹出,则遍历结束 + if (stack.isEmpty()) { + return nearestNode; + } + + // 获取父节点 + parentNode = stack.pop(); + + minDis = desPoint.ouDistance(nearestNode.nodeData); + dis = desPoint.ouDistance(parentNode.nodeData); + + // 如果与当前回溯到的父节点距离更短,则搜索到的节点进行更新 + if (dis < minDis) { + minDis = dis; + nearestNode = parentNode; + } + + // 默认没有碰撞到 + isCollision = false; + // 判断是否触碰到了父节点的空间分割线 + if (parentNode.spilt == DIRECTION_X) { + if (parentNode.nodeData.x > desPoint.x - minDis + && parentNode.nodeData.x < desPoint.x + minDis) { + isCollision = true; + } + } else { + if (parentNode.nodeData.y > desPoint.y - minDis + && parentNode.nodeData.y < desPoint.y + minDis) { + isCollision = true; + } + } + + // 如果触碰到父边界了,并且此节点的孩子节点还未完全遍历完,则可以继续遍历 + if (isCollision + && (!parentNode.leftNode.isVisited || !parentNode.rightNode.isVisited)) { + TreeNode newNode; + // 新建当前的小局部节点栈 + Stack otherStack = new Stack<>(); + // 从parentNode的树以下继续寻找 + findedNearestLeafNode(desPoint, parentNode, otherStack); + newNode = dfsSearchNodes(otherStack, desPoint, otherStack.pop()); + + dis = newNode.nodeData.ouDistance(desPoint); + if (dis < minDis) { + nearestNode = newNode; + } + } + + // 继续往上回溯 + nearestNode = dfsSearchNodes(stack, desPoint, nearestNode); + + return nearestNode; + } + + /** + * 找到与所给定节点的最近的叶子节点 + * + * @param p 待比较节点 + * @param node 当前搜索到的节点 + * @param stack 遍历过的节点栈 + */ + private void findedNearestLeafNode(Point p, TreeNode node, + Stack stack){ + // 分割方向 + int splitDic; + + // 将遍历过的节点加入栈中 + stack.push(node); + // 标记为访问过 + node.isVisited = true; + // 如果此节点没有左右孩子节点说明已经是叶子节点了 + if (node.leftNode == null && node.rightNode == null) { + return; + } + + splitDic = node.spilt; + // 选择一个符合分割范围的节点继续递归搜寻 + if ((splitDic == DIRECTION_X && p.x < node.nodeData.x) + || (splitDic == DIRECTION_Y && p.y < node.nodeData.y)) { + if (node.leftNode != null) { + if (!node.leftNode.isVisited) { + findedNearestLeafNode(p, node.leftNode, stack); + } else { + // 如果左孩子节点已经访问过,则访问另一边 + findedNearestLeafNode(p, node.rightNode, stack); + } + } + } else if ((splitDic == DIRECTION_X && p.x > node.nodeData.x) + || (splitDic == DIRECTION_Y && p.y > node.nodeData.y)) { + if (!node.rightNode.isVisited) { + findedNearestLeafNode(p, node.rightNode, stack); + } else { + // 如果右孩子节点已经访问过,则访问另一边 + findedNearestLeafNode(p, node.leftNode, stack); + } + } + } + + /** + * 根据给定的数据点通过计算反差选择的分割点 + * + * @param data 部分的集合点集合 + */ + private int selectSplitDrc(ArrayList data){ + int direction; + double avgX = 0; + double avgY = 0; + double varianceX = 0; + double varianceY = 0; + + for (Point p : data) { + avgX += p.x; + avgY += p.y; + } + + avgX /= data.size(); + avgY /= data.size(); + + for (Point p : data) { + varianceX += (p.x - avgX) * (p.x - avgX); + varianceY += (p.y - avgY) * (p.y - avgY); + } + + // 求最后的方差 + varianceX /= data.size(); + varianceY /= data.size(); + + // 通过比较方差的大小决定分割方向,选择波动较大的进行划分 + direction = varianceX > varianceY ? DIRECTION_X : DIRECTION_Y; + + return direction; + } + + /** + * 根据坐标点方位进行排序,选出中间点的坐标数据 + * + * @param datas 数据点集合 + * @param dir 排序的坐标方向 + */ + private Point getMiddlePoint(ArrayList datas, int dir){ + int index; + Point middlePoint; + + index = datas.size() / 2; + if (dir == DIRECTION_X) { + Collections.sort(datas, (o1, o2) -> o1.x.compareTo(o2.x)); + } else { + Collections.sort(datas, (o1, o2) -> o2.y.compareTo(o1.y)); + } + + // 取出中位数 + middlePoint = datas.get(index); + + return middlePoint; + } + + /** + * 根据方向得到原部分节点集合左侧的数据点 + * + * @param data 原始数据点集合 + * @param nodeData 数据矢量 + * @param dir 分割方向 + */ + private ArrayList getLeftSideData(ArrayList data, + Point nodeData, int dir){ + ArrayList leftSideDatas = new ArrayList<>(); + + for (Point p : data) { + if (dir == DIRECTION_X && p.x < nodeData.x) { + leftSideDatas.add(p); + } else if (dir == DIRECTION_Y && p.y < nodeData.y) { + leftSideDatas.add(p); + } + } + + return leftSideDatas; + } } diff --git a/Others/DataMining_KDTree/Point.java b/Others/DataMining_KDTree/Point.java index c98a770..8610143 100644 --- a/Others/DataMining_KDTree/Point.java +++ b/Others/DataMining_KDTree/Point.java @@ -1,58 +1,55 @@ -package DataMining_KDTree; +package Others.DataMining_KDTree; + +import java.util.Objects; /** * 坐标点类 - * - * @author lyq - * + * + * @author Qstar */ -public class Point{ - // 坐标点横坐标 - Double x; - // 坐标点纵坐标 - Double y; - - public Point(double x, double y){ - this.x = x; - this.y = y; - } - - public Point(String x, String y) { - this.x = (Double.parseDouble(x)); - this.y = (Double.parseDouble(y)); - } - - /** - * 计算当前点与制定点之间的欧式距离 - * - * @param p - * 待计算聚类的p点 - * @return - */ - public double ouDistance(Point p) { - double distance = 0; - - distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y) - * (this.y - p.y); - distance = Math.sqrt(distance); - - return distance; - } - - /** - * 判断2个坐标点是否为用个坐标点 - * - * @param p - * 待比较坐标点 - * @return - */ - public boolean isTheSame(Point p) { - boolean isSamed = false; - - if (this.x == p.x && this.y == p.y) { - isSamed = true; - } - - return isSamed; - } +public class Point { + // 坐标点横坐标 + Double x; + // 坐标点纵坐标 + Double y; + + public Point(double x, double y){ + this.x = x; + this.y = y; + } + + public Point(String x, String y){ + this.x = (Double.parseDouble(x)); + this.y = (Double.parseDouble(y)); + } + + /** + * 计算当前点与制定点之间的欧式距离 + * + * @param p 待计算聚类的p点 + */ + double ouDistance(Point p){ + double distance; + + distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y) + * (this.y - p.y); + distance = Math.sqrt(distance); + + return distance; + } + + /** + * 判断2个坐标点是否为用个坐标点 + * + * @param p 待比较坐标点 + */ + public boolean isTheSame(Point p){ + boolean isSame = false; + + if (Objects.equals(this.x, p.x) && Objects.equals(this.y, p.y)) { + isSame = true; + } + + return isSame; + } } diff --git a/Others/DataMining_KDTree/Range.java b/Others/DataMining_KDTree/Range.java index b36d3d3..97aec96 100644 --- a/Others/DataMining_KDTree/Range.java +++ b/Others/DataMining_KDTree/Range.java @@ -1,114 +1,106 @@ -package DataMining_KDTree; +package Others.DataMining_KDTree; /** * 空间矢量,表示所代表的空间范围 - * - * @author lyq - * + * + * @author Qstar */ -public class Range { - // 边界左边界 - double left; - // 边界右边界 - double right; - // 边界上边界 - double top; - // 边界下边界 - double bottom; +class Range { + // 边界左边界 + private double left; + // 边界右边界 + private double right; + // 边界上边界 + private double top; + // 边界下边界 + private double bottom; - public Range() { - this.left = -Integer.MAX_VALUE; - this.right = Integer.MAX_VALUE; - this.top = Integer.MAX_VALUE; - this.bottom = -Integer.MAX_VALUE; - } + Range(){ + this.left = -Integer.MAX_VALUE; + this.right = Integer.MAX_VALUE; + this.top = Integer.MAX_VALUE; + this.bottom = -Integer.MAX_VALUE; + } - public Range(int left, int right, int top, int bottom) { - this.left = left; - this.right = right; - this.top = top; - this.bottom = bottom; - } + public Range(int left, int right, int top, int bottom){ + this.left = left; + this.right = right; + this.top = top; + this.bottom = bottom; + } - /** - * 空间矢量进行并操作 - * - * @param range - * @return - */ - public Range crossOperation(Range r) { - Range range = new Range(); + /** + * 根据坐标点分割方向确定左侧空间矢量 + * + * @param p 数据矢量 + * @param dir 分割方向 + */ + static Range initLeftRange(Point p, int dir){ + Range range = new Range(); - // 取靠近右侧的左边界 - if (r.left > this.left) { - range.left = r.left; - } else { - range.left = this.left; - } + if (dir == KDTreeTool.DIRECTION_X) { + range.right = p.x; + } else { + range.bottom = p.y; + } - // 取靠近左侧的右边界 - if (r.right < this.right) { - range.right = r.right; - } else { - range.right = this.right; - } + return range; + } - // 取靠近下侧的上边界 - if (r.top < this.top) { - range.top = r.top; - } else { - range.top = this.top; - } + /** + * 根据坐标点分割方向确定右侧空间矢量 + * + * @param p 数据矢量 + * @param dir 分割方向 + */ + static Range initRightRange(Point p, int dir){ + Range range = new Range(); - // 取靠近上侧的下边界 - if (r.bottom > this.bottom) { - range.bottom = r.bottom; - } else { - range.bottom = this.bottom; - } + if (dir == KDTreeTool.DIRECTION_X) { + range.left = p.x; + } else { + range.top = p.y; + } - return range; - } + return range; + } - /** - * 根据坐标点分割方向确定左侧空间矢量 - * - * @param p - * 数据矢量 - * @param dir - * 分割方向 - * @return - */ - public static Range initLeftRange(Point p, int dir) { - Range range = new Range(); + /** + * 空间矢量进行并操作 + * + * @param range1 空间矢量,表示所代表的空间范围 + */ + Range crossOperation(Range range1){ + Range range = new Range(); - if (dir == KDTreeTool.DIRECTION_X) { - range.right = p.x; - } else { - range.bottom = p.y; - } + // 取靠近右侧的左边界 + if (range1.left > this.left) { + range.left = range1.left; + } else { + range.left = this.left; + } - return range; - } + // 取靠近左侧的右边界 + if (range1.right < this.right) { + range.right = range1.right; + } else { + range.right = this.right; + } - /** - * 根据坐标点分割方向确定右侧空间矢量 - * - * @param p - * 数据矢量 - * @param dir - * 分割方向 - * @return - */ - public static Range initRightRange(Point p, int dir) { - Range range = new Range(); + // 取靠近下侧的上边界 + if (range1.top < this.top) { + range.top = range1.top; + } else { + range.top = this.top; + } - if (dir == KDTreeTool.DIRECTION_X) { - range.left = p.x; - } else { - range.top = p.y; - } + // 取靠近上侧的下边界 + if (range1.bottom > this.bottom) { + range.bottom = range1.bottom; + } else { + range.bottom = this.bottom; + } - return range; - } + return range; + } } diff --git a/Others/DataMining_KDTree/TreeNode.java b/Others/DataMining_KDTree/TreeNode.java index 127833c..779bae1 100644 --- a/Others/DataMining_KDTree/TreeNode.java +++ b/Others/DataMining_KDTree/TreeNode.java @@ -1,27 +1,27 @@ -package DataMining_KDTree; +package Others.DataMining_KDTree; /** * KD树节点 - * @author lyq * + * @author Qstar */ -public class TreeNode { - //数据矢量 - Point nodeData; - //分割平面的分割线 - int spilt; - //空间矢量,该节点所表示的空间范围 - Range range; - //父节点 - TreeNode parentNode; - //位于分割超平面左侧的孩子节点 - TreeNode leftNode; - //位于分割超平面右侧的孩子节点 - TreeNode rightNode; - //节点是否被访问过,用于回溯时使用 - boolean isVisited; - - public TreeNode(){ - this.isVisited = false; - } +class TreeNode { + //数据矢量 + Point nodeData; + //分割平面的分割线 + int spilt; + //空间矢量,该节点所表示的空间范围 + Range range; + //父节点 + TreeNode parentNode; + //位于分割超平面左侧的孩子节点 + TreeNode leftNode; + //位于分割超平面右侧的孩子节点 + TreeNode rightNode; + //节点是否被访问过,用于回溯时使用 + boolean isVisited; + + TreeNode(){ + this.isVisited = false; + } } diff --git a/Others/DataMining_MSApriori/Client.java b/Others/DataMining_MSApriori/Client.java index f49e83d..1c7db93 100644 --- a/Others/DataMining_MSApriori/Client.java +++ b/Others/DataMining_MSApriori/Client.java @@ -1,45 +1,44 @@ -package DataMining_MSApriori; +package Others.DataMining_MSApriori; /** * 基于多支持度的Apriori算法测试类 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - //是否是事务型数据 - boolean isTransaction; - //测试数据文件地址 - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - //关系表型数据文件地址 - String tableFilePath = "C:\\Users\\lyq\\Desktop\\icon\\input2.txt"; - //最小支持度阈值 - double minSup; - // 最小置信度率 - double minConf; - //最大支持度差别阈值 - double delta; - //多项目的最小支持度数,括号中的下标代表的是商品的ID - double[] mis; - //msApriori算法工具类 - MSAprioriTool tool; - - //为了测试的方便,取一个偏低的置信度值0.3 - minConf = 0.3; - minSup = 0.1; - delta = 0.5; - //每项的支持度率都默认为0.1,第一项不使用 - mis = new double[]{-1, 0.1, 0.1, 0.1, 0.1, 0.1}; - isTransaction = true; - - isTransaction = true; - tool = new MSAprioriTool(filePath, minConf, delta, mis, isTransaction); - tool.calFItems(); - System.out.println(); - - isTransaction = false; - //重新初始化数据 - tool = new MSAprioriTool(tableFilePath, minConf, minSup, isTransaction); - tool.calFItems(); - } + public static void main(String[] args){ + //是否是事务型数据 + boolean isTransaction; + //测试数据文件地址 + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_MSApriori/testInput.txt"; + //关系表型数据文件地址 + String tableFilePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_MSApriori/testInput2.txt"; + //最小支持度阈值 + double minSup; + // 最小置信度率 + double minConf; + //最大支持度差别阈值 + double delta; + //多项目的最小支持度数,括号中的下标代表的是商品的ID + double[] mis; + //msApriori算法工具类 + MSAprioriTool tool; + + //为了测试的方便,取一个偏低的置信度值0.3 + minConf = 0.3; + minSup = 0.1; + delta = 0.5; + //每项的支持度率都默认为0.1,第一项不使用 + mis = new double[]{-1, 0.1, 0.1, 0.1, 0.1, 0.1}; + + isTransaction = true; + tool = new MSAprioriTool(filePath, minConf, delta, mis, isTransaction); + tool.calFItems(); + System.out.println(); + + isTransaction = false; + //重新初始化数据 + tool = new MSAprioriTool(tableFilePath, minConf, minSup, isTransaction); + tool.calFItems(); + } } diff --git a/Others/DataMining_MSApriori/FrequentItem.java b/Others/DataMining_MSApriori/FrequentItem.java index 2ba88c4..35681e8 100644 --- a/Others/DataMining_MSApriori/FrequentItem.java +++ b/Others/DataMining_MSApriori/FrequentItem.java @@ -1,56 +1,51 @@ -package DataMining_MSApriori; +package Others.DataMining_MSApriori; /** * 频繁项集 - * - * @author lyq - * + * + * @author Qstar */ -public class FrequentItem implements Comparable{ - // 频繁项集的集合ID - private String[] idArray; - // 频繁项集的支持度计数 - private int count; - //频繁项集的长度,1项集或是2项集,亦或是3项集 - private int length; - - public FrequentItem(String[] idArray, int count){ - this.idArray = idArray; - this.count = count; - length = idArray.length; - } - - public String[] getIdArray() { - return idArray; - } - - public void setIdArray(String[] idArray) { - this.idArray = idArray; - } - - public int getCount() { - return count; - } - - public void setCount(int count) { - this.count = count; - } - - public int getLength() { - return length; - } - - public void setLength(int length) { - this.length = length; - } - - @Override - public int compareTo(FrequentItem o) { - // TODO Auto-generated method stub - Integer int1 = Integer.parseInt(this.getIdArray()[0]); - Integer int2 = Integer.parseInt(o.getIdArray()[0]); - - return int1.compareTo(int2); - } - +class FrequentItem implements Comparable { + // 频繁项集的集合ID + private String[] idArray; + // 频繁项集的支持度计数 + private int count; + //频繁项集的长度,1项集或是2项集,亦或是3项集 + private int length; + + public FrequentItem(String[] idArray, int count){ + this.idArray = idArray; + this.count = count; + length = idArray.length; + } + + private String[] getIdArray(){ + return idArray; + } + + public int getCount(){ + return count; + } + + public void setCount(int count){ + this.count = count; + } + + public int getLength(){ + return length; + } + + public void setLength(int length){ + this.length = length; + } + + @Override + public int compareTo(FrequentItem o){ + // TODO Auto-generated method stub + Integer int1 = Integer.parseInt(this.getIdArray()[0]); + Integer int2 = Integer.parseInt(o.getIdArray()[0]); + + return int1.compareTo(int2); + } + } diff --git a/Others/DataMining_MSApriori/MSAprioriTool.java b/Others/DataMining_MSApriori/MSAprioriTool.java index ba5d444..0ccb0c8 100644 --- a/Others/DataMining_MSApriori/MSAprioriTool.java +++ b/Others/DataMining_MSApriori/MSAprioriTool.java @@ -1,780 +1,754 @@ -package DataMining_MSApriori; +package Others.DataMining_MSApriori; + +import AssociationAnalysis.DataMining_Apriori.FrequentItem; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; -import java.text.MessageFormat; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import DataMining_Apriori.FrequentItem; +import java.util.*; /** * 基于多支持度的Apriori算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class MSAprioriTool { - // 前件判断的结果值,用于关联规则的推导 - public static final int PREFIX_NOT_SUB = -1; - public static final int PREFIX_EQUAL = 1; - public static final int PREFIX_IS_SUB = 2; - - // 是否读取的是事务型数据 - private boolean isTransaction; - // 最大频繁k项集的k值 - private int initFItemNum; - // 事务数据文件地址 - private String filePath; - // 最小支持度阈值 - private double minSup; - // 最小置信度率 - private double minConf; - // 最大支持度差别阈值 - private double delta; - // 多项目的最小支持度数,括号中的下标代表的是商品的ID - private double[] mis; - // 每个事务中的商品ID - private ArrayList totalGoodsIDs; - // 关系表数据所转化的事务数据 - private ArrayList transactionDatas; - // 过程中计算出来的所有频繁项集列表 - private ArrayList resultItem; - // 过程中计算出来频繁项集的ID集合 - private ArrayList resultItemID; - // 属性到数字的映射图 - private HashMap attr2Num; - // 数字id对应属性的映射图 - private HashMap num2Attr; - // 频繁项集所覆盖的id数值 - private Map fItem2Id; - - /** - * 事务型数据关联挖掘算法 - * - * @param filePath - * @param minConf - * @param delta - * @param mis - * @param isTransaction - */ - public MSAprioriTool(String filePath, double minConf, double delta, - double[] mis, boolean isTransaction) { - this.filePath = filePath; - this.minConf = minConf; - this.delta = delta; - this.mis = mis; - this.isTransaction = isTransaction; - this.fItem2Id = new HashMap<>(); - - readDataFile(); - } - - /** - * 非事务型关联挖掘 - * - * @param filePath - * @param minConf - * @param minSup - * @param isTransaction - */ - public MSAprioriTool(String filePath, double minConf, double minSup, - boolean isTransaction) { - this.filePath = filePath; - this.minConf = minConf; - this.minSup = minSup; - this.isTransaction = isTransaction; - this.delta = 1.0; - this.fItem2Id = new HashMap<>(); - - readRDBMSData(filePath); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - String[] temp = null; - ArrayList dataArray; - - dataArray = readLine(filePath); - totalGoodsIDs = new ArrayList<>(); - - for (String[] array : dataArray) { - temp = new String[array.length - 1]; - System.arraycopy(array, 1, temp, 0, array.length - 1); - - // 将事务ID加入列表吧中 - totalGoodsIDs.add(temp); - } - } - - /** - * 从文件中逐行读数据 - * - * @param filePath - * 数据文件地址 - * @return - */ - private ArrayList readLine(String filePath) { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - return dataArray; - } - - /** - * 计算频繁项集 - */ - public void calFItems() { - FrequentItem fItem; - - computeLink(); - printFItems(); - - if (isTransaction) { - fItem = resultItem.get(resultItem.size() - 1); - // 取出最后一个频繁项集做关联规则的推导 - System.out.println("最后一个频繁项集做关联规则的推导结果:"); - printAttachRuls(fItem.getIdArray()); - } - } - - /** - * 输出频繁项集 - */ - private void printFItems() { - if (isTransaction) { - System.out.println("事务型数据频繁项集输出结果:"); - } else { - System.out.println("非事务(关系)型数据频繁项集输出结果:"); - } - - // 输出频繁项集 - for (int k = 1; k <= initFItemNum; k++) { - System.out.println("频繁" + k + "项集:"); - for (FrequentItem i : resultItem) { - if (i.getLength() == k) { - System.out.print("{"); - for (String t : i.getIdArray()) { - if (!isTransaction) { - // 如果原本是非事务型数据,需要重新做替换 - t = num2Attr.get(Integer.parseInt(t)); - } - - System.out.print(t + ","); - } - System.out.print("},"); - } - } - System.out.println(); - } - } - - /** - * 项集进行连接运算 - */ - private void computeLink() { - // 连接计算的终止数,k项集必须算到k-1子项集为止 - int endNum = 0; - // 当前已经进行连接运算到几项集,开始时就是1项集 - int currentNum = 1; - // 商品,1频繁项集映射图 - HashMap itemMap = new HashMap<>(); - FrequentItem tempItem; - // 初始列表 - ArrayList list = new ArrayList<>(); - // 经过连接运算后产生的结果项集 - resultItem = new ArrayList<>(); - resultItemID = new ArrayList<>(); - // 商品ID的种类 - ArrayList idType = new ArrayList<>(); - for (String[] a : totalGoodsIDs) { - for (String s : a) { - if (!idType.contains(s)) { - tempItem = new FrequentItem(new String[] { s }, 1); - idType.add(s); - resultItemID.add(new String[] { s }); - } else { - // 支持度计数加1 - tempItem = itemMap.get(s); - tempItem.setCount(tempItem.getCount() + 1); - } - itemMap.put(s, tempItem); - } - } - // 将初始频繁项集转入到列表中,以便继续做连接运算 - for (Map.Entry entry : itemMap.entrySet()) { - tempItem = entry.getValue(); - - // 判断1频繁项集是否满足支持度阈值的条件 - if (judgeFItem(tempItem.getIdArray())) { - list.add(tempItem); - } - } - - // 按照商品ID进行排序,否则连接计算结果将会不一致,将会减少 - Collections.sort(list); - resultItem.addAll(list); - - String[] array1; - String[] array2; - String[] resultArray; - ArrayList tempIds; - ArrayList resultContainer; - // 总共要算到endNum项集 - endNum = list.size() - 1; - initFItemNum = list.size() - 1; - - while (currentNum < endNum) { - resultContainer = new ArrayList<>(); - for (int i = 0; i < list.size() - 1; i++) { - tempItem = list.get(i); - array1 = tempItem.getIdArray(); - - for (int j = i + 1; j < list.size(); j++) { - tempIds = new ArrayList<>(); - array2 = list.get(j).getIdArray(); - - for (int k = 0; k < array1.length; k++) { - // 如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作 - if (array1[k].equals(array2[k])) { - tempIds.add(array1[k]); - } else { - tempIds.add(array1[k]); - tempIds.add(array2[k]); - } - } - - resultArray = new String[tempIds.size()]; - tempIds.toArray(resultArray); - - boolean isContain = false; - // 过滤不符合条件的的ID数组,包括重复的和长度不符合要求的 - if (resultArray.length == (array1.length + 1)) { - isContain = isIDArrayContains(resultContainer, - resultArray); - if (!isContain) { - resultContainer.add(resultArray); - } - } - } - } - - // 做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集 - list = cutItem(resultContainer); - currentNum++; - } - } - - /** - * 对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集 - */ - private ArrayList cutItem(ArrayList resultIds) { - String[] temp; - // 忽略的索引位置,以此构建子集 - int igNoreIndex = 0; - FrequentItem tempItem; - // 剪枝生成新的频繁项集 - ArrayList newItem = new ArrayList<>(); - // 不符合要求的id - ArrayList deleteIdArray = new ArrayList<>(); - // 子项集是否也为频繁子项集 - boolean isContain = true; - - for (String[] array : resultIds) { - // 列举出其中的一个个的子项集,判断存在于频繁项集列表中 - temp = new String[array.length - 1]; - for (igNoreIndex = 0; igNoreIndex < array.length; igNoreIndex++) { - isContain = true; - for (int j = 0, k = 0; j < array.length; j++) { - if (j != igNoreIndex) { - temp[k] = array[j]; - k++; - } - } - - if (!isIDArrayContains(resultItemID, temp)) { - isContain = false; - break; - } - } - - if (!isContain) { - deleteIdArray.add(array); - } - } - - // 移除不符合条件的ID组合 - resultIds.removeAll(deleteIdArray); - - // 移除支持度计数不够的id集合 - int tempCount = 0; - boolean isSatisfied = false; - for (String[] array : resultIds) { - isSatisfied = judgeFItem(array); - - // 如果此频繁项集满足多支持度阈值限制条件和支持度差别限制条件,则添加入结果集中 - if (isSatisfied) { - tempItem = new FrequentItem(array, tempCount); - newItem.add(tempItem); - resultItemID.add(array); - resultItem.add(tempItem); - } - } - - return newItem; - } - - /** - * 判断列表结果中是否已经包含此数组 - * - * @param container - * ID数组容器 - * @param array - * 待比较数组 - * @return - */ - private boolean isIDArrayContains(ArrayList container, - String[] array) { - boolean isContain = true; - if (container.size() == 0) { - isContain = false; - return isContain; - } - - for (String[] s : container) { - // 比较的视乎必须保证长度一样 - if (s.length != array.length) { - continue; - } - - isContain = true; - for (int i = 0; i < s.length; i++) { - // 只要有一个id不等,就算不相等 - if (s[i] != array[i]) { - isContain = false; - break; - } - } - - // 如果已经判断是包含在容器中时,直接退出 - if (isContain) { - break; - } - } - - return isContain; - } - - /** - * 判断一个频繁项集是否满足条件 - * - * @param frequentItem - * 待判断频繁项集 - * @return - */ - private boolean judgeFItem(String[] frequentItem) { - boolean isSatisfied = true; - int id; - int count; - double tempMinSup; - // 最小的支持度阈值 - double minMis = Integer.MAX_VALUE; - // 最大的支持度阈值 - double maxMis = -Integer.MAX_VALUE; - - // 如果是事务型数据,用mis数组判断,如果不是统一用同样的最小支持度阈值判断 - if (isTransaction) { - // 寻找频繁项集中的最小支持度阈值 - for (int i = 0; i < frequentItem.length; i++) { - id = i + 1; - - if (mis[id] < minMis) { - minMis = mis[id]; - } - - if (mis[id] > maxMis) { - maxMis = mis[id]; - } - } - } else { - minMis = minSup; - maxMis = minSup; - } - - count = calSupportCount(frequentItem); - tempMinSup = 1.0 * count / totalGoodsIDs.size(); - // 判断频繁项集的支持度阈值是否超过最小的支持度阈值 - if (tempMinSup < minMis) { - isSatisfied = false; - } - - // 如果误差超过了最大支持度差别,也算不满足条件 - if (Math.abs(maxMis - minMis) > delta) { - isSatisfied = false; - } - - return isSatisfied; - } - - /** - * 统计候选频繁项集的支持度数,利用他的子集进行技术,无须扫描整个数据集 - * - * @param frequentItem - * 待计算频繁项集 - * @return - */ - private int calSupportCount(String[] frequentItem) { - int count = 0; - int[] ids; - String key; - String[] array; - ArrayList newIds; - - key = ""; - for (int i = 1; i < frequentItem.length; i++) { - key += frequentItem[i]; - } - - newIds = new ArrayList<>(); - // 找出所属的事务ID - ids = fItem2Id.get(key); - - // 如果没有找到子项集的事务id,则全盘扫描数据集 - if (ids == null || ids.length == 0) { - for (int j = 0; j < totalGoodsIDs.size(); j++) { - array = totalGoodsIDs.get(j); - if (isStrArrayContain(array, frequentItem)) { - count++; - newIds.add(j); - } - } - } else { - for (int index : ids) { - array = totalGoodsIDs.get(index); - if (isStrArrayContain(array, frequentItem)) { - count++; - newIds.add(index); - } - } - } - - ids = new int[count]; - for (int i = 0; i < ids.length; i++) { - ids[i] = newIds.get(i); - } - - key = frequentItem[0] + key; - // 将所求值存入图中,便于下次的计数 - fItem2Id.put(key, ids); - - return count; - } - - /** - * 根据给定的频繁项集输出关联规则 - * - * @param frequentItems - * 频繁项集 - */ - public void printAttachRuls(String[] frequentItem) { - // 关联规则前件,后件对 - Map, ArrayList> rules; - // 前件搜索历史 - Map, ArrayList> searchHistory; - ArrayList prefix; - ArrayList suffix; - - rules = new HashMap, ArrayList>(); - searchHistory = new HashMap<>(); - - for (int i = 0; i < frequentItem.length; i++) { - suffix = new ArrayList<>(); - for (int j = 0; j < frequentItem.length; j++) { - suffix.add(frequentItem[j]); - } - prefix = new ArrayList<>(); - - recusiveFindRules(rules, searchHistory, prefix, suffix); - } - - // 依次输出找到的关联规则 - for (Map.Entry, ArrayList> entry : rules - .entrySet()) { - prefix = entry.getKey(); - suffix = entry.getValue(); - - printRuleDetail(prefix, suffix); - } - } - - /** - * 根据前件后件,输出关联规则 - * - * @param prefix - * @param suffix - */ - private void printRuleDetail(ArrayList prefix, - ArrayList suffix) { - // {A}-->{B}的意思为在A的情况下发生B的概率 - System.out.print("{"); - for (String s : prefix) { - System.out.print(s + ", "); - } - System.out.print("}-->"); - System.out.print("{"); - for (String s : suffix) { - System.out.print(s + ", "); - } - System.out.println("}"); - } - - /** - * 递归扩展关联规则解 - * - * @param rules - * 关联规则结果集 - * @param history - * 前件搜索历史 - * @param prefix - * 关联规则前件 - * @param suffix - * 关联规则后件 - */ - private void recusiveFindRules( - Map, ArrayList> rules, - Map, ArrayList> history, - ArrayList prefix, ArrayList suffix) { - int count1; - int count2; - int compareResult; - // 置信度大小 - double conf; - String[] temp1; - String[] temp2; - ArrayList copyPrefix; - ArrayList copySuffix; - - // 如果后件只有1个,则函数返回 - if (suffix.size() == 1) { - return; - } - - for (String s : suffix) { - count1 = 0; - count2 = 0; - - copyPrefix = (ArrayList) prefix.clone(); - copyPrefix.add(s); - - copySuffix = (ArrayList) suffix.clone(); - // 将拷贝的后件移除添加的一项 - copySuffix.remove(s); - - compareResult = isSubSetInRules(history, copyPrefix); - if (compareResult == PREFIX_EQUAL) { - // 如果曾经已经被搜索过,则跳过 - continue; - } - - // 判断是否为子集,如果是子集则无需计算 - compareResult = isSubSetInRules(rules, copyPrefix); - if (compareResult == PREFIX_IS_SUB) { - rules.put(copyPrefix, copySuffix); - // 加入到搜索历史中 - history.put(copyPrefix, copySuffix); - recusiveFindRules(rules, history, copyPrefix, copySuffix); - continue; - } - - // 暂时合并为总的集合 - copySuffix.addAll(copyPrefix); - temp1 = new String[copyPrefix.size()]; - temp2 = new String[copySuffix.size()]; - copyPrefix.toArray(temp1); - copySuffix.toArray(temp2); - // 之后再次移除之前天剑的前件 - copySuffix.removeAll(copyPrefix); - - for (String[] a : totalGoodsIDs) { - if (isStrArrayContain(a, temp1)) { - count1++; - - // 在group1的条件下,统计group2的事件发生次数 - if (isStrArrayContain(a, temp2)) { - count2++; - } - } - } - - conf = 1.0 * count2 / count1; - if (conf > minConf) { - // 设置此前件条件下,能导出关联规则 - rules.put(copyPrefix, copySuffix); - } - - // 加入到搜索历史中 - history.put(copyPrefix, copySuffix); - recusiveFindRules(rules, history, copyPrefix, copySuffix); - } - } - - /** - * 判断当前的前件是否会关联规则的子集 - * - * @param rules - * 当前已经判断出的关联规则 - * @param prefix - * 待判断的前件 - * @return - */ - private int isSubSetInRules( - Map, ArrayList> rules, - ArrayList prefix) { - int result = PREFIX_NOT_SUB; - String[] temp1; - String[] temp2; - ArrayList tempPrefix; - - for (Map.Entry, ArrayList> entry : rules - .entrySet()) { - tempPrefix = entry.getKey(); - - temp1 = new String[tempPrefix.size()]; - temp2 = new String[prefix.size()]; - - tempPrefix.toArray(temp1); - prefix.toArray(temp2); - - // 判断当前构造的前件是否已经是存在前件的子集 - if (isStrArrayContain(temp2, temp1)) { - if (temp2.length == temp1.length) { - result = PREFIX_EQUAL; - } else { - result = PREFIX_IS_SUB; - } - } - - if (result == PREFIX_EQUAL) { - break; - } - } - - return result; - } - - /** - * 数组array2是否包含于array1中,不需要完全一样 - * - * @param array1 - * @param array2 - * @return - */ - private boolean isStrArrayContain(String[] array1, String[] array2) { - boolean isContain = true; - for (String s2 : array2) { - isContain = false; - for (String s1 : array1) { - // 只要s2字符存在于array1中,这个字符就算包含在array1中 - if (s2.equals(s1)) { - isContain = true; - break; - } - } - - // 一旦发现不包含的字符,则array2数组不包含于array1中 - if (!isContain) { - break; - } - } - - return isContain; - } - - /** - * 读关系表中的数据,并转化为事务数据 - * - * @param filePath - */ - private void readRDBMSData(String filePath) { - String str; - // 属性名称行 - String[] attrNames = null; - String[] temp; - String[] newRecord; - ArrayList datas = null; - - datas = readLine(filePath); - - // 获取首行 - attrNames = datas.get(0); - this.transactionDatas = new ArrayList<>(); - - // 去除首行数据 - for (int i = 1; i < datas.size(); i++) { - temp = datas.get(i); - - // 过滤掉首列id列 - for (int j = 1; j < temp.length; j++) { - str = ""; - // 采用属性名+属性值的形式避免数据的重复 - str = attrNames[j] + ":" + temp[j]; - temp[j] = str; - } - - newRecord = new String[attrNames.length - 1]; - System.arraycopy(temp, 1, newRecord, 0, attrNames.length - 1); - this.transactionDatas.add(newRecord); - } - - attributeReplace(); - // 将事务数转到totalGoodsID中做统一处理 - this.totalGoodsIDs = transactionDatas; - } - - /** - * 属性值的替换,替换成数字的形式,以便进行频繁项的挖掘 - */ - private void attributeReplace() { - int currentValue = 1; - String s; - // 属性名到数字的映射图 - attr2Num = new HashMap<>(); - num2Attr = new HashMap<>(); - - // 按照1列列的方式来,从左往右边扫描,跳过列名称行和id列 - for (int j = 0; j < transactionDatas.get(0).length; j++) { - for (int i = 0; i < transactionDatas.size(); i++) { - s = transactionDatas.get(i)[j]; - - if (!attr2Num.containsKey(s)) { - attr2Num.put(s, currentValue); - num2Attr.put(currentValue, s); - - transactionDatas.get(i)[j] = currentValue + ""; - currentValue++; - } else { - transactionDatas.get(i)[j] = attr2Num.get(s) + ""; - } - } - } - } +class MSAprioriTool { + // 前件判断的结果值,用于关联规则的推导 + private static final int PREFIX_NOT_SUB = -1; + private static final int PREFIX_EQUAL = 1; + private static final int PREFIX_IS_SUB = 2; + + // 是否读取的是事务型数据 + private boolean isTransaction; + // 最大频繁k项集的k值 + private int initFItemNum; + // 事务数据文件地址 + private String filePath; + // 最小支持度阈值 + private double minSup; + // 最小置信度率 + private double minConf; + // 最大支持度差别阈值 + private double delta; + // 多项目的最小支持度数,括号中的下标代表的是商品的ID + private double[] mis; + // 每个事务中的商品ID + private ArrayList totalGoodsIDs; + // 关系表数据所转化的事务数据 + private ArrayList transactionDatas; + // 过程中计算出来的所有频繁项集列表 + private ArrayList resultItem; + // 过程中计算出来频繁项集的ID集合 + private ArrayList resultItemID; + // 属性到数字的映射图 + private HashMap attr2Num; + // 数字id对应属性的映射图 + private HashMap num2Attr; + // 频繁项集所覆盖的id数值 + private Map fItem2Id; + + /** + * 事务型数据关联挖掘算法 + * + * @param filePath 文件路径 + * @param minConf + * @param delta + * @param mis + * @param isTransaction 是否是事务 + */ + MSAprioriTool(String filePath, double minConf, double delta, + double[] mis, boolean isTransaction){ + this.filePath = filePath; + this.minConf = minConf; + this.delta = delta; + this.mis = mis; + this.isTransaction = isTransaction; + this.fItem2Id = new HashMap<>(); + + readDataFile(); + } + + /** + * 非事务型关联挖掘 + * + * @param filePath 文件路径 + * @param minConf + * @param minSup + * @param isTransaction 是否是事务 + */ + public MSAprioriTool(String filePath, double minConf, double minSup, + boolean isTransaction){ + this.filePath = filePath; + this.minConf = minConf; + this.minSup = minSup; + this.isTransaction = isTransaction; + this.delta = 1.0; + this.fItem2Id = new HashMap<>(); + + readRDBMSData(filePath); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + String[] temp; + ArrayList dataArray; + + dataArray = readLine(filePath); + totalGoodsIDs = new ArrayList<>(); + + for (String[] array : dataArray) { + temp = new String[array.length - 1]; + System.arraycopy(array, 1, temp, 0, array.length - 1); + + // 将事务ID加入列表吧中 + totalGoodsIDs.add(temp); + } + } + + /** + * 从文件中逐行读数据 + * + * @param filePath 数据文件地址 + */ + private ArrayList readLine(String filePath){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + return dataArray; + } + + /** + * 计算频繁项集 + */ + void calFItems(){ + FrequentItem fItem; + + computeLink(); + printFItems(); + + if (isTransaction) { + fItem = resultItem.get(resultItem.size() - 1); + // 取出最后一个频繁项集做关联规则的推导 + System.out.println("最后一个频繁项集做关联规则的推导结果:"); + printAttachRuls(fItem.getIdArray()); + } + } + + /** + * 输出频繁项集 + */ + private void printFItems(){ + if (isTransaction) { + System.out.println("事务型数据频繁项集输出结果:"); + } else { + System.out.println("非事务(关系)型数据频繁项集输出结果:"); + } + + // 输出频繁项集 + for (int k = 1; k <= initFItemNum; k++) { + System.out.println("频繁" + k + "项集:"); + for (FrequentItem i : resultItem) { + if (i.getLength() == k) { + System.out.print("{"); + for (String t : i.getIdArray()) { + if (!isTransaction) { + // 如果原本是非事务型数据,需要重新做替换 + t = num2Attr.get(Integer.parseInt(t)); + } + + System.out.print(t + ","); + } + System.out.print("},"); + } + } + System.out.println(); + } + } + + /** + * 项集进行连接运算 + */ + private void computeLink(){ + // 连接计算的终止数,k项集必须算到k-1子项集为止 + int endNum; + // 当前已经进行连接运算到几项集,开始时就是1项集 + int currentNum = 1; + // 商品,1频繁项集映射图 + HashMap itemMap = new HashMap<>(); + FrequentItem tempItem; + // 初始列表 + ArrayList list = new ArrayList<>(); + // 经过连接运算后产生的结果项集 + resultItem = new ArrayList<>(); + resultItemID = new ArrayList<>(); + // 商品ID的种类 + ArrayList idType = new ArrayList<>(); + for (String[] a : totalGoodsIDs) { + for (String s : a) { + if (!idType.contains(s)) { + tempItem = new FrequentItem(new String[]{s}, 1); + idType.add(s); + resultItemID.add(new String[]{s}); + } else { + // 支持度计数加1 + tempItem = itemMap.get(s); + tempItem.setCount(tempItem.getCount() + 1); + } + itemMap.put(s, tempItem); + } + } + // 将初始频繁项集转入到列表中,以便继续做连接运算 + for (Map.Entry entry : itemMap.entrySet()) { + tempItem = entry.getValue(); + + // 判断1频繁项集是否满足支持度阈值的条件 + if (judgeFItem(tempItem.getIdArray())) { + list.add(tempItem); + } + } + + // 按照商品ID进行排序,否则连接计算结果将会不一致,将会减少 + Collections.sort(list); + resultItem.addAll(list); + + String[] array1; + String[] array2; + String[] resultArray; + ArrayList tempIds; + ArrayList resultContainer; + // 总共要算到endNum项集 + endNum = list.size() - 1; + initFItemNum = list.size() - 1; + + while (currentNum < endNum) { + resultContainer = new ArrayList<>(); + for (int i = 0; i < list.size() - 1; i++) { + tempItem = list.get(i); + array1 = tempItem.getIdArray(); + + for (int j = i + 1; j < list.size(); j++) { + tempIds = new ArrayList<>(); + array2 = list.get(j).getIdArray(); + + for (int k = 0; k < array1.length; k++) { + // 如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作 + if (array1[k].equals(array2[k])) { + tempIds.add(array1[k]); + } else { + tempIds.add(array1[k]); + tempIds.add(array2[k]); + } + } + + resultArray = new String[tempIds.size()]; + tempIds.toArray(resultArray); + + boolean isContain; + // 过滤不符合条件的的ID数组,包括重复的和长度不符合要求的 + if (resultArray.length == (array1.length + 1)) { + isContain = isIDArrayContains(resultContainer, + resultArray); + if (!isContain) { + resultContainer.add(resultArray); + } + } + } + } + + // 做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集 + list = cutItem(resultContainer); + currentNum++; + } + } + + /** + * 对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集 + */ + private ArrayList cutItem(ArrayList resultIds){ + String[] temp; + // 忽略的索引位置,以此构建子集 + int igNoreIndex; + FrequentItem tempItem; + // 剪枝生成新的频繁项集 + ArrayList newItem = new ArrayList<>(); + // 不符合要求的id + ArrayList deleteIdArray = new ArrayList<>(); + // 子项集是否也为频繁子项集 + boolean isContain = true; + + for (String[] array : resultIds) { + // 列举出其中的一个个的子项集,判断存在于频繁项集列表中 + temp = new String[array.length - 1]; + for (igNoreIndex = 0; igNoreIndex < array.length; igNoreIndex++) { + isContain = true; + for (int j = 0, k = 0; j < array.length; j++) { + if (j != igNoreIndex) { + temp[k] = array[j]; + k++; + } + } + + if (!isIDArrayContains(resultItemID, temp)) { + isContain = false; + break; + } + } + + if (!isContain) { + deleteIdArray.add(array); + } + } + + // 移除不符合条件的ID组合 + resultIds.removeAll(deleteIdArray); + + // 移除支持度计数不够的id集合 + int tempCount = 0; + boolean isSatisfied; + for (String[] array : resultIds) { + isSatisfied = judgeFItem(array); + + // 如果此频繁项集满足多支持度阈值限制条件和支持度差别限制条件,则添加入结果集中 + if (isSatisfied) { + tempItem = new FrequentItem(array, tempCount); + newItem.add(tempItem); + resultItemID.add(array); + resultItem.add(tempItem); + } + } + + return newItem; + } + + /** + * 判断列表结果中是否已经包含此数组 + * + * @param container ID数组容器 + * @param array 待比较数组 + */ + private boolean isIDArrayContains(ArrayList container, + String[] array){ + boolean isContain = true; + if (container.size() == 0) { + return false; + } + + for (String[] s : container) { + // 比较的视乎必须保证长度一样 + if (s.length != array.length) { + continue; + } + + isContain = true; + for (int i = 0; i < s.length; i++) { + // 只要有一个id不等,就算不相等 + if (!Objects.equals(s[i], array[i])) { + isContain = false; + break; + } + } + + // 如果已经判断是包含在容器中时,直接退出 + if (isContain) { + break; + } + } + + return isContain; + } + + /** + * 判断一个频繁项集是否满足条件 + * + * @param frequentItem 待判断频繁项集 + */ + private boolean judgeFItem(String[] frequentItem){ + boolean isSatisfied = true; + int id; + int count; + double tempMinSup; + // 最小的支持度阈值 + double minMis = Integer.MAX_VALUE; + // 最大的支持度阈值 + double maxMis = -Integer.MAX_VALUE; + + // 如果是事务型数据,用mis数组判断,如果不是统一用同样的最小支持度阈值判断 + if (isTransaction) { + // 寻找频繁项集中的最小支持度阈值 + for (int i = 0; i < frequentItem.length; i++) { + id = i + 1; + + if (mis[id] < minMis) { + minMis = mis[id]; + } + + if (mis[id] > maxMis) { + maxMis = mis[id]; + } + } + } else { + minMis = minSup; + maxMis = minSup; + } + + count = calSupportCount(frequentItem); + tempMinSup = 1.0 * count / totalGoodsIDs.size(); + // 判断频繁项集的支持度阈值是否超过最小的支持度阈值 + if (tempMinSup < minMis) { + isSatisfied = false; + } + + // 如果误差超过了最大支持度差别,也算不满足条件 + if (Math.abs(maxMis - minMis) > delta) { + isSatisfied = false; + } + + return isSatisfied; + } + + /** + * 统计候选频繁项集的支持度数,利用他的子集进行技术,无须扫描整个数据集 + * + * @param frequentItem 待计算频繁项集 + */ + private int calSupportCount(String[] frequentItem){ + int count = 0; + int[] ids; + String key; + String[] array; + ArrayList newIds; + + key = ""; + for (int i = 1; i < frequentItem.length; i++) { + key += frequentItem[i]; + } + + newIds = new ArrayList<>(); + // 找出所属的事务ID + ids = fItem2Id.get(key); + + // 如果没有找到子项集的事务id,则全盘扫描数据集 + if (ids == null || ids.length == 0) { + for (int j = 0; j < totalGoodsIDs.size(); j++) { + array = totalGoodsIDs.get(j); + if (isStrArrayContain(array, frequentItem)) { + count++; + newIds.add(j); + } + } + } else { + for (int index : ids) { + array = totalGoodsIDs.get(index); + if (isStrArrayContain(array, frequentItem)) { + count++; + newIds.add(index); + } + } + } + + ids = new int[count]; + for (int i = 0; i < ids.length; i++) { + ids[i] = newIds.get(i); + } + + key = frequentItem[0] + key; + // 将所求值存入图中,便于下次的计数 + fItem2Id.put(key, ids); + + return count; + } + + /** + * 根据给定的频繁项集输出关联规则 + * + * @param frequentItem 频繁项集 + */ + private void printAttachRuls(String[] frequentItem){ + // 关联规则前件,后件对 + Map, ArrayList> rules; + // 前件搜索历史 + Map, ArrayList> searchHistory; + ArrayList prefix; + ArrayList suffix; + + rules = new HashMap<>(); + searchHistory = new HashMap<>(); + + for (String ignored : frequentItem) { + suffix = new ArrayList<>(); + Collections.addAll(suffix, frequentItem); + prefix = new ArrayList<>(); + + recusiveFindRules(rules, searchHistory, prefix, suffix); + } + + // 依次输出找到的关联规则 + for (Map.Entry, ArrayList> entry : rules + .entrySet()) { + prefix = entry.getKey(); + suffix = entry.getValue(); + + printRuleDetail(prefix, suffix); + } + } + + /** + * 根据前件后件,输出关联规则 + * + * @param prefix + * @param suffix + */ + private void printRuleDetail(ArrayList prefix, + ArrayList suffix){ + // {A}-->{B}的意思为在A的情况下发生B的概率 + System.out.print("{"); + for (String s : prefix) { + System.out.print(s + ", "); + } + System.out.print("}-->"); + System.out.print("{"); + for (String s : suffix) { + System.out.print(s + ", "); + } + System.out.println("}"); + } + + /** + * 递归扩展关联规则解 + * + * @param rules 关联规则结果集 + * @param history 前件搜索历史 + * @param prefix 关联规则前件 + * @param suffix 关联规则后件 + */ + private void recusiveFindRules( + Map, ArrayList> rules, + Map, ArrayList> history, + ArrayList prefix, ArrayList suffix){ + int count1; + int count2; + int compareResult; + // 置信度大小 + double conf; + String[] temp1; + String[] temp2; + ArrayList copyPrefix; + ArrayList copySuffix; + + // 如果后件只有1个,则函数返回 + if (suffix.size() == 1) { + return; + } + + for (String s : suffix) { + count1 = 0; + count2 = 0; + + copyPrefix = (ArrayList) prefix.clone(); + copyPrefix.add(s); + + copySuffix = (ArrayList) suffix.clone(); + // 将拷贝的后件移除添加的一项 + copySuffix.remove(s); + + compareResult = isSubSetInRules(history, copyPrefix); + if (compareResult == PREFIX_EQUAL) { + // 如果曾经已经被搜索过,则跳过 + continue; + } + + // 判断是否为子集,如果是子集则无需计算 + compareResult = isSubSetInRules(rules, copyPrefix); + if (compareResult == PREFIX_IS_SUB) { + rules.put(copyPrefix, copySuffix); + // 加入到搜索历史中 + history.put(copyPrefix, copySuffix); + recusiveFindRules(rules, history, copyPrefix, copySuffix); + continue; + } + + // 暂时合并为总的集合 + copySuffix.addAll(copyPrefix); + temp1 = new String[copyPrefix.size()]; + temp2 = new String[copySuffix.size()]; + copyPrefix.toArray(temp1); + copySuffix.toArray(temp2); + // 之后再次移除之前天剑的前件 + copySuffix.removeAll(copyPrefix); + + for (String[] a : totalGoodsIDs) { + if (isStrArrayContain(a, temp1)) { + count1++; + + // 在group1的条件下,统计group2的事件发生次数 + if (isStrArrayContain(a, temp2)) { + count2++; + } + } + } + + conf = 1.0 * count2 / count1; + if (conf > minConf) { + // 设置此前件条件下,能导出关联规则 + rules.put(copyPrefix, copySuffix); + } + + // 加入到搜索历史中 + history.put(copyPrefix, copySuffix); + recusiveFindRules(rules, history, copyPrefix, copySuffix); + } + } + + /** + * 判断当前的前件是否会关联规则的子集 + * + * @param rules 当前已经判断出的关联规则 + * @param prefix 待判断的前件 + */ + private int isSubSetInRules( + Map, ArrayList> rules, + ArrayList prefix){ + int result = PREFIX_NOT_SUB; + String[] temp1; + String[] temp2; + ArrayList tempPrefix; + + for (Map.Entry, ArrayList> entry : rules + .entrySet()) { + tempPrefix = entry.getKey(); + + temp1 = new String[tempPrefix.size()]; + temp2 = new String[prefix.size()]; + + tempPrefix.toArray(temp1); + prefix.toArray(temp2); + + // 判断当前构造的前件是否已经是存在前件的子集 + if (isStrArrayContain(temp2, temp1)) { + if (temp2.length == temp1.length) { + result = PREFIX_EQUAL; + } else { + result = PREFIX_IS_SUB; + } + } + + if (result == PREFIX_EQUAL) { + break; + } + } + + return result; + } + + /** + * 数组array2是否包含于array1中,不需要完全一样 + * + * @param array1 + * @param array2 + * @return + */ + private boolean isStrArrayContain(String[] array1, String[] array2){ + boolean isContain = true; + for (String s2 : array2) { + isContain = false; + for (String s1 : array1) { + // 只要s2字符存在于array1中,这个字符就算包含在array1中 + if (s2.equals(s1)) { + isContain = true; + break; + } + } + + // 一旦发现不包含的字符,则array2数组不包含于array1中 + if (!isContain) { + break; + } + } + + return isContain; + } + + /** + * 读关系表中的数据,并转化为事务数据 + * + * @param filePath 文件路径 + */ + private void readRDBMSData(String filePath){ + String str; + // 属性名称行 + String[] attrNames; + String[] temp; + String[] newRecord; + ArrayList datas; + + datas = readLine(filePath); + + // 获取首行 + attrNames = datas.get(0); + this.transactionDatas = new ArrayList<>(); + + // 去除首行数据 + for (int i = 1; i < datas.size(); i++) { + temp = datas.get(i); + + // 过滤掉首列id列 + for (int j = 1; j < temp.length; j++) { + // 采用属性名+属性值的形式避免数据的重复 + str = attrNames[j] + ":" + temp[j]; + temp[j] = str; + } + + newRecord = new String[attrNames.length - 1]; + System.arraycopy(temp, 1, newRecord, 0, attrNames.length - 1); + this.transactionDatas.add(newRecord); + } + + attributeReplace(); + // 将事务数转到totalGoodsID中做统一处理 + this.totalGoodsIDs = transactionDatas; + } + + /** + * 属性值的替换,替换成数字的形式,以便进行频繁项的挖掘 + */ + private void attributeReplace(){ + int currentValue = 1; + String s; + // 属性名到数字的映射图 + attr2Num = new HashMap<>(); + num2Attr = new HashMap<>(); + + // 按照1列列的方式来,从左往右边扫描,跳过列名称行和id列 + for (int j = 0; j < transactionDatas.get(0).length; j++) { + for (String[] transactionData : transactionDatas) { + s = transactionData[j]; + + if (!attr2Num.containsKey(s)) { + attr2Num.put(s, currentValue); + num2Attr.put(currentValue, s); + + transactionData[j] = currentValue + ""; + currentValue++; + } else { + transactionData[j] = attr2Num.get(s) + ""; + } + } + } + } } diff --git a/Others/DataMining_RandomForest/CARTTool.java b/Others/DataMining_RandomForest/CARTTool.java index d68aab4..1b12a70 100644 --- a/Others/DataMining_RandomForest/CARTTool.java +++ b/Others/DataMining_RandomForest/CARTTool.java @@ -1,511 +1,482 @@ -package DataMining_RandomForest; +package Others.DataMining_RandomForest; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.Queue; +import java.util.*; /** * CART分类回归树算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class CARTTool { - // 类标号的值类型 - private final String YES = "Yes"; - private final String NO = "No"; - - // 所有属性的类型总数,在这里就是data源数据的列数 - private int attrNum; - private String filePath; - // 初始源数据,用一个二维字符数组存放模仿表格数据 - private String[][] data; - // 数据的属性行的名字 - private String[] attrNames; - // 每个属性的值所有类型 - private HashMap> attrValue; - - public CARTTool(ArrayList dataArray) { - attrValue = new HashMap<>(); - readData(dataArray); - } - - /** - * 根据随机选取的样本数据进行初始化 - * @param dataArray - * 已经读入的样本数据 - */ - public void readData(ArrayList dataArray) { - data = new String[dataArray.size()][]; - dataArray.toArray(data); - attrNum = data[0].length; - attrNames = data[0]; - } - - /** - * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用 - */ - public void initAttrValue() { - ArrayList tempValues; - - // 按照列的方式,从左往右找 - for (int j = 1; j < attrNum; j++) { - // 从一列中的上往下开始寻找值 - tempValues = new ArrayList<>(); - for (int i = 1; i < data.length; i++) { - if (!tempValues.contains(data[i][j])) { - // 如果这个属性的值没有添加过,则添加 - tempValues.add(data[i][j]); - } - } - - // 一列属性的值已经遍历完毕,复制到map属性表中 - attrValue.put(data[0][j], tempValues); - } - } - - /** - * 计算机基尼指数 - * - * @param remainData - * 剩余数据 - * @param attrName - * 属性名称 - * @param value - * 属性值 - * @param beLongValue - * 分类是否属于此属性值 - * @return - */ - public double computeGini(String[][] remainData, String attrName, - String value, boolean beLongValue) { - // 实例总数 - int total = 0; - // 正实例数 - int posNum = 0; - // 负实例数 - int negNum = 0; - // 基尼指数 - double gini = 0; - - // 还是按列从左往右遍历属性 - for (int j = 1; j < attrNames.length; j++) { - // 找到了指定的属性 - if (attrName.equals(attrNames[j])) { - for (int i = 1; i < remainData.length; i++) { - // 统计正负实例按照属于和不属于值类型进行划分 - if ((beLongValue && remainData[i][j].equals(value)) - || (!beLongValue && !remainData[i][j].equals(value))) { - if (remainData[i][attrNames.length - 1].equals(YES)) { - // 判断此行数据是否为正实例 - posNum++; - } else { - negNum++; - } - } - } - } - } - - total = posNum + negNum; - double posProbobly = (double) posNum / total; - double negProbobly = (double) negNum / total; - gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly; - - // 返回计算基尼指数 - return gini; - } - - /** - * 计算属性划分的最小基尼指数,返回最小的属性值划分和最小的基尼指数,保存在一个数组中 - * - * @param remainData - * 剩余谁 - * @param attrName - * 属性名称 - * @return - */ - public String[] computeAttrGini(String[][] remainData, String attrName) { - String[] str = new String[2]; - // 最终该属性的划分类型值 - String spiltValue = ""; - // 临时变量 - int tempNum = 0; - // 保存属性的值划分时的最小的基尼指数 - double minGini = Integer.MAX_VALUE; - ArrayList valueTypes = attrValue.get(attrName); - // 属于此属性值的实例数 - HashMap belongNum = new HashMap<>(); - - for (String string : valueTypes) { - // 重新计数的时候,数字归0 - tempNum = 0; - // 按列从左往右遍历属性 - for (int j = 1; j < attrNames.length; j++) { - // 找到了指定的属性 - if (attrName.equals(attrNames[j])) { - for (int i = 1; i < remainData.length; i++) { - // 统计正负实例按照属于和不属于值类型进行划分 - if (remainData[i][j].equals(string)) { - tempNum++; - } - } - } - } - - belongNum.put(string, tempNum); - } - - double tempGini = 0; - double posProbably = 1.0; - double negProbably = 1.0; - for (String string : valueTypes) { - tempGini = 0; - - posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1); - negProbably = 1 - posProbably; - - tempGini += posProbably - * computeGini(remainData, attrName, string, true); - tempGini += negProbably - * computeGini(remainData, attrName, string, false); - - if (tempGini < minGini) { - minGini = tempGini; - spiltValue = string; - } - } - - str[0] = spiltValue; - str[1] = minGini + ""; - - return str; - } - - public void buildDecisionTree(TreeNode node, String parentAttrValue, - String[][] remainData, ArrayList remainAttr, - boolean beLongParentValue) { - // 属性划分值 - String valueType = ""; - // 划分属性名称 - String spiltAttrName = ""; - double minGini = Integer.MAX_VALUE; - double tempGini = 0; - // 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值 - String[] giniArray; - - if (beLongParentValue) { - node.setParentAttrValue(parentAttrValue); - } else { - node.setParentAttrValue("!" + parentAttrValue); - } - - if (remainAttr.size() == 0) { - if (remainData.length > 1) { - ArrayList indexArray = new ArrayList<>(); - for (int i = 1; i < remainData.length; i++) { - indexArray.add(remainData[i][0]); - } - node.setDataIndex(indexArray); - } - // System.out.println("attr remain null"); - return; - } - - for (String str : remainAttr) { - giniArray = computeAttrGini(remainData, str); - tempGini = Double.parseDouble(giniArray[1]); - - if (tempGini < minGini) { - spiltAttrName = str; - minGini = tempGini; - valueType = giniArray[0]; - } - } - // 移除划分属性 - remainAttr.remove(spiltAttrName); - node.setAttrName(spiltAttrName); - - // 孩子节点,分类回归树中,每次二元划分,分出2个孩子节点 - TreeNode[] childNode = new TreeNode[2]; - String[][] rData; - - boolean[] bArray = new boolean[] { true, false }; - for (int i = 0; i < bArray.length; i++) { - // 二元划分属于属性值的划分 - rData = removeData(remainData, spiltAttrName, valueType, bArray[i]); - - boolean sameClass = true; - ArrayList indexArray = new ArrayList<>(); - for (int k = 1; k < rData.length; k++) { - indexArray.add(rData[k][0]); - // 判断是否为同一类的 - if (!rData[k][attrNames.length - 1] - .equals(rData[1][attrNames.length - 1])) { - // 只要有1个不相等,就不是同类型的 - sameClass = false; - break; - } - } - - childNode[i] = new TreeNode(); - if (!sameClass) { - // 创建新的对象属性,对象的同个引用会出错 - ArrayList rAttr = new ArrayList<>(); - for (String str : remainAttr) { - rAttr.add(str); - } - buildDecisionTree(childNode[i], valueType, rData, rAttr, - bArray[i]); - } else { - String pAtr = (bArray[i] ? valueType : "!" + valueType); - childNode[i].setParentAttrValue(pAtr); - childNode[i].setDataIndex(indexArray); - } - } - - node.setChildAttrNode(childNode); - } - - /** - * 属性划分完毕,进行数据的移除 - * - * @param srcData - * 源数据 - * @param attrName - * 划分的属性名称 - * @param valueType - * 属性的值类型 - * @parame beLongValue 分类是否属于此值类型 - */ - private String[][] removeData(String[][] srcData, String attrName, - String valueType, boolean beLongValue) { - String[][] desDataArray; - ArrayList desData = new ArrayList<>(); - // 待删除数据 - ArrayList selectData = new ArrayList<>(); - selectData.add(attrNames); - - // 数组数据转化到列表中,方便移除 - for (int i = 0; i < srcData.length; i++) { - desData.add(srcData[i]); - } - - // 还是从左往右一列列的查找 - for (int j = 1; j < attrNames.length; j++) { - if (attrNames[j].equals(attrName)) { - for (int i = 1; i < desData.size(); i++) { - if (desData.get(i)[j].equals(valueType)) { - // 如果匹配这个数据,则移除其他的数据 - selectData.add(desData.get(i)); - } - } - } - } - - if (beLongValue) { - desDataArray = new String[selectData.size()][]; - selectData.toArray(desDataArray); - } else { - // 属性名称行不移除 - selectData.remove(attrNames); - // 如果是划分不属于此类型的数据时,进行移除 - desData.removeAll(selectData); - desDataArray = new String[desData.size()][]; - desData.toArray(desDataArray); - } - - return desDataArray; - } - - /** - * 构造分类回归树,并返回根节点 - * @return - */ - public TreeNode startBuildingTree() { - initAttrValue(); - - ArrayList remainAttr = new ArrayList<>(); - // 添加属性,除了最后一个类标号属性 - for (int i = 1; i < attrNames.length - 1; i++) { - remainAttr.add(attrNames[i]); - } - - TreeNode rootNode = new TreeNode(); - buildDecisionTree(rootNode, "", data, remainAttr, false); - setIndexAndAlpah(rootNode, 0, false); - showDecisionTree(rootNode, 1); - - return rootNode; - } - - /** - * 显示决策树 - * - * @param node - * 待显示的节点 - * @param blankNum - * 行空格符,用于显示树型结构 - */ - private void showDecisionTree(TreeNode node, int blankNum) { - System.out.println(); - for (int i = 0; i < blankNum; i++) { - System.out.print(" "); - } - System.out.print("--"); - // 显示分类的属性值 - if (node.getParentAttrValue() != null - && node.getParentAttrValue().length() > 0) { - System.out.print(node.getParentAttrValue()); - } else { - System.out.print("--"); - } - System.out.print("--"); - - if (node.getDataIndex() != null && node.getDataIndex().size() > 0) { - String i = node.getDataIndex().get(0); - System.out.print("【" + node.getNodeIndex() + "】类别:" - + data[Integer.parseInt(i)][attrNames.length - 1]); - System.out.print("["); - for (String index : node.getDataIndex()) { - System.out.print(index + ", "); - } - System.out.print("]"); - } else { - // 递归显示子节点 - System.out.print("【" + node.getNodeIndex() + ":" - + node.getAttrName() + "】"); - if (node.getChildAttrNode() != null) { - for (TreeNode childNode : node.getChildAttrNode()) { - showDecisionTree(childNode, 2 * blankNum); - } - } else { - System.out.print("【 Child Null】"); - } - } - } - - /** - * 为节点设置序列号,并计算每个节点的误差率,用于后面剪枝 - * - * @param node - * 开始的时候传入的是根节点 - * @param index - * 开始的索引号,从1开始 - * @param ifCutNode - * 是否需要剪枝 - */ - private void setIndexAndAlpah(TreeNode node, int index, boolean ifCutNode) { - TreeNode tempNode; - // 最小误差代价节点,即将被剪枝的节点 - TreeNode minAlphaNode = null; - double minAlpah = Integer.MAX_VALUE; - Queue nodeQueue = new LinkedList(); - - nodeQueue.add(node); - while (nodeQueue.size() > 0) { - index++; - // 从队列头部获取首个节点 - tempNode = nodeQueue.poll(); - tempNode.setNodeIndex(index); - if (tempNode.getChildAttrNode() != null) { - for (TreeNode childNode : tempNode.getChildAttrNode()) { - nodeQueue.add(childNode); - } - computeAlpha(tempNode); - if (tempNode.getAlpha() < minAlpah) { - minAlphaNode = tempNode; - minAlpah = tempNode.getAlpha(); - } else if (tempNode.getAlpha() == minAlpah) { - // 如果误差代价值一样,比较包含的叶子节点个数,剪枝有多叶子节点数的节点 - if (tempNode.getLeafNum() > minAlphaNode.getLeafNum()) { - minAlphaNode = tempNode; - } - } - } - } - - if (ifCutNode) { - // 进行树的剪枝,让其左右孩子节点为null - minAlphaNode.setChildAttrNode(null); - } - } - - /** - * 为非叶子节点计算误差代价,这里的后剪枝法用的是CCP代价复杂度剪枝 - * - * @param node - * 待计算的非叶子节点 - */ - private void computeAlpha(TreeNode node) { - double rt = 0; - double Rt = 0; - double alpha = 0; - // 当前节点的数据总数 - int sumNum = 0; - // 最少的偏差数 - int minNum = 0; - - ArrayList dataIndex; - ArrayList leafNodes = new ArrayList<>(); - - addLeafNode(node, leafNodes); - node.setLeafNum(leafNodes.size()); - for (TreeNode attrNode : leafNodes) { - dataIndex = attrNode.getDataIndex(); - - int num = 0; - sumNum += dataIndex.size(); - for (String s : dataIndex) { - // 统计分类数据中的正负实例数 - if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) { - num++; - } - } - minNum += num; - - // 取小数量的值部分 - if (1.0 * num / dataIndex.size() > 0.5) { - num = dataIndex.size() - num; - } - - rt += (1.0 * num / (data.length - 1)); - } - - //同样取出少偏差的那部分 - if (1.0 * minNum / sumNum > 0.5) { - minNum = sumNum - minNum; - } - - Rt = 1.0 * minNum / (data.length - 1); - alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1); - node.setAlpha(alpha); - } - - /** - * 筛选出节点所包含的叶子节点数 - * - * @param node - * 待筛选节点 - * @param leafNode - * 叶子节点列表容器 - */ - private void addLeafNode(TreeNode node, ArrayList leafNode) { - ArrayList dataIndex; - - if (node.getChildAttrNode() != null) { - for (TreeNode childNode : node.getChildAttrNode()) { - dataIndex = childNode.getDataIndex(); - if (dataIndex != null && dataIndex.size() > 0) { - // 说明此节点为叶子节点 - leafNode.add(childNode); - } else { - // 如果还是非叶子节点则继续递归调用 - addLeafNode(childNode, leafNode); - } - } - } - } +class CARTTool { + // 类标号的值类型 + private final String YES = "Yes"; + private final String NO = "No"; + + // 所有属性的类型总数,在这里就是data源数据的列数 + private int attrNum; + // 初始源数据,用一个二维字符数组存放模仿表格数据 + private String[][] data; + // 数据的属性行的名字 + private String[] attrNames; + // 每个属性的值所有类型 + private HashMap> attrValue; + + CARTTool(ArrayList dataArray){ + attrValue = new HashMap<>(); + readData(dataArray); + } + + /** + * 根据随机选取的样本数据进行初始化 + * + * @param dataArray 已经读入的样本数据 + */ + private void readData(ArrayList dataArray){ + data = new String[dataArray.size()][]; + dataArray.toArray(data); + attrNum = data[0].length; + attrNames = data[0]; + } + + /** + * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用 + */ + private void initAttrValue(){ + ArrayList tempValues; + + // 按照列的方式,从左往右找 + for (int j = 1; j < attrNum; j++) { + // 从一列中的上往下开始寻找值 + tempValues = new ArrayList<>(); + for (int i = 1; i < data.length; i++) { + if (!tempValues.contains(data[i][j])) { + // 如果这个属性的值没有添加过,则添加 + tempValues.add(data[i][j]); + } + } + + // 一列属性的值已经遍历完毕,复制到map属性表中 + attrValue.put(data[0][j], tempValues); + } + } + + /** + * 计算机基尼指数 + * + * @param remainData 剩余数据 + * @param attrName 属性名称 + * @param value 属性值 + * @param beLongValue 分类是否属于此属性值 + */ + private double computeGini(String[][] remainData, String attrName, + String value, boolean beLongValue){ + // 实例总数 + int total; + // 正实例数 + int posNum = 0; + // 负实例数 + int negNum = 0; + // 基尼指数 + double gini; + + // 还是按列从左往右遍历属性 + for (int j = 1; j < attrNames.length; j++) { + // 找到了指定的属性 + if (attrName.equals(attrNames[j])) { + for (int i = 1; i < remainData.length; i++) { + // 统计正负实例按照属于和不属于值类型进行划分 + if ((beLongValue && remainData[i][j].equals(value)) + || (!beLongValue && !remainData[i][j].equals(value))) { + if (remainData[i][attrNames.length - 1].equals(YES)) { + // 判断此行数据是否为正实例 + posNum++; + } else { + negNum++; + } + } + } + } + } + + total = posNum + negNum; + double posProbobly = (double) posNum / total; + double negProbobly = (double) negNum / total; + gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly; + + // 返回计算基尼指数 + return gini; + } + + /** + * 计算属性划分的最小基尼指数,返回最小的属性值划分和最小的基尼指数,保存在一个数组中 + * + * @param remainData 剩余谁 + * @param attrName 属性名称 + */ + private String[] computeAttrGini(String[][] remainData, String attrName){ + String[] str = new String[2]; + // 最终该属性的划分类型值 + String spiltValue = ""; + // 临时变量 + int tempNum; + // 保存属性的值划分时的最小的基尼指数 + double minGini = Integer.MAX_VALUE; + ArrayList valueTypes = attrValue.get(attrName); + // 属于此属性值的实例数 + HashMap belongNum = new HashMap<>(); + + for (String string : valueTypes) { + // 重新计数的时候,数字归0 + tempNum = 0; + // 按列从左往右遍历属性 + for (int j = 1; j < attrNames.length; j++) { + // 找到了指定的属性 + if (attrName.equals(attrNames[j])) { + for (int i = 1; i < remainData.length; i++) { + // 统计正负实例按照属于和不属于值类型进行划分 + if (remainData[i][j].equals(string)) { + tempNum++; + } + } + } + } + + belongNum.put(string, tempNum); + } + + double tempGini; + double posProbably; + double negProbably; + for (String string : valueTypes) { + tempGini = 0; + + posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1); + negProbably = 1 - posProbably; + + tempGini += posProbably + * computeGini(remainData, attrName, string, true); + tempGini += negProbably + * computeGini(remainData, attrName, string, false); + + if (tempGini < minGini) { + minGini = tempGini; + spiltValue = string; + } + } + + str[0] = spiltValue; + str[1] = minGini + ""; + + return str; + } + + private void buildDecisionTree(TreeNode node, String parentAttrValue, + String[][] remainData, ArrayList remainAttr, + boolean beLongParentValue){ + // 属性划分值 + String valueType = ""; + // 划分属性名称 + String spiltAttrName = ""; + double minGini = Integer.MAX_VALUE; + double tempGini; + // 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值 + String[] giniArray; + + if (beLongParentValue) { + node.setParentAttrValue(parentAttrValue); + } else { + node.setParentAttrValue("!" + parentAttrValue); + } + + if (remainAttr.size() == 0) { + if (remainData.length > 1) { + ArrayList indexArray = new ArrayList<>(); + for (int i = 1; i < remainData.length; i++) { + indexArray.add(remainData[i][0]); + } + node.setDataIndex(indexArray); + } + // System.out.println("attr remain null"); + return; + } + + for (String str : remainAttr) { + giniArray = computeAttrGini(remainData, str); + tempGini = Double.parseDouble(giniArray[1]); + + if (tempGini < minGini) { + spiltAttrName = str; + minGini = tempGini; + valueType = giniArray[0]; + } + } + // 移除划分属性 + remainAttr.remove(spiltAttrName); + node.setAttrName(spiltAttrName); + + // 孩子节点,分类回归树中,每次二元划分,分出2个孩子节点 + TreeNode[] childNode = new TreeNode[2]; + String[][] rData; + + boolean[] bArray = new boolean[]{true, false}; + for (int i = 0; i < bArray.length; i++) { + // 二元划分属于属性值的划分 + rData = removeData(remainData, spiltAttrName, valueType, bArray[i]); + + boolean sameClass = true; + ArrayList indexArray = new ArrayList<>(); + for (int k = 1; k < rData.length; k++) { + indexArray.add(rData[k][0]); + // 判断是否为同一类的 + if (!rData[k][attrNames.length - 1] + .equals(rData[1][attrNames.length - 1])) { + // 只要有1个不相等,就不是同类型的 + sameClass = false; + break; + } + } + + childNode[i] = new TreeNode(); + if (!sameClass) { + // 创建新的对象属性,对象的同个引用会出错 + ArrayList rAttr = new ArrayList<>(); + for (String str : remainAttr) { + rAttr.add(str); + } + buildDecisionTree(childNode[i], valueType, rData, rAttr, + bArray[i]); + } else { + String pAtr = (bArray[i] ? valueType : "!" + valueType); + childNode[i].setParentAttrValue(pAtr); + childNode[i].setDataIndex(indexArray); + } + } + + node.setChildAttrNode(childNode); + } + + /** + * 属性划分完毕,进行数据的移除 + * + * @param srcData 源数据 + * @param attrName 划分的属性名称 + * @param valueType 属性的值类型 + * @parame beLongValue 分类是否属于此值类型 + */ + private String[][] removeData(String[][] srcData, String attrName, + String valueType, boolean beLongValue){ + String[][] desDataArray; + ArrayList desData = new ArrayList<>(); + // 待删除数据 + ArrayList selectData = new ArrayList<>(); + selectData.add(attrNames); + + // 数组数据转化到列表中,方便移除 + Collections.addAll(desData, srcData); + + // 还是从左往右一列列的查找 + for (int j = 1; j < attrNames.length; j++) { + if (attrNames[j].equals(attrName)) { + for (int i = 1; i < desData.size(); i++) { + if (desData.get(i)[j].equals(valueType)) { + // 如果匹配这个数据,则移除其他的数据 + selectData.add(desData.get(i)); + } + } + } + } + + if (beLongValue) { + desDataArray = new String[selectData.size()][]; + selectData.toArray(desDataArray); + } else { + // 属性名称行不移除 + selectData.remove(attrNames); + // 如果是划分不属于此类型的数据时,进行移除 + desData.removeAll(selectData); + desDataArray = new String[desData.size()][]; + desData.toArray(desDataArray); + } + + return desDataArray; + } + + /** + * 构造分类回归树,并返回根节点 + */ + TreeNode startBuildingTree(){ + initAttrValue(); + + ArrayList remainAttr = new ArrayList<>(); + // 添加属性,除了最后一个类标号属性 + remainAttr.addAll(Arrays.asList(attrNames).subList(1, attrNames.length - 1)); + + TreeNode rootNode = new TreeNode(); + buildDecisionTree(rootNode, "", data, remainAttr, false); + setIndexAndAlpah(rootNode, 0, false); + showDecisionTree(rootNode, 1); + + return rootNode; + } + + /** + * 显示决策树 + * + * @param node 待显示的节点 + * @param blankNum 行空格符,用于显示树型结构 + */ + private void showDecisionTree(TreeNode node, int blankNum){ + System.out.println(); + for (int i = 0; i < blankNum; i++) { + System.out.print(" "); + } + System.out.print("--"); + // 显示分类的属性值 + if (node.getParentAttrValue() != null + && node.getParentAttrValue().length() > 0) { + System.out.print(node.getParentAttrValue()); + } else { + System.out.print("--"); + } + System.out.print("--"); + + if (node.getDataIndex() != null && node.getDataIndex().size() > 0) { + String i = node.getDataIndex().get(0); + System.out.print("【" + node.getNodeIndex() + "】类别:" + + data[Integer.parseInt(i)][attrNames.length - 1]); + System.out.print("["); + for (String index : node.getDataIndex()) { + System.out.print(index + ", "); + } + System.out.print("]"); + } else { + // 递归显示子节点 + System.out.print("【" + node.getNodeIndex() + ":" + + node.getAttrName() + "】"); + if (node.getChildAttrNode() != null) { + for (TreeNode childNode : node.getChildAttrNode()) { + showDecisionTree(childNode, 2 * blankNum); + } + } else { + System.out.print("【 Child Null】"); + } + } + } + + /** + * 为节点设置序列号,并计算每个节点的误差率,用于后面剪枝 + * + * @param node 开始的时候传入的是根节点 + * @param index 开始的索引号,从1开始 + * @param ifCutNode 是否需要剪枝 + */ + private void setIndexAndAlpah(TreeNode node, int index, boolean ifCutNode){ + TreeNode tempNode; + // 最小误差代价节点,即将被剪枝的节点 + TreeNode minAlphaNode = null; + double minAlpah = Integer.MAX_VALUE; + Queue nodeQueue = new LinkedList<>(); + + nodeQueue.add(node); + while (nodeQueue.size() > 0) { + index++; + // 从队列头部获取首个节点 + tempNode = nodeQueue.poll(); + tempNode.setNodeIndex(index); + if (tempNode.getChildAttrNode() != null) { + Collections.addAll(nodeQueue, tempNode.getChildAttrNode()); + computeAlpha(tempNode); + if (tempNode.getAlpha() < minAlpah) { + minAlphaNode = tempNode; + minAlpah = tempNode.getAlpha(); + } else if (tempNode.getAlpha() == minAlpah) { + // 如果误差代价值一样,比较包含的叶子节点个数,剪枝有多叶子节点数的节点 + if (minAlphaNode != null && tempNode.getLeafNum() > minAlphaNode.getLeafNum()) { + minAlphaNode = tempNode; + } + } + } + } + + if (ifCutNode) { + // 进行树的剪枝,让其左右孩子节点为null + if (minAlphaNode != null) { + minAlphaNode.setChildAttrNode(null); + } + } + } + + /** + * 为非叶子节点计算误差代价,这里的后剪枝法用的是CCP代价复杂度剪枝 + * + * @param node 待计算的非叶子节点 + */ + private void computeAlpha(TreeNode node){ + double rt = 0; + double Rt; + double alpha; + // 当前节点的数据总数 + int sumNum = 0; + // 最少的偏差数 + int minNum = 0; + + ArrayList dataIndex; + ArrayList leafNodes = new ArrayList<>(); + + addLeafNode(node, leafNodes); + node.setLeafNum(leafNodes.size()); + for (TreeNode attrNode : leafNodes) { + dataIndex = attrNode.getDataIndex(); + + int num = 0; + sumNum += dataIndex.size(); + for (String s : dataIndex) { + // 统计分类数据中的正负实例数 + if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) { + num++; + } + } + minNum += num; + + // 取小数量的值部分 + if (1.0 * num / dataIndex.size() > 0.5) { + num = dataIndex.size() - num; + } + + rt += (1.0 * num / (data.length - 1)); + } + + //同样取出少偏差的那部分 + if (1.0 * minNum / sumNum > 0.5) { + minNum = sumNum - minNum; + } + + Rt = 1.0 * minNum / (data.length - 1); + alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1); + node.setAlpha(alpha); + } + + /** + * 筛选出节点所包含的叶子节点数 + * + * @param node 待筛选节点 + * @param leafNode 叶子节点列表容器 + */ + private void addLeafNode(TreeNode node, ArrayList leafNode){ + ArrayList dataIndex; + + if (node.getChildAttrNode() != null) { + for (TreeNode childNode : node.getChildAttrNode()) { + dataIndex = childNode.getDataIndex(); + if (dataIndex != null && dataIndex.size() > 0) { + // 说明此节点为叶子节点 + leafNode.add(childNode); + } else { + // 如果还是非叶子节点则继续递归调用 + addLeafNode(childNode, leafNode); + } + } + } + } } diff --git a/Others/DataMining_RandomForest/Client.java b/Others/DataMining_RandomForest/Client.java index 6139d3e..8515c11 100644 --- a/Others/DataMining_RandomForest/Client.java +++ b/Others/DataMining_RandomForest/Client.java @@ -1,33 +1,32 @@ -package DataMining_RandomForest; +package Others.DataMining_RandomForest; import java.text.MessageFormat; /** * 随机森林算法测试场景 - * - * @author lyq - * + * + * @author Qstar */ public class Client { - public static void main(String[] args) { - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - String queryStr = "Age=Youth,Income=Low,Student=No,CreditRating=Fair"; - String resultClassType = ""; - // 决策树的样本占总数的占比率 - double sampleNumRatio = 0.4; - // 样本数据的采集特征数量占总特征的比例 - double featureNumRatio = 0.5; + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_RandomForest/input.txt"; + String queryStr = "Age=Youth,Income=Low,Student=No,CreditRating=Fair"; + String resultClassType; + // 决策树的样本占总数的占比率 + double sampleNumRatio = 0.4; + // 样本数据的采集特征数量占总特征的比例 + double featureNumRatio = 0.5; - RandomForestTool tool = new RandomForestTool(filePath, sampleNumRatio, - featureNumRatio); - tool.constructRandomTree(); + RandomForestTool tool = new RandomForestTool(filePath, sampleNumRatio, + featureNumRatio); + tool.constructRandomTree(); - resultClassType = tool.judgeClassType(queryStr); + resultClassType = tool.judgeClassType(queryStr); - System.out.println(); - System.out - .println(MessageFormat.format( - "查询属性描述{0},预测的分类结果为BuysCompute:{1}", queryStr, - resultClassType)); - } + System.out.println(); + System.out + .println(MessageFormat.format( + "查询属性描述{0},预测的分类结果为BuysCompute:{1}", queryStr, + resultClassType)); + } } diff --git a/Others/DataMining_RandomForest/DecisionTree.java b/Others/DataMining_RandomForest/DecisionTree.java index 119254e..58e2f90 100644 --- a/Others/DataMining_RandomForest/DecisionTree.java +++ b/Others/DataMining_RandomForest/DecisionTree.java @@ -1,4 +1,4 @@ -package DataMining_RandomForest; +package Others.DataMining_RandomForest; import java.util.ArrayList; import java.util.HashMap; @@ -6,160 +6,155 @@ /** * 决策树 - * - * @author lyq - * + * + * @author Qstar */ -public class DecisionTree { - // 树的根节点 - TreeNode rootNode; - // 数据的属性列名称 - String[] featureNames; - // 这棵树所包含的数据 - ArrayList datas; - // 决策树构造的的工具类 - CARTTool tool; - - public DecisionTree(ArrayList datas) { - this.datas = datas; - this.featureNames = datas.get(0); - - tool = new CARTTool(datas); - // 通过CART工具类进行决策树的构建,并返回树的根节点 - rootNode = tool.startBuildingTree(); - } - - /** - * 根据给定的数据特征描述进行类别的判断 - * - * @param features - * @return - */ - public String decideClassType(String features) { - String classType = ""; - // 查询属性组 - String[] queryFeatures; - // 在本决策树中对应的查询的属性值描述 - ArrayList featureStrs; - - featureStrs = new ArrayList<>(); - queryFeatures = features.split(","); - - String[] array; - for (String name : featureNames) { - for (String featureValue : queryFeatures) { - array = featureValue.split("="); - // 将对应的属性值加入到列表中 - if (array[0].equals(name)) { - featureStrs.add(array); - } - } - } - - // 开始从根据节点往下递归搜索 - classType = recusiveSearchClassType(rootNode, featureStrs); - - return classType; - } - - /** - * 递归搜索树,查询属性的分类类别 - * - * @param node - * 当前搜索到的节点 - * @param remainFeatures - * 剩余未判断的属性 - * @return - */ - private String recusiveSearchClassType(TreeNode node, - ArrayList remainFeatures) { - String classType = null; - - // 如果节点包含了数据的id索引,说明已经分类到底了 - if (node.getDataIndex() != null && node.getDataIndex().size() > 0) { - classType = judgeClassType(node.getDataIndex()); - - return classType; - } - - // 取出剩余属性中的一个匹配属性作为当前的判断属性名称 - String[] currentFeature = null; - for (String[] featureValue : remainFeatures) { - if (node.getAttrName().equals(featureValue[0])) { - currentFeature = featureValue; - break; - } - } - - for (TreeNode childNode : node.getChildAttrNode()) { - // 寻找子节点中属于此属性值的分支 - if (childNode.getParentAttrValue().equals(currentFeature[1])) { - remainFeatures.remove(currentFeature); - classType = recusiveSearchClassType(childNode, remainFeatures); - - // 如果找到了分类结果,则直接挑出循环 - break; - }else{ - //进行第二种情况的判断加上!符号的情况 - String value = childNode.getParentAttrValue(); - - if(value.charAt(0) == '!'){ - //去掉第一个!字符 - value = value.substring(1, value.length()); - - if(!value.equals(currentFeature[1])){ - remainFeatures.remove(currentFeature); - classType = recusiveSearchClassType(childNode, remainFeatures); - - break; - } - } - } - } - - return classType; - } - - /** - * 根据得到的数据行分类进行类别的决策 - * - * @param dataIndex - * 根据分类的数据索引号 - * @return - */ - public String judgeClassType(ArrayList dataIndex) { - // 结果类型值 - String resultClassType = ""; - String classType = ""; - int count = 0; - int temp = 0; - Map type2Num = new HashMap(); - - for (String index : dataIndex) { - temp = Integer.parseInt(index); - // 取最后一列的决策类别数据 - classType = datas.get(temp)[featureNames.length - 1]; - - if (type2Num.containsKey(classType)) { - // 如果类别已经存在,则使其计数加1 - count = type2Num.get(classType); - count++; - } else { - count = 1; - } - - type2Num.put(classType, count); - } - - // 选出其中类别支持计数最多的一个类别值 - count = -1; - for (Map.Entry entry : type2Num.entrySet()) { - if ((int) entry.getValue() > count) { - count = (int) entry.getValue(); - resultClassType = (String) entry.getKey(); - } - } - - return resultClassType; - } +class DecisionTree { + // 树的根节点 + private TreeNode rootNode; + // 数据的属性列名称 + private String[] featureNames; + // 这棵树所包含的数据 + private ArrayList datas; + // 决策树构造的的工具类 + private CARTTool tool; + + DecisionTree(ArrayList datas){ + this.datas = datas; + this.featureNames = datas.get(0); + + tool = new CARTTool(datas); + // 通过CART工具类进行决策树的构建,并返回树的根节点 + rootNode = tool.startBuildingTree(); + } + + /** + * 根据给定的数据特征描述进行类别的判断 + * + * @param features 数据特征 + */ + String decideClassType(String features){ + String classType; + // 查询属性组 + String[] queryFeatures; + // 在本决策树中对应的查询的属性值描述 + ArrayList featureStrs; + + featureStrs = new ArrayList<>(); + queryFeatures = features.split(","); + + String[] array; + for (String name : featureNames) { + for (String featureValue : queryFeatures) { + array = featureValue.split("="); + // 将对应的属性值加入到列表中 + if (array[0].equals(name)) { + featureStrs.add(array); + } + } + } + + // 开始从根据节点往下递归搜索 + classType = recusiveSearchClassType(rootNode, featureStrs); + + return classType; + } + + /** + * 递归搜索树,查询属性的分类类别 + * + * @param node 当前搜索到的节点 + * @param remainFeatures 剩余未判断的属性 + */ + private String recusiveSearchClassType(TreeNode node, + ArrayList remainFeatures){ + String classType = null; + + // 如果节点包含了数据的id索引,说明已经分类到底了 + if (node.getDataIndex() != null && node.getDataIndex().size() > 0) { + classType = judgeClassType(node.getDataIndex()); + + return classType; + } + + // 取出剩余属性中的一个匹配属性作为当前的判断属性名称 + String[] currentFeature = null; + for (String[] featureValue : remainFeatures) { + if (node.getAttrName().equals(featureValue[0])) { + currentFeature = featureValue; + break; + } + } + + for (TreeNode childNode : node.getChildAttrNode()) { + // 寻找子节点中属于此属性值的分支 + if (currentFeature != null) { + if (childNode.getParentAttrValue().equals(currentFeature[1])) { + remainFeatures.remove(currentFeature); + classType = recusiveSearchClassType(childNode, remainFeatures); + + // 如果找到了分类结果,则直接挑出循环 + break; + } else { + //进行第二种情况的判断加上!符号的情况 + String value = childNode.getParentAttrValue(); + + if (value.charAt(0) == '!') { + //去掉第一个!字符 + value = value.substring(1, value.length()); + + if (!value.equals(currentFeature[1])) { + remainFeatures.remove(currentFeature); + classType = recusiveSearchClassType(childNode, remainFeatures); + + break; + } + } + } + } + } + + return classType; + } + + /** + * 根据得到的数据行分类进行类别的决策 + * + * @param dataIndex 根据分类的数据索引号 + */ + private String judgeClassType(ArrayList dataIndex){ + // 结果类型值 + String resultClassType = ""; + String classType; + int count; + int temp; + Map type2Num = new HashMap<>(); + + for (String index : dataIndex) { + temp = Integer.parseInt(index); + // 取最后一列的决策类别数据 + classType = datas.get(temp)[featureNames.length - 1]; + + if (type2Num.containsKey(classType)) { + // 如果类别已经存在,则使其计数加1 + count = type2Num.get(classType); + count++; + } else { + count = 1; + } + + type2Num.put(classType, count); + } + + // 选出其中类别支持计数最多的一个类别值 + count = -1; + for (Map.Entry entry : type2Num.entrySet()) { + if ((int) entry.getValue() > count) { + count = (int) entry.getValue(); + resultClassType = (String) entry.getKey(); + } + } + + return resultClassType; + } } diff --git a/Others/DataMining_RandomForest/RandomForestTool.java b/Others/DataMining_RandomForest/RandomForestTool.java index a244cd9..9290a5f 100644 --- a/Others/DataMining_RandomForest/RandomForestTool.java +++ b/Others/DataMining_RandomForest/RandomForestTool.java @@ -1,4 +1,4 @@ -package DataMining_RandomForest; +package Others.DataMining_RandomForest; import java.io.BufferedReader; import java.io.File; @@ -11,213 +11,209 @@ /** * 随机森林算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class RandomForestTool { - // 测试数据文件地址 - private String filePath; - // 决策树的样本占总数的占比率 - private double sampleNumRatio; - // 样本数据的采集特征数量占总特征的比例 - private double featureNumRatio; - // 决策树的采样样本数 - private int sampleNum; - // 样本数据的采集采样特征数 - private int featureNum; - // 随机森林中的决策树的数目,等于总的数据数/用于构造每棵树的数据的数量 - private int treeNum; - // 随机数产生器 - private Random random; - // 样本数据列属性名称行 - private String[] featureNames; - // 原始的总的数据 - private ArrayList totalDatas; - // 决策树森林 - private ArrayList decisionForest; - - public RandomForestTool(String filePath, double sampleNumRatio, - double featureNumRatio) { - this.filePath = filePath; - this.sampleNumRatio = sampleNumRatio; - this.featureNumRatio = featureNumRatio; - - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - totalDatas = dataArray; - featureNames = totalDatas.get(0); - sampleNum = (int) ((totalDatas.size() - 1) * sampleNumRatio); - //算属性数量的时候需要去掉id属性和决策属性,用条件属性计算 - featureNum = (int) ((featureNames.length -2) * featureNumRatio); - // 算数量的时候需要去掉首行属性名称行 - treeNum = (totalDatas.size() - 1) / sampleNum; - } - - /** - * 产生决策树 - */ - private DecisionTree produceDecisionTree() { - int temp = 0; - DecisionTree tree; - String[] tempData; - //采样数据的随机行号组 - ArrayList sampleRandomNum; - //采样属性特征的随机列号组 - ArrayList featureRandomNum; - ArrayList datas; - - sampleRandomNum = new ArrayList<>(); - featureRandomNum = new ArrayList<>(); - datas = new ArrayList<>(); - - for(int i=0; i 0){ - array[0] = temp + ""; - } - - temp++; - } - - tree = new DecisionTree(datas); - - return tree; - } - - /** - * 构造随机森林 - */ - public void constructRandomTree() { - DecisionTree tree; - random = new Random(); - decisionForest = new ArrayList<>(); - - System.out.println("下面是随机森林中的决策树:"); - // 构造决策树加入森林中 - for (int i = 0; i < treeNum; i++) { - System.out.println("\n决策树" + (i+1)); - tree = produceDecisionTree(); - decisionForest.add(tree); - } - } - - /** - * 根据给定的属性条件进行类别的决策 - * - * @param features - * 给定的已知的属性描述 - * @return - */ - public String judgeClassType(String features) { - // 结果类型值 - String resultClassType = ""; - String classType = ""; - int count = 0; - Map type2Num = new HashMap(); - - for (DecisionTree tree : decisionForest) { - classType = tree.decideClassType(features); - if (type2Num.containsKey(classType)) { - // 如果类别已经存在,则使其计数加1 - count = type2Num.get(classType); - count++; - } else { - count = 1; - } - - type2Num.put(classType, count); - } - - // 选出其中类别支持计数最多的一个类别值 - count = -1; - for (Map.Entry entry : type2Num.entrySet()) { - if ((int) entry.getValue() > count) { - count = (int) entry.getValue(); - resultClassType = (String) entry.getKey(); - } - } - - return resultClassType; - } +class RandomForestTool { + // 测试数据文件地址 + private String filePath; + // 决策树的样本占总数的占比率 + private double sampleNumRatio; + // 样本数据的采集特征数量占总特征的比例 + private double featureNumRatio; + // 决策树的采样样本数 + private int sampleNum; + // 样本数据的采集采样特征数 + private int featureNum; + // 随机森林中的决策树的数目,等于总的数据数/用于构造每棵树的数据的数量 + private int treeNum; + // 随机数产生器 + private Random random; + // 样本数据列属性名称行 + private String[] featureNames; + // 原始的总的数据 + private ArrayList totalDatas; + // 决策树森林 + private ArrayList decisionForest; + + RandomForestTool(String filePath, double sampleNumRatio, + double featureNumRatio){ + this.filePath = filePath; + this.sampleNumRatio = sampleNumRatio; + this.featureNumRatio = featureNumRatio; + + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + totalDatas = dataArray; + featureNames = totalDatas.get(0); + sampleNum = (int) ((totalDatas.size() - 1) * sampleNumRatio); + //算属性数量的时候需要去掉id属性和决策属性,用条件属性计算 + featureNum = (int) ((featureNames.length - 2) * featureNumRatio); + // 算数量的时候需要去掉首行属性名称行 + treeNum = (totalDatas.size() - 1) / sampleNum; + } + + /** + * 产生决策树 + */ + private DecisionTree produceDecisionTree(){ + int temp; + DecisionTree tree; + String[] tempData; + //采样数据的随机行号组 + ArrayList sampleRandomNum; + //采样属性特征的随机列号组 + ArrayList featureRandomNum; + ArrayList datas; + + sampleRandomNum = new ArrayList<>(); + featureRandomNum = new ArrayList<>(); + datas = new ArrayList<>(); + + for (int i = 0; i < sampleNum; ) { + temp = random.nextInt(totalDatas.size()); + + //如果是行首属性名称行,则跳过 + if (temp == 0) { + continue; + } + + if (!sampleRandomNum.contains(temp)) { + sampleRandomNum.add(temp); + i++; + } + } + + for (int i = 0; i < featureNum; ) { + temp = random.nextInt(featureNames.length); + + //如果是第一列的数据id号或者是决策属性列,则跳过 + if (temp == 0 || temp == featureNames.length - 1) { + continue; + } + + if (!featureRandomNum.contains(temp)) { + featureRandomNum.add(temp); + i++; + } + } + + String[] singleRecord; + String[] headCulumn = null; + // 获取随机数据行 + for (int dataIndex : sampleRandomNum) { + singleRecord = totalDatas.get(dataIndex); + + //每行的列数=所选的特征数+id号 + tempData = new String[featureNum + 2]; + headCulumn = new String[featureNum + 2]; + + for (int i = 0, k = 1; i < featureRandomNum.size(); i++, k++) { + temp = featureRandomNum.get(i); + + headCulumn[k] = featureNames[temp]; + tempData[k] = singleRecord[temp]; + } + + //加上id列的信息 + headCulumn[0] = featureNames[0]; + //加上决策分类列的信息 + headCulumn[featureNum + 1] = featureNames[featureNames.length - 1]; + tempData[featureNum + 1] = singleRecord[featureNames.length - 1]; + + //加入此行数据 + datas.add(tempData); + } + + //加入行首列出现名称 + datas.add(0, headCulumn); + //对筛选出的数据重新做id分配 + temp = 0; + for (String[] array : datas) { + //从第2行开始赋值 + if (temp > 0) { + array[0] = temp + ""; + } + + temp++; + } + + tree = new DecisionTree(datas); + + return tree; + } + + /** + * 构造随机森林 + */ + void constructRandomTree(){ + DecisionTree tree; + random = new Random(); + decisionForest = new ArrayList<>(); + + System.out.println("下面是随机森林中的决策树:"); + // 构造决策树加入森林中 + for (int i = 0; i < treeNum; i++) { + System.out.println("\n决策树" + (i + 1)); + tree = produceDecisionTree(); + decisionForest.add(tree); + } + } + + /** + * 根据给定的属性条件进行类别的决策 + * + * @param features 给定的已知的属性描述 + */ + String judgeClassType(String features){ + // 结果类型值 + String resultClassType = ""; + String classType; + int count; + Map type2Num = new HashMap<>(); + + for (DecisionTree tree : decisionForest) { + classType = tree.decideClassType(features); + if (type2Num.containsKey(classType)) { + // 如果类别已经存在,则使其计数加1 + count = type2Num.get(classType); + count++; + } else { + count = 1; + } + + type2Num.put(classType, count); + } + + // 选出其中类别支持计数最多的一个类别值 + count = -1; + for (Map.Entry entry : type2Num.entrySet()) { + if ((int) entry.getValue() > count) { + count = (int) entry.getValue(); + resultClassType = (String) entry.getKey(); + } + } + return resultClassType; + } } diff --git a/Others/DataMining_RandomForest/TreeNode.java b/Others/DataMining_RandomForest/TreeNode.java index b118472..96f5f96 100644 --- a/Others/DataMining_RandomForest/TreeNode.java +++ b/Others/DataMining_RandomForest/TreeNode.java @@ -1,85 +1,83 @@ -package DataMining_RandomForest; +package Others.DataMining_RandomForest; import java.util.ArrayList; /** * 回归分类树节点 - * - * @author lyq - * + * + * @author Qstar */ -public class TreeNode { - // 节点属性名字 - private String attrName; - // 节点索引标号 - private int nodeIndex; - //包含的叶子节点数 - private int leafNum; - // 节点误差率 - private double alpha; - // 父亲分类属性值 - private String parentAttrValue; - // 孩子节点 - private TreeNode[] childAttrNode; - // 数据记录索引 - private ArrayList dataIndex; - - public String getAttrName() { - return attrName; - } - - public void setAttrName(String attrName) { - this.attrName = attrName; - } - - public int getNodeIndex() { - return nodeIndex; - } - - public void setNodeIndex(int nodeIndex) { - this.nodeIndex = nodeIndex; - } - - public double getAlpha() { - return alpha; - } - - public void setAlpha(double alpha) { - this.alpha = alpha; - } - - public String getParentAttrValue() { - return parentAttrValue; - } - - public void setParentAttrValue(String parentAttrValue) { - this.parentAttrValue = parentAttrValue; - } - - public TreeNode[] getChildAttrNode() { - return childAttrNode; - } - - public void setChildAttrNode(TreeNode[] childAttrNode) { - this.childAttrNode = childAttrNode; - } - - public ArrayList getDataIndex() { - return dataIndex; - } - - public void setDataIndex(ArrayList dataIndex) { - this.dataIndex = dataIndex; - } - - public int getLeafNum() { - return leafNum; - } - - public void setLeafNum(int leafNum) { - this.leafNum = leafNum; - } - - - +class TreeNode { + // 节点属性名字 + private String attrName; + // 节点索引标号 + private int nodeIndex; + //包含的叶子节点数 + private int leafNum; + // 节点误差率 + private double alpha; + // 父亲分类属性值 + private String parentAttrValue; + // 孩子节点 + private TreeNode[] childAttrNode; + // 数据记录索引 + private ArrayList dataIndex; + + public String getAttrName(){ + return attrName; + } + + public void setAttrName(String attrName){ + this.attrName = attrName; + } + + int getNodeIndex(){ + return nodeIndex; + } + + void setNodeIndex(int nodeIndex){ + this.nodeIndex = nodeIndex; + } + + double getAlpha(){ + return alpha; + } + + void setAlpha(double alpha){ + this.alpha = alpha; + } + + String getParentAttrValue(){ + return parentAttrValue; + } + + void setParentAttrValue(String parentAttrValue){ + this.parentAttrValue = parentAttrValue; + } + + TreeNode[] getChildAttrNode(){ + return childAttrNode; + } + + void setChildAttrNode(TreeNode[] childAttrNode){ + this.childAttrNode = childAttrNode; + } + + ArrayList getDataIndex(){ + return dataIndex; + } + + void setDataIndex(ArrayList dataIndex){ + this.dataIndex = dataIndex; + } + + int getLeafNum(){ + return leafNum; + } + + void setLeafNum(int leafNum){ + this.leafNum = leafNum; + } + + } diff --git a/Others/DataMining_TAN/AttrMutualInfo.java b/Others/DataMining_TAN/AttrMutualInfo.java index 6caf12d..80f09b8 100644 --- a/Others/DataMining_TAN/AttrMutualInfo.java +++ b/Others/DataMining_TAN/AttrMutualInfo.java @@ -1,28 +1,28 @@ -package DataMining_TAN; +package Others.DataMining_TAN; /** * 属性之间的互信息值,表示属性之间的关联性大小 - * @author lyq * + * @author Qstar */ -public class AttrMutualInfo implements Comparable{ - //互信息值 - Double value; - //关联属性值对 - Node[] nodeArray; - - public AttrMutualInfo(double value, Node node1, Node node2){ - this.value = value; - - this.nodeArray = new Node[2]; - this.nodeArray[0] = node1; - this.nodeArray[1] = node2; - } +class AttrMutualInfo implements Comparable { + //关联属性值对 + Node[] nodeArray; + //互信息值 + private Double value; + + AttrMutualInfo(double value, Node node1, Node node2){ + this.value = value; + + this.nodeArray = new Node[2]; + this.nodeArray[0] = node1; + this.nodeArray[1] = node2; + } + + @Override + public int compareTo(AttrMutualInfo o){ + // TODO Auto-generated method stub + return o.value.compareTo(this.value); + } - @Override - public int compareTo(AttrMutualInfo o) { - // TODO Auto-generated method stub - return o.value.compareTo(this.value); - } - } diff --git a/Others/DataMining_TAN/Client.java b/Others/DataMining_TAN/Client.java index bd104bc..69b1830 100644 --- a/Others/DataMining_TAN/Client.java +++ b/Others/DataMining_TAN/Client.java @@ -1,36 +1,35 @@ -package DataMining_TAN; +package Others.DataMining_TAN; /** * TAN树型朴素贝叶斯算法 - * - * @author lyq - * + * + * @author Qstar */ public class Client { - public static void main(String[] args) { - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - // 条件查询语句 - String queryStr; - // 分类结果概率1 - double classResult1; - // 分类结果概率2 - double classResult2; + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_TAN/input.txt"; + // 条件查询语句 + String queryStr; + // 分类结果概率1 + double classResult1; + // 分类结果概率2 + double classResult2; - TANTool tool = new TANTool(filePath); - queryStr = "OutLook=Sunny,Temperature=Hot,Humidity=High,Wind=Weak,PlayTennis=No"; - classResult1 = tool.calHappenedPro(queryStr); + TANTool tool = new TANTool(filePath); + queryStr = "OutLook=Sunny,Temperature=Hot,Humidity=High,Wind=Weak,PlayTennis=No"; + classResult1 = tool.calHappenedPro(queryStr); - queryStr = "OutLook=Sunny,Temperature=Hot,Humidity=High,Wind=Weak,PlayTennis=Yes"; - classResult2 = tool.calHappenedPro(queryStr); + queryStr = "OutLook=Sunny,Temperature=Hot,Humidity=High,Wind=Weak,PlayTennis=Yes"; + classResult2 = tool.calHappenedPro(queryStr); - System.out.println(String.format("类别为%s所求得的概率为%s", "PlayTennis=No", - classResult1)); - System.out.println(String.format("类别为%s所求得的概率为%s", "PlayTennis=Yes", - classResult2)); - if (classResult1 > classResult2) { - System.out.println("分类类别为PlayTennis=No"); - } else { - System.out.println("分类类别为PlayTennis=Yes"); - } - } + System.out.println(String.format("类别为%s所求得的概率为%s", "PlayTennis=No", + classResult1)); + System.out.println(String.format("类别为%s所求得的概率为%s", "PlayTennis=Yes", + classResult2)); + if (classResult1 > classResult2) { + System.out.println("分类类别为PlayTennis=No"); + } else { + System.out.println("分类类别为PlayTennis=Yes"); + } + } } diff --git a/Others/DataMining_TAN/Node.java b/Others/DataMining_TAN/Node.java index f3a3b51..ec3b32b 100644 --- a/Others/DataMining_TAN/Node.java +++ b/Others/DataMining_TAN/Node.java @@ -1,63 +1,59 @@ -package DataMining_TAN; +package Others.DataMining_TAN; import java.util.ArrayList; /** * 贝叶斯网络节点类 - * - * @author lyq - * + * + * @author Qstar */ -public class Node { - //节点唯一id,方便后面节点连接方向的确定 - int id; - // 节点的属性名称 - String name; - // 该节点所连续的节点 - ArrayList connectedNodes; - - public Node(int id, String name) { - this.id = id; - this.name = name; - - // 初始化变量 - this.connectedNodes = new ArrayList<>(); - } - - /** - * 将自身节点连接到目标给定的节点 - * - * @param node - * 下游节点 - */ - public void connectNode(Node node) { - //避免连接自身 - if(this.id == node.id){ - return; - } - - // 将节点加入自身节点的节点列表中 - this.connectedNodes.add(node); - // 将自身节点加入到目标节点的列表中 - node.connectedNodes.add(this); - } - - /** - * 判断与目标节点是否相同,主要比较名称是否相同即可 - * - * @param node - * 目标结点 - * @return - */ - public boolean isEqual(Node node) { - boolean isEqual; - - isEqual = false; - // 节点名称相同则视为相等 - if (this.id == node.id) { - isEqual = true; - } - - return isEqual; - } +class Node { + //节点唯一id,方便后面节点连接方向的确定 + int id; + // 节点的属性名称 + String name; + // 该节点所连续的节点 + ArrayList connectedNodes; + + Node(int id, String name){ + this.id = id; + this.name = name; + + // 初始化变量 + this.connectedNodes = new ArrayList<>(); + } + + /** + * 将自身节点连接到目标给定的节点 + * + * @param node 下游节点 + */ + void connectNode(Node node){ + //避免连接自身 + if (this.id == node.id) { + return; + } + + // 将节点加入自身节点的节点列表中 + this.connectedNodes.add(node); + // 将自身节点加入到目标节点的列表中 + node.connectedNodes.add(this); + } + + /** + * 判断与目标节点是否相同,主要比较名称是否相同即可 + * + * @param node 目标结点 + */ + boolean isEqual(Node node){ + boolean isEqual; + + isEqual = false; + // 节点名称相同则视为相等 + if (this.id == node.id) { + isEqual = true; + } + + return isEqual; + } } diff --git a/Others/DataMining_TAN/TANTool.java b/Others/DataMining_TAN/TANTool.java index 56e90a6..3b7e394 100644 --- a/Others/DataMining_TAN/TANTool.java +++ b/Others/DataMining_TAN/TANTool.java @@ -1,4 +1,4 @@ -package DataMining_TAN; +package Others.DataMining_TAN; import java.io.BufferedReader; import java.io.File; @@ -7,565 +7,543 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.stream.Collectors; /** * TAN树型朴素贝叶斯算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class TANTool { - // 测试数据集地址 - private String filePath; - // 数据集属性总数,其中一个个分类属性 - private int attrNum; - // 分类属性名 - private String classAttrName; - // 属性列名称行 - private String[] attrNames; - // 贝叶斯网络边的方向,数组内的数值为节点id,从i->j - private int[][] edges; - // 属性名到列下标的映射 - private HashMap attr2Column; - // 属性,属性对取值集合映射对 - private HashMap> attr2Values; - // 贝叶斯网络总节点列表 - private ArrayList totalNodes; - // 总的测试数据 - private ArrayList totalDatas; - - public TANTool(String filePath) { - this.filePath = filePath; - - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] array; - - while ((str = in.readLine()) != null) { - array = str.split(" "); - dataArray.add(array); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - this.totalDatas = dataArray; - this.attrNames = this.totalDatas.get(0); - this.attrNum = this.attrNames.length; - this.classAttrName = this.attrNames[attrNum - 1]; - - Node node; - this.edges = new int[attrNum][attrNum]; - this.totalNodes = new ArrayList<>(); - this.attr2Column = new HashMap<>(); - this.attr2Values = new HashMap<>(); - - // 分类属性节点id最小设为0 - node = new Node(0, attrNames[attrNum - 1]); - this.totalNodes.add(node); - for (int i = 0; i < attrNames.length; i++) { - if (i < attrNum - 1) { - // 创建贝叶斯网络节点,每个属性一个节点 - node = new Node(i + 1, attrNames[i]); - this.totalNodes.add(node); - } - - // 添加属性到列下标的映射 - this.attr2Column.put(attrNames[i], i); - } - - String[] temp; - ArrayList values; - // 进行属性名,属性值对的映射匹配 - for (int i = 1; i < this.totalDatas.size(); i++) { - temp = this.totalDatas.get(i); - - for (int j = 0; j < temp.length; j++) { - // 判断map中是否包含此属性名 - if (this.attr2Values.containsKey(attrNames[j])) { - values = this.attr2Values.get(attrNames[j]); - } else { - values = new ArrayList<>(); - } - - if (!values.contains(temp[j])) { - // 加入新的属性值 - values.add(temp[j]); - } - - this.attr2Values.put(attrNames[j], values); - } - } - } - - /** - * 根据条件互信息度对构建最大权重跨度树,返回第一个节点为根节点 - * - * @param iArray - */ - private Node constructWeightTree(ArrayList iArray) { - Node node1; - Node node2; - Node root; - ArrayList existNodes; - - existNodes = new ArrayList<>(); - - for (Node[] i : iArray) { - node1 = i[0]; - node2 = i[1]; - - // 将2个节点进行连接 - node1.connectNode(node2); - // 避免出现环路现象 - addIfNotExist(node1, existNodes); - addIfNotExist(node2, existNodes); - - if (existNodes.size() == attrNum - 1) { - break; - } - } - - // 返回第一个作为根节点 - root = existNodes.get(0); - return root; - } - - /** - * 为树型结构确定边的方向,方向为属性根节点方向指向其他属性节点方向 - * - * @param root - * 当前遍历到的节点 - */ - private void confirmGraphDirection(Node currentNode) { - int i; - int j; - ArrayList connectedNodes; - - connectedNodes = currentNode.connectedNodes; - - i = currentNode.id; - for (Node n : connectedNodes) { - j = n.id; - - // 判断连接此2节点的方向是否被确定 - if (edges[i][j] == 0 && edges[j][i] == 0) { - // 如果没有确定,则制定方向为i->j - edges[i][j] = 1; - - // 递归继续搜索 - confirmGraphDirection(n); - } - } - } - - /** - * 为属性节点添加分类属性节点为父节点 - * - * @param parentNode - * 父节点 - * @param nodeList - * 子节点列表 - */ - private void addParentNode() { - // 分类属性节点 - Node parentNode; - - parentNode = null; - for (Node n : this.totalNodes) { - if (n.id == 0) { - parentNode = n; - break; - } - } - - for (Node child : this.totalNodes) { - parentNode.connectNode(child); - - if (child.id != 0) { - // 确定连接方向 - this.edges[0][child.id] = 1; - } - } - } - - /** - * 在节点集合中添加节点 - * - * @param node - * 待添加节点 - * @param existNodes - * 已存在的节点列表 - * @return - */ - public boolean addIfNotExist(Node node, ArrayList existNodes) { - boolean canAdd; - - canAdd = true; - for (Node n : existNodes) { - // 如果节点列表中已经含有节点,则算添加失败 - if (n.isEqual(node)) { - canAdd = false; - break; - } - } - - if (canAdd) { - existNodes.add(node); - } - - return canAdd; - } - - /** - * 计算节点条件概率 - * - * @param node - * 关于node的后验概率 - * @param queryParam - * 查询的属性参数 - * @return - */ - private double calConditionPro(Node node, HashMap queryParam) { - int id; - double pro; - String value; - String[] attrValue; - - ArrayList priorAttrInfos; - ArrayList backAttrInfos; - ArrayList parentNodes; - - pro = 1; - id = node.id; - parentNodes = new ArrayList<>(); - priorAttrInfos = new ArrayList<>(); - backAttrInfos = new ArrayList<>(); - - for (int i = 0; i < this.edges.length; i++) { - // 寻找父节点id - if (this.edges[i][id] == 1) { - for (Node temp : this.totalNodes) { - // 寻找目标节点id - if (temp.id == i) { - parentNodes.add(temp); - break; - } - } - } - } - - // 获取先验属性的属性值,首先添加先验属性 - value = queryParam.get(node.name); - attrValue = new String[2]; - attrValue[0] = node.name; - attrValue[1] = value; - priorAttrInfos.add(attrValue); - - // 逐一添加后验属性 - for (Node p : parentNodes) { - value = queryParam.get(p.name); - attrValue = new String[2]; - attrValue[0] = p.name; - attrValue[1] = value; - - backAttrInfos.add(attrValue); - } - - pro = queryConditionPro(priorAttrInfos, backAttrInfos); - - return pro; - } - - /** - * 查询条件概率 - * - * @param attrValues - * 条件属性值 - * @return - */ - private double queryConditionPro(ArrayList priorValues, - ArrayList backValues) { - // 判断是否满足先验属性值条件 - boolean hasPrior; - // 判断是否满足后验属性值条件 - boolean hasBack; - int attrIndex; - double backPro; - double totalPro; - double pro; - String[] tempData; - - pro = 0; - totalPro = 0; - backPro = 0; - - // 跳过第一行的属性名称行 - for (int i = 1; i < this.totalDatas.size(); i++) { - tempData = this.totalDatas.get(i); - - hasPrior = true; - hasBack = true; - - // 判断是否满足先验条件 - for (String[] array : priorValues) { - attrIndex = this.attr2Column.get(array[0]); - - // 判断值是否满足条件 - if (!tempData[attrIndex].equals(array[1])) { - hasPrior = false; - break; - } - } - - // 判断是否满足后验条件 - for (String[] array : backValues) { - attrIndex = this.attr2Column.get(array[0]); - - // 判断值是否满足条件 - if (!tempData[attrIndex].equals(array[1])) { - hasBack = false; - break; - } - } - - // 进行计数统计,分别计算满足后验属性的值和同时满足条件的个数 - if (hasBack) { - backPro++; - if (hasPrior) { - totalPro++; - } - } else if (hasPrior && backValues.size() == 0) { - // 如果只有先验概率则为纯概率的计算 - totalPro++; - backPro = 1.0; - } - } - - if (backPro == 0) { - pro = 0; - } else { - // 计算总的概率=都发生概率/只发生后验条件的时间概率 - pro = totalPro / backPro; - } - - return pro; - } - - /** - * 输入查询条件参数,计算发生概率 - * - * @param queryParam - * 条件参数 - * @return - */ - public double calHappenedPro(String queryParam) { - double result; - double temp; - // 分类属性值 - String classAttrValue; - String[] array; - String[] array2; - HashMap params; - - result = 1; - params = new HashMap<>(); - - // 进行查询字符的参数分解 - array = queryParam.split(","); - for (String s : array) { - array2 = s.split("="); - params.put(array2[0], array2[1]); - } - - classAttrValue = params.get(classAttrName); - // 构建贝叶斯网络结构 - constructBayesNetWork(classAttrValue); - - for (Node n : this.totalNodes) { - temp = calConditionPro(n, params); - - // 为了避免出现条件概率为0的现象,进行轻微矫正 - if (temp == 0) { - temp = 0.001; - } - - // 按照联合概率公式,进行乘积运算 - result *= temp; - } - - return result; - } - - /** - * 构建树型贝叶斯网络结构 - * - * @param value - * 类别量值 - */ - private void constructBayesNetWork(String value) { - Node rootNode; - ArrayList mInfoArray; - // 互信息度对 - ArrayList iArray; - - iArray = null; - rootNode = null; - - // 在每次重新构建贝叶斯网络结构的时候,清空原有的连接结构 - for (Node n : this.totalNodes) { - n.connectedNodes.clear(); - } - this.edges = new int[attrNum][attrNum]; - - // 从互信息对象中取出属性值对 - iArray = new ArrayList<>(); - mInfoArray = calAttrMutualInfoArray(value); - for (AttrMutualInfo v : mInfoArray) { - iArray.add(v.nodeArray); - } - - // 构建最大权重跨度树 - rootNode = constructWeightTree(iArray); - // 为无向图确定边的方向 - confirmGraphDirection(rootNode); - // 为每个属性节点添加分类属性父节点 - addParentNode(); - } - - /** - * 给定分类变量值,计算属性之间的互信息值 - * - * @param value - * 分类变量值 - * @return - */ - private ArrayList calAttrMutualInfoArray(String value) { - double iValue; - Node node1; - Node node2; - AttrMutualInfo mInfo; - ArrayList mInfoArray; - - mInfoArray = new ArrayList<>(); - - for (int i = 0; i < this.totalNodes.size() - 1; i++) { - node1 = this.totalNodes.get(i); - // 跳过分类属性节点 - if (node1.id == 0) { - continue; - } - - for (int j = i + 1; j < this.totalNodes.size(); j++) { - node2 = this.totalNodes.get(j); - // 跳过分类属性节点 - if (node2.id == 0) { - continue; - } - - // 计算2个属性节点之间的互信息值 - iValue = calMutualInfoValue(node1, node2, value); - mInfo = new AttrMutualInfo(iValue, node1, node2); - mInfoArray.add(mInfo); - } - } - - // 将结果进行降序排列,让互信息值高的优先用于构建树 - Collections.sort(mInfoArray); - - return mInfoArray; - } - - /** - * 计算2个属性节点的互信息值 - * - * @param node1 - * 节点1 - * @param node2 - * 节点2 - * @param vlaue - * 分类变量值 - */ - private double calMutualInfoValue(Node node1, Node node2, String value) { - double iValue; - double temp; - // 三种不同条件的后验概率 - double pXiXj; - double pXi; - double pXj; - String[] array1; - String[] array2; - ArrayList attrValues1; - ArrayList attrValues2; - ArrayList priorValues; - // 后验概率,在这里就是类变量值 - ArrayList backValues; - - array1 = new String[2]; - array2 = new String[2]; - priorValues = new ArrayList<>(); - backValues = new ArrayList<>(); - - iValue = 0; - array1[0] = classAttrName; - array1[1] = value; - // 后验属性都是类属性 - backValues.add(array1); - - // 获取节点属性的属性值集合 - attrValues1 = this.attr2Values.get(node1.name); - attrValues2 = this.attr2Values.get(node2.name); - - for (String v1 : attrValues1) { - for (String v2 : attrValues2) { - priorValues.clear(); - - array1 = new String[2]; - array1[0] = node1.name; - array1[1] = v1; - priorValues.add(array1); - - array2 = new String[2]; - array2[0] = node2.name; - array2[1] = v2; - priorValues.add(array2); - - // 计算3种条件下的概率 - pXiXj = queryConditionPro(priorValues, backValues); - - priorValues.clear(); - priorValues.add(array1); - pXi = queryConditionPro(priorValues, backValues); - - priorValues.clear(); - priorValues.add(array2); - pXj = queryConditionPro(priorValues, backValues); - - // 如果出现其中一个计数概率为0,则直接赋值为0处理 - if (pXiXj == 0 || pXi == 0 || pXj == 0) { - temp = 0; - } else { - // 利用公式计算针对此属性值对组合的概率 - temp = pXiXj * Math.log(pXiXj / (pXi * pXj)) / Math.log(2); - } - - // 进行和属性值对组合的累加即为整个属性的互信息值 - iValue += temp; - } - } - - return iValue; - } +class TANTool { + // 测试数据集地址 + private String filePath; + // 数据集属性总数,其中一个个分类属性 + private int attrNum; + // 分类属性名 + private String classAttrName; + // 属性列名称行 + private String[] attrNames; + // 贝叶斯网络边的方向,数组内的数值为节点id,从i->j + private int[][] edges; + // 属性名到列下标的映射 + private HashMap attr2Column; + // 属性,属性对取值集合映射对 + private HashMap> attr2Values; + // 贝叶斯网络总节点列表 + private ArrayList totalNodes; + // 总的测试数据 + private ArrayList totalDatas; + + TANTool(String filePath){ + this.filePath = filePath; + + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] array; + + while ((str = in.readLine()) != null) { + array = str.split(" "); + dataArray.add(array); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + this.totalDatas = dataArray; + this.attrNames = this.totalDatas.get(0); + this.attrNum = this.attrNames.length; + this.classAttrName = this.attrNames[attrNum - 1]; + + Node node; + this.edges = new int[attrNum][attrNum]; + this.totalNodes = new ArrayList<>(); + this.attr2Column = new HashMap<>(); + this.attr2Values = new HashMap<>(); + + // 分类属性节点id最小设为0 + node = new Node(0, attrNames[attrNum - 1]); + this.totalNodes.add(node); + for (int i = 0; i < attrNames.length; i++) { + if (i < attrNum - 1) { + // 创建贝叶斯网络节点,每个属性一个节点 + node = new Node(i + 1, attrNames[i]); + this.totalNodes.add(node); + } + + // 添加属性到列下标的映射 + this.attr2Column.put(attrNames[i], i); + } + + String[] temp; + ArrayList values; + // 进行属性名,属性值对的映射匹配 + for (int i = 1; i < this.totalDatas.size(); i++) { + temp = this.totalDatas.get(i); + + for (int j = 0; j < temp.length; j++) { + // 判断map中是否包含此属性名 + if (this.attr2Values.containsKey(attrNames[j])) { + values = this.attr2Values.get(attrNames[j]); + } else { + values = new ArrayList<>(); + } + + if (!values.contains(temp[j])) { + // 加入新的属性值 + values.add(temp[j]); + } + + this.attr2Values.put(attrNames[j], values); + } + } + } + + /** + * 根据条件互信息度对构建最大权重跨度树,返回第一个节点为根节点 + * + * @param iArray + */ + private Node constructWeightTree(ArrayList iArray){ + Node node1; + Node node2; + Node root; + ArrayList existNodes; + + existNodes = new ArrayList<>(); + + for (Node[] i : iArray) { + node1 = i[0]; + node2 = i[1]; + + // 将2个节点进行连接 + node1.connectNode(node2); + // 避免出现环路现象 + addIfNotExist(node1, existNodes); + addIfNotExist(node2, existNodes); + + if (existNodes.size() == attrNum - 1) { + break; + } + } + + // 返回第一个作为根节点 + root = existNodes.get(0); + return root; + } + + /** + * 为树型结构确定边的方向,方向为属性根节点方向指向其他属性节点方向 + * + * @param currentNode 当前遍历到的节点 + */ + private void confirmGraphDirection(Node currentNode){ + int i; + int j; + ArrayList connectedNodes; + + connectedNodes = currentNode.connectedNodes; + + i = currentNode.id; + for (Node n : connectedNodes) { + j = n.id; + + // 判断连接此2节点的方向是否被确定 + if (edges[i][j] == 0 && edges[j][i] == 0) { + // 如果没有确定,则制定方向为i->j + edges[i][j] = 1; + + // 递归继续搜索 + confirmGraphDirection(n); + } + } + } + + /** + * 为属性节点添加分类属性节点为父节点 + *

+ * parentNode 父节点 + * nodeList 子节点列表 + */ + private void addParentNode(){ + // 分类属性节点 + Node parentNode; + + parentNode = null; + for (Node n : this.totalNodes) { + if (n.id == 0) { + parentNode = n; + break; + } + } + + for (Node child : this.totalNodes) { + if (parentNode != null) { + parentNode.connectNode(child); + } + + if (child.id != 0) { + // 确定连接方向 + this.edges[0][child.id] = 1; + } + } + } + + /** + * 在节点集合中添加节点 + * + * @param node 待添加节点 + * @param existNodes 已存在的节点列表 + */ + private boolean addIfNotExist(Node node, ArrayList existNodes){ + boolean canAdd; + + canAdd = true; + for (Node n : existNodes) { + // 如果节点列表中已经含有节点,则算添加失败 + if (n.isEqual(node)) { + canAdd = false; + break; + } + } + + if (canAdd) { + existNodes.add(node); + } + + return canAdd; + } + + /** + * 计算节点条件概率 + * + * @param node 关于node的后验概率 + * @param queryParam 查询的属性参数 + */ + private double calConditionPro(Node node, HashMap queryParam){ + int id; + double pro; + String value; + String[] attrValue; + + ArrayList priorAttrInfos; + ArrayList backAttrInfos; + ArrayList parentNodes; + + id = node.id; + parentNodes = new ArrayList<>(); + priorAttrInfos = new ArrayList<>(); + backAttrInfos = new ArrayList<>(); + + for (int i = 0; i < this.edges.length; i++) { + // 寻找父节点id + if (this.edges[i][id] == 1) { + for (Node temp : this.totalNodes) { + // 寻找目标节点id + if (temp.id == i) { + parentNodes.add(temp); + break; + } + } + } + } + + // 获取先验属性的属性值,首先添加先验属性 + value = queryParam.get(node.name); + attrValue = new String[2]; + attrValue[0] = node.name; + attrValue[1] = value; + priorAttrInfos.add(attrValue); + + // 逐一添加后验属性 + for (Node p : parentNodes) { + value = queryParam.get(p.name); + attrValue = new String[2]; + attrValue[0] = p.name; + attrValue[1] = value; + + backAttrInfos.add(attrValue); + } + + pro = queryConditionPro(priorAttrInfos, backAttrInfos); + + return pro; + } + + /** + * 查询条件概率 + * + * @param priorValues 条件属性值 + * @param backValues 条件属性值 + */ + private double queryConditionPro(ArrayList priorValues, + ArrayList backValues){ + // 判断是否满足先验属性值条件 + boolean hasPrior; + // 判断是否满足后验属性值条件 + boolean hasBack; + int attrIndex; + double backPro; + double totalPro; + double pro; + String[] tempData; + + totalPro = 0; + backPro = 0; + + // 跳过第一行的属性名称行 + for (int i = 1; i < this.totalDatas.size(); i++) { + tempData = this.totalDatas.get(i); + + hasPrior = true; + hasBack = true; + + // 判断是否满足先验条件 + for (String[] array : priorValues) { + attrIndex = this.attr2Column.get(array[0]); + + // 判断值是否满足条件 + if (!tempData[attrIndex].equals(array[1])) { + hasPrior = false; + break; + } + } + + // 判断是否满足后验条件 + for (String[] array : backValues) { + attrIndex = this.attr2Column.get(array[0]); + + // 判断值是否满足条件 + if (!tempData[attrIndex].equals(array[1])) { + hasBack = false; + break; + } + } + + // 进行计数统计,分别计算满足后验属性的值和同时满足条件的个数 + if (hasBack) { + backPro++; + if (hasPrior) { + totalPro++; + } + } else if (hasPrior && backValues.size() == 0) { + // 如果只有先验概率则为纯概率的计算 + totalPro++; + backPro = 1.0; + } + } + + if (backPro == 0) { + pro = 0; + } else { + // 计算总的概率=都发生概率/只发生后验条件的时间概率 + pro = totalPro / backPro; + } + + return pro; + } + + /** + * 输入查询条件参数,计算发生概率 + * + * @param queryParam 条件参数 + */ + double calHappenedPro(String queryParam){ + double result; + double temp; + // 分类属性值 + String classAttrValue; + String[] array; + String[] array2; + HashMap params; + + result = 1; + params = new HashMap<>(); + + // 进行查询字符的参数分解 + array = queryParam.split(","); + for (String s : array) { + array2 = s.split("="); + params.put(array2[0], array2[1]); + } + + classAttrValue = params.get(classAttrName); + // 构建贝叶斯网络结构 + constructBayesNetWork(classAttrValue); + + for (Node n : this.totalNodes) { + temp = calConditionPro(n, params); + + // 为了避免出现条件概率为0的现象,进行轻微矫正 + if (temp == 0) { + temp = 0.001; + } + + // 按照联合概率公式,进行乘积运算 + result *= temp; + } + + return result; + } + + /** + * 构建树型贝叶斯网络结构 + * + * @param value 类别量值 + */ + private void constructBayesNetWork(String value){ + Node rootNode; + ArrayList mInfoArray; + // 互信息度对 + ArrayList iArray; + + // 在每次重新构建贝叶斯网络结构的时候,清空原有的连接结构 + for (Node n : this.totalNodes) { + n.connectedNodes.clear(); + } + this.edges = new int[attrNum][attrNum]; + + // 从互信息对象中取出属性值对 + iArray = new ArrayList<>(); + mInfoArray = calAttrMutualInfoArray(value); + iArray.addAll(mInfoArray.stream() + .map(v -> v.nodeArray) + .collect(Collectors.toList())); + + // 构建最大权重跨度树 + rootNode = constructWeightTree(iArray); + // 为无向图确定边的方向 + confirmGraphDirection(rootNode); + // 为每个属性节点添加分类属性父节点 + addParentNode(); + } + + /** + * 给定分类变量值,计算属性之间的互信息值 + * + * @param value 分类变量值 + */ + private ArrayList calAttrMutualInfoArray(String value){ + double iValue; + Node node1; + Node node2; + AttrMutualInfo mInfo; + ArrayList mInfoArray; + + mInfoArray = new ArrayList<>(); + + for (int i = 0; i < this.totalNodes.size() - 1; i++) { + node1 = this.totalNodes.get(i); + // 跳过分类属性节点 + if (node1.id == 0) { + continue; + } + + for (int j = i + 1; j < this.totalNodes.size(); j++) { + node2 = this.totalNodes.get(j); + // 跳过分类属性节点 + if (node2.id == 0) { + continue; + } + + // 计算2个属性节点之间的互信息值 + iValue = calMutualInfoValue(node1, node2, value); + mInfo = new AttrMutualInfo(iValue, node1, node2); + mInfoArray.add(mInfo); + } + } + + // 将结果进行降序排列,让互信息值高的优先用于构建树 + Collections.sort(mInfoArray); + + return mInfoArray; + } + + /** + * 计算2个属性节点的互信息值 + * + * @param node1 节点1 + * @param node2 节点2 + * @param value 分类变量值 + */ + private double calMutualInfoValue(Node node1, Node node2, String value){ + double iValue; + double temp; + // 三种不同条件的后验概率 + double pXiXj; + double pXi; + double pXj; + String[] array1; + String[] array2; + ArrayList attrValues1; + ArrayList attrValues2; + ArrayList priorValues; + // 后验概率,在这里就是类变量值 + ArrayList backValues; + + array1 = new String[2]; + priorValues = new ArrayList<>(); + backValues = new ArrayList<>(); + + iValue = 0; + array1[0] = classAttrName; + array1[1] = value; + // 后验属性都是类属性 + backValues.add(array1); + + // 获取节点属性的属性值集合 + attrValues1 = this.attr2Values.get(node1.name); + attrValues2 = this.attr2Values.get(node2.name); + + for (String v1 : attrValues1) { + for (String v2 : attrValues2) { + priorValues.clear(); + + array1 = new String[2]; + array1[0] = node1.name; + array1[1] = v1; + priorValues.add(array1); + + array2 = new String[2]; + array2[0] = node2.name; + array2[1] = v2; + priorValues.add(array2); + + // 计算3种条件下的概率 + pXiXj = queryConditionPro(priorValues, backValues); + + priorValues.clear(); + priorValues.add(array1); + pXi = queryConditionPro(priorValues, backValues); + + priorValues.clear(); + priorValues.add(array2); + pXj = queryConditionPro(priorValues, backValues); + + // 如果出现其中一个计数概率为0,则直接赋值为0处理 + if (pXiXj == 0 || pXi == 0 || pXj == 0) { + temp = 0; + } else { + // 利用公式计算针对此属性值对组合的概率 + temp = pXiXj * Math.log(pXiXj / (pXi * pXj)) / Math.log(2); + } + + // 进行和属性值对组合的累加即为整个属性的互信息值 + iValue += temp; + } + } + + return iValue; + } } diff --git a/Others/DataMining_Viterbi/BaseNames.java b/Others/DataMining_Viterbi/BaseNames.java index cca0aaa..9733fb5 100644 --- a/Others/DataMining_Viterbi/BaseNames.java +++ b/Others/DataMining_Viterbi/BaseNames.java @@ -1,24 +1,22 @@ -package DataMining_Viterbi; +package Others.DataMining_Viterbi; /** * 基本变量定义类 - * @author lyq * + * @author Qstar */ -public class BaseNames { - //日期天数下标 - public static final int DAY1 = 0; - public static final int DAY2 = 1; - public static final int DAY3 = 2; - - //天气属性类别 - public static final int WEATHER_SUNNY = 0; - public static final int WEATHER_CLOUDY = 1; - public static final int WEATHER_RAINY = 2; - - //湿度属性类别 - public static final int HUMIDITY_DRY = 0; - public static final int HUMIDITY_DRYISH = 1; - public static final int HUMIDITY_DAMP = 1; - public static final int HUMIDITY_SOGGY = 1; +class BaseNames { + public static final int DAY2 = 1; + public static final int DAY3 = 2; + //天气属性类别 + public static final int WEATHER_SUNNY = 0; + public static final int WEATHER_CLOUDY = 1; + public static final int WEATHER_RAINY = 2; + //湿度属性类别 + public static final int HUMIDITY_DRY = 0; + public static final int HUMIDITY_DRYISH = 1; + public static final int HUMIDITY_DAMP = 1; + public static final int HUMIDITY_SOGGY = 1; + //日期天数下标 + static final int DAY1 = 0; } diff --git a/Others/DataMining_Viterbi/Client.java b/Others/DataMining_Viterbi/Client.java index 577eabd..d6c62ac 100644 --- a/Others/DataMining_Viterbi/Client.java +++ b/Others/DataMining_Viterbi/Client.java @@ -1,31 +1,30 @@ -package DataMining_Viterbi; +package Others.DataMining_Viterbi; /** * 维特比算法 - * - * @author lyq - * + * + * @author Qstar */ public class Client { - public static void main(String[] args) { - // 状态转移概率矩阵路径 - String stmFilePath; - // 混淆矩阵路径 - String cfFilePath; - // 观察到的状态 - String[] observeStates; - // 初始状态 - double[] initStatePro; - ViterbiTool tool; + public static void main(String[] args){ + // 状态转移概率矩阵路径 + String stmFilePath; + // 混淆矩阵路径 + String cfFilePath; + // 观察到的状态 + String[] observeStates; + // 初始状态 + double[] initStatePro; + ViterbiTool tool; - stmFilePath = "C:\\Users\\lyq\\Desktop\\icon\\stmatrix.txt"; - cfFilePath = "C:\\Users\\lyq\\Desktop\\icon\\humidity-matrix.txt"; + stmFilePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_Viterbi/stmatrix.txt"; + cfFilePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Others/DataMining_Viterbi/humidity-matrix.txt"; - initStatePro = new double[] { 0.63, 0.17, 0.20 }; - observeStates = new String[] { "Dry", "Damp", "Soggy" }; + initStatePro = new double[]{0.63, 0.17, 0.20}; + observeStates = new String[]{"Dry", "Damp", "Soggy"}; - tool = new ViterbiTool(stmFilePath, cfFilePath, initStatePro, - observeStates); - tool.calHMMObserve(); - } + tool = new ViterbiTool(stmFilePath, cfFilePath, initStatePro, + observeStates); + tool.calHMMObserve(); + } } diff --git a/Others/DataMining_Viterbi/ViterbiTool.java b/Others/DataMining_Viterbi/ViterbiTool.java index 6f1ade6..0f3fd0a 100644 --- a/Others/DataMining_Viterbi/ViterbiTool.java +++ b/Others/DataMining_Viterbi/ViterbiTool.java @@ -1,240 +1,236 @@ -package DataMining_Viterbi; +package Others.DataMining_Viterbi; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; -import java.util.Map; /** * 维特比算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class ViterbiTool { - // 状态转移概率矩阵文件地址 - private String stmFilePath; - // 混淆矩阵文件地址 - private String confusionFilePath; - // 初始状态概率 - private double[] initStatePro; - // 观察到的状态序列 - public String[] observeStates; - // 状态转移矩阵值 - private double[][] stMatrix; - // 混淆矩阵值 - private double[][] confusionMatrix; - // 各个条件下的潜在特征概率值 - private double[][] potentialValues; - // 潜在特征 - private ArrayList potentialAttrs; - // 属性值列坐标映射图 - private HashMap name2Index; - // 列坐标属性值映射图 - private HashMap index2name; - - public ViterbiTool(String stmFilePath, String confusionFilePath, - double[] initStatePro, String[] observeStates) { - this.stmFilePath = stmFilePath; - this.confusionFilePath = confusionFilePath; - this.initStatePro = initStatePro; - this.observeStates = observeStates; - - initOperation(); - } - - /** - * 初始化数据操作 - */ - private void initOperation() { - double[] temp; - int index; - ArrayList smtDatas; - ArrayList cfDatas; - - smtDatas = readDataFile(stmFilePath); - cfDatas = readDataFile(confusionFilePath); - - index = 0; - this.stMatrix = new double[smtDatas.size()][]; - for (String[] array : smtDatas) { - temp = new double[array.length]; - for (int i = 0; i < array.length; i++) { - try { - temp[i] = Double.parseDouble(array[i]); - } catch (NumberFormatException e) { - temp[i] = -1; - } - } - - // 将转换后的值赋给数组中 - this.stMatrix[index] = temp; - index++; - } - - index = 0; - this.confusionMatrix = new double[cfDatas.size()][]; - for (String[] array : cfDatas) { - temp = new double[array.length]; - for (int i = 0; i < array.length; i++) { - try { - temp[i] = Double.parseDouble(array[i]); - } catch (NumberFormatException e) { - temp[i] = -1; - } - } - - // 将转换后的值赋给数组中 - this.confusionMatrix[index] = temp; - index++; - } - - this.potentialAttrs = new ArrayList<>(); - // 添加潜在特征属性 - for (String s : smtDatas.get(0)) { - this.potentialAttrs.add(s); - } - // 去除首列无效列 - potentialAttrs.remove(0); - - this.name2Index = new HashMap<>(); - this.index2name = new HashMap<>(); - - // 添加名称下标映射关系 - for (int i = 1; i < smtDatas.get(0).length; i++) { - this.name2Index.put(smtDatas.get(0)[i], i); - // 添加下标到名称的映射 - this.index2name.put(i, smtDatas.get(0)[i]); - } - - for (int i = 1; i < cfDatas.get(0).length; i++) { - this.name2Index.put(cfDatas.get(0)[i], i); - } - } - - /** - * 从文件中读取数据 - */ - private ArrayList readDataFile(String filePath) { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - return dataArray; - } - - /** - * 根据观察特征计算隐藏的特征概率矩阵 - */ - private void calPotencialProMatrix() { - String curObserveState; - // 观察特征和潜在特征的下标 - int osIndex; - int psIndex; - double temp; - double maxPro; - // 混淆矩阵概率值,就是相关影响的因素概率 - double confusionPro; - - this.potentialValues = new double[observeStates.length][potentialAttrs - .size() + 1]; - for (int i = 0; i < this.observeStates.length; i++) { - curObserveState = this.observeStates[i]; - osIndex = this.name2Index.get(curObserveState); - maxPro = -1; - - // 因为是第一个观察特征,没有前面的影响,根据初始状态计算 - if (i == 0) { - for (String attr : this.potentialAttrs) { - psIndex = this.name2Index.get(attr); - confusionPro = this.confusionMatrix[psIndex][osIndex]; - - temp = this.initStatePro[psIndex - 1] * confusionPro; - this.potentialValues[BaseNames.DAY1][psIndex] = temp; - } - } else { - // 后面的潜在特征受前一个特征的影响,以及当前的混淆因素影响 - for (String toDayAttr : this.potentialAttrs) { - psIndex = this.name2Index.get(toDayAttr); - confusionPro = this.confusionMatrix[psIndex][osIndex]; - - int index; - maxPro = -1; - // 通过昨天的概率计算今天此特征的最大概率 - for (String yAttr : this.potentialAttrs) { - index = this.name2Index.get(yAttr); - temp = this.potentialValues[i - 1][index] - * this.stMatrix[index][psIndex]; - - // 计算得到今天此潜在特征的最大概率 - if (temp > maxPro) { - maxPro = temp; - } - } - - this.potentialValues[i][psIndex] = maxPro * confusionPro; - } - } - } - } - - /** - * 根据同时期最大概率值输出潜在特征值 - */ - private void outputResultAttr() { - double maxPro; - int maxIndex; - ArrayList psValues; - - psValues = new ArrayList<>(); - for (int i = 0; i < this.potentialValues.length; i++) { - maxPro = -1; - maxIndex = 0; - - for (int j = 0; j < potentialValues[i].length; j++) { - if (this.potentialValues[i][j] > maxPro) { - maxPro = potentialValues[i][j]; - maxIndex = j; - } - } - - // 取出最大概率下标对应的潜在特征 - psValues.add(this.index2name.get(maxIndex)); - } - - System.out.println("观察特征为:"); - for (String s : this.observeStates) { - System.out.print(s + ", "); - } - System.out.println(); - - System.out.println("潜在特征为:"); - for (String s : psValues) { - System.out.print(s + ", "); - } - System.out.println(); - } - - /** - * 根据观察属性,得到潜在属性信息 - */ - public void calHMMObserve() { - calPotencialProMatrix(); - outputResultAttr(); - } +class ViterbiTool { + // 观察到的状态序列 + private String[] observeStates; + // 状态转移概率矩阵文件地址 + private String stmFilePath; + // 混淆矩阵文件地址 + private String confusionFilePath; + // 初始状态概率 + private double[] initStatePro; + // 状态转移矩阵值 + private double[][] stMatrix; + // 混淆矩阵值 + private double[][] confusionMatrix; + // 各个条件下的潜在特征概率值 + private double[][] potentialValues; + // 潜在特征 + private ArrayList potentialAttrs; + // 属性值列坐标映射图 + private HashMap name2Index; + // 列坐标属性值映射图 + private HashMap index2name; + + ViterbiTool(String stmFilePath, String confusionFilePath, + double[] initStatePro, String[] observeStates){ + this.stmFilePath = stmFilePath; + this.confusionFilePath = confusionFilePath; + this.initStatePro = initStatePro; + this.observeStates = observeStates; + + initOperation(); + } + + /** + * 初始化数据操作 + */ + private void initOperation(){ + double[] temp; + int index; + ArrayList smtDatas; + ArrayList cfDatas; + + smtDatas = readDataFile(stmFilePath); + cfDatas = readDataFile(confusionFilePath); + + index = 0; + this.stMatrix = new double[smtDatas.size()][]; + for (String[] array : smtDatas) { + temp = new double[array.length]; + for (int i = 0; i < array.length; i++) { + try { + temp[i] = Double.parseDouble(array[i]); + } catch (NumberFormatException e) { + temp[i] = -1; + } + } + + // 将转换后的值赋给数组中 + this.stMatrix[index] = temp; + index++; + } + + index = 0; + this.confusionMatrix = new double[cfDatas.size()][]; + for (String[] array : cfDatas) { + temp = new double[array.length]; + for (int i = 0; i < array.length; i++) { + try { + temp[i] = Double.parseDouble(array[i]); + } catch (NumberFormatException e) { + temp[i] = -1; + } + } + + // 将转换后的值赋给数组中 + this.confusionMatrix[index] = temp; + index++; + } + + this.potentialAttrs = new ArrayList<>(); + // 添加潜在特征属性 + Collections.addAll(this.potentialAttrs, smtDatas.get(0)); + // 去除首列无效列 + potentialAttrs.remove(0); + + this.name2Index = new HashMap<>(); + this.index2name = new HashMap<>(); + + // 添加名称下标映射关系 + for (int i = 1; i < smtDatas.get(0).length; i++) { + this.name2Index.put(smtDatas.get(0)[i], i); + // 添加下标到名称的映射 + this.index2name.put(i, smtDatas.get(0)[i]); + } + + for (int i = 1; i < cfDatas.get(0).length; i++) { + this.name2Index.put(cfDatas.get(0)[i], i); + } + } + + /** + * 从文件中读取数据 + */ + private ArrayList readDataFile(String filePath){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + return dataArray; + } + + /** + * 根据观察特征计算隐藏的特征概率矩阵 + */ + private void calPotencialProMatrix(){ + String curObserveState; + // 观察特征和潜在特征的下标 + int osIndex; + int psIndex; + double temp; + double maxPro; + // 混淆矩阵概率值,就是相关影响的因素概率 + double confusionPro; + + this.potentialValues = new double[observeStates.length][potentialAttrs + .size() + 1]; + for (int i = 0; i < this.observeStates.length; i++) { + curObserveState = this.observeStates[i]; + osIndex = this.name2Index.get(curObserveState); + + // 因为是第一个观察特征,没有前面的影响,根据初始状态计算 + if (i == 0) { + for (String attr : this.potentialAttrs) { + psIndex = this.name2Index.get(attr); + confusionPro = this.confusionMatrix[psIndex][osIndex]; + + temp = this.initStatePro[psIndex - 1] * confusionPro; + this.potentialValues[BaseNames.DAY1][psIndex] = temp; + } + } else { + // 后面的潜在特征受前一个特征的影响,以及当前的混淆因素影响 + for (String toDayAttr : this.potentialAttrs) { + psIndex = this.name2Index.get(toDayAttr); + confusionPro = this.confusionMatrix[psIndex][osIndex]; + + int index; + maxPro = -1; + // 通过昨天的概率计算今天此特征的最大概率 + for (String yAttr : this.potentialAttrs) { + index = this.name2Index.get(yAttr); + temp = this.potentialValues[i - 1][index] + * this.stMatrix[index][psIndex]; + + // 计算得到今天此潜在特征的最大概率 + if (temp > maxPro) { + maxPro = temp; + } + } + + this.potentialValues[i][psIndex] = maxPro * confusionPro; + } + } + } + } + + /** + * 根据同时期最大概率值输出潜在特征值 + */ + private void outputResultAttr(){ + double maxPro; + int maxIndex; + ArrayList psValues; + + psValues = new ArrayList<>(); + for (double[] potentialValue : this.potentialValues) { + maxPro = -1; + maxIndex = 0; + + for (int j = 0; j < potentialValue.length; j++) { + if (potentialValue[j] > maxPro) { + maxPro = potentialValue[j]; + maxIndex = j; + } + } + + // 取出最大概率下标对应的潜在特征 + psValues.add(this.index2name.get(maxIndex)); + } + + System.out.println("观察特征为:"); + for (String s : this.observeStates) { + System.out.print(s + ", "); + } + System.out.println(); + + System.out.println("潜在特征为:"); + for (String s : psValues) { + System.out.print(s + ", "); + } + System.out.println(); + } + + /** + * 根据观察属性,得到潜在属性信息 + */ + void calHMMObserve(){ + calPotencialProMatrix(); + outputResultAttr(); + } } diff --git a/RoughSets/DataMining_RoughSets/Client.java b/RoughSets/DataMining_RoughSets/Client.java index 2675dd9..e1e144a 100644 --- a/RoughSets/DataMining_RoughSets/Client.java +++ b/RoughSets/DataMining_RoughSets/Client.java @@ -1,15 +1,15 @@ -package DataMining_RoughSets; +package RoughSets.DataMining_RoughSets; /** * 粗糙集约简算法 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - - RoughSetsTool tool = new RoughSetsTool(filePath); - tool.findingReduct(); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/RoughSets/DataMining_RoughSets/input.txt"; + + RoughSetsTool tool = new RoughSetsTool(filePath); + tool.findingReduct(); + } } diff --git a/RoughSets/DataMining_RoughSets/KnowledgeSystem.java b/RoughSets/DataMining_RoughSets/KnowledgeSystem.java index 6e655ed..9455f98 100644 --- a/RoughSets/DataMining_RoughSets/KnowledgeSystem.java +++ b/RoughSets/DataMining_RoughSets/KnowledgeSystem.java @@ -1,195 +1,183 @@ -package DataMining_RoughSets; +package RoughSets.DataMining_RoughSets; import java.util.ArrayList; -import java.util.HashMap; /** * 知识系统 - * - * @author lyq - * + * + * @author Qstar */ -public class KnowledgeSystem { - // 知识系统内的集合 - ArrayList ksCollections; - - public KnowledgeSystem(ArrayList ksCollections) { - this.ksCollections = ksCollections; - } - - /** - * 获取集合的上近似集合 - * - * @param rc - * 原始集合 - * @return - */ - public RecordCollection getUpSimilarRC(RecordCollection rc) { - RecordCollection resultRc = null; - ArrayList nameArray; - ArrayList targetArray; - ArrayList copyRcs = new ArrayList<>(); - ArrayList deleteRcs = new ArrayList<>(); - targetArray = rc.getRecordNames(); - - // 做一个集合拷贝 - for (RecordCollection recordCollection : ksCollections) { - copyRcs.add(recordCollection); - } - - for (RecordCollection recordCollection : copyRcs) { - nameArray = recordCollection.getRecordNames(); - - if (strIsContained(targetArray, nameArray)) { - removeOverLaped(targetArray, nameArray); - deleteRcs.add(recordCollection); - - if (resultRc == null) { - resultRc = recordCollection; - } else { - // 进行并运算 - resultRc = resultRc.unionCal(recordCollection); - } - - if (targetArray.size() == 0) { - break; - } - } - } - //去除已经添加过的集合 - copyRcs.removeAll(deleteRcs); - - if (targetArray.size() > 0) { - // 说明已经完全还未找全上近似的集合 - for (RecordCollection recordCollection : copyRcs) { - nameArray = recordCollection.getRecordNames(); - - if (strHasOverlap(targetArray, nameArray)) { - removeOverLaped(targetArray, nameArray); - - if (resultRc == null) { - resultRc = recordCollection; - } else { - // 进行并运算 - resultRc = resultRc.unionCal(recordCollection); - } - - if (targetArray.size() == 0) { - break; - } - } - } - } - - return resultRc; - } - - /** - * 获取集合的下近似集合 - * - * @param rc - * 原始集合 - * @return - */ - public RecordCollection getDownSimilarRC(RecordCollection rc) { - RecordCollection resultRc = null; - ArrayList nameArray; - ArrayList targetArray; - targetArray = rc.getRecordNames(); - - for (RecordCollection recordCollection : ksCollections) { - nameArray = recordCollection.getRecordNames(); - - if (strIsContained(targetArray, nameArray)) { - removeOverLaped(targetArray, nameArray); - - if (resultRc == null) { - resultRc = recordCollection; - } else { - // 进行并运算 - resultRc = resultRc.unionCal(recordCollection); - } - - if (targetArray.size() == 0) { - break; - } - } - } - - return resultRc; - } - - /** - * 判断2个字符数组之间是否有交集 - * - * @param str1 - * 字符列表1 - * @param str2 - * 字符列表2 - * @return - */ - public boolean strHasOverlap(ArrayList str1, ArrayList str2) { - boolean hasOverlap = false; - - for (String s1 : str1) { - for (String s2 : str2) { - if (s1.equals(s2)) { - hasOverlap = true; - break; - } - } - - if (hasOverlap) { - break; - } - } - - return hasOverlap; - } - - /** - * 判断字符集str2是否完全包含于str1中 - * - * @param str1 - * @param str2 - * @return - */ - public boolean strIsContained(ArrayList str1, ArrayList str2) { - boolean isContained = false; - int count = 0; - - for (String s : str2) { - if (str1.contains(s)) { - count++; - } - } - - if (count == str2.size()) { - isContained = true; - } - - return isContained; - } - - /** - * 字符列表移除公共元素 - * - * @param str1 - * @param str2 - */ - public void removeOverLaped(ArrayList str1, ArrayList str2) { - ArrayList deleteStrs = new ArrayList<>(); - - for (String s1 : str1) { - for (String s2 : str2) { - if (s1.equals(s2)) { - deleteStrs.add(s1); - break; - } - } - } - - // 进行公共元素的移除 - str1.removeAll(deleteStrs); - } +class KnowledgeSystem { + // 知识系统内的集合 + private ArrayList ksCollections; + + KnowledgeSystem(ArrayList ksCollections){ + this.ksCollections = ksCollections; + } + + /** + * 获取集合的上近似集合 + * + * @param rc 原始集合 + */ + RecordCollection getUpSimilarRC(RecordCollection rc){ + RecordCollection resultRc = null; + ArrayList nameArray; + ArrayList targetArray; + ArrayList copyRcs = new ArrayList<>(); + ArrayList deleteRcs = new ArrayList<>(); + targetArray = rc.getRecordNames(); + + // 做一个集合拷贝 + copyRcs.addAll(ksCollections); + + for (RecordCollection recordCollection : copyRcs) { + nameArray = recordCollection.getRecordNames(); + + if (strIsContained(targetArray, nameArray)) { + removeOverLaped(targetArray, nameArray); + deleteRcs.add(recordCollection); + + if (resultRc == null) { + resultRc = recordCollection; + } else { + // 进行并运算 + resultRc = resultRc.unionCal(recordCollection); + } + + if (targetArray.size() == 0) { + break; + } + } + } + //去除已经添加过的集合 + copyRcs.removeAll(deleteRcs); + + if (targetArray.size() > 0) { + // 说明已经完全还未找全上近似的集合 + for (RecordCollection recordCollection : copyRcs) { + nameArray = recordCollection.getRecordNames(); + + if (strHasOverlap(targetArray, nameArray)) { + removeOverLaped(targetArray, nameArray); + + if (resultRc == null) { + resultRc = recordCollection; + } else { + // 进行并运算 + resultRc = resultRc.unionCal(recordCollection); + } + + if (targetArray.size() == 0) { + break; + } + } + } + } + + return resultRc; + } + + /** + * 获取集合的下近似集合 + * + * @param rc 原始集合 + */ + RecordCollection getDownSimilarRC(RecordCollection rc){ + RecordCollection resultRc = null; + ArrayList nameArray; + ArrayList targetArray; + targetArray = rc.getRecordNames(); + + for (RecordCollection recordCollection : ksCollections) { + nameArray = recordCollection.getRecordNames(); + + if (strIsContained(targetArray, nameArray)) { + removeOverLaped(targetArray, nameArray); + + if (resultRc == null) { + resultRc = recordCollection; + } else { + // 进行并运算 + resultRc = resultRc.unionCal(recordCollection); + } + + if (targetArray.size() == 0) { + break; + } + } + } + + return resultRc; + } + + /** + * 判断2个字符数组之间是否有交集 + * + * @param str1 字符列表1 + * @param str2 字符列表2 + */ + private boolean strHasOverlap(ArrayList str1, ArrayList str2){ + boolean hasOverlap = false; + + for (String s1 : str1) { + for (String s2 : str2) { + if (s1.equals(s2)) { + hasOverlap = true; + break; + } + } + + if (hasOverlap) { + break; + } + } + + return hasOverlap; + } + + /** + * 判断字符集str2是否完全包含于str1中 + * + * @param str1 字符集1 + * @param str2 字符集2 + */ + private boolean strIsContained(ArrayList str1, ArrayList str2){ + boolean isContained = false; + int count = 0; + + for (String s : str2) { + if (str1.contains(s)) { + count++; + } + } + + if (count == str2.size()) { + isContained = true; + } + + return isContained; + } + + /** + * 字符列表移除公共元素 + * + * @param str1 字符列表1 + * @param str2 字符列表2 + */ + private void removeOverLaped(ArrayList str1, ArrayList str2){ + ArrayList deleteStrs = new ArrayList<>(); + + for (String s1 : str1) { + for (String s2 : str2) { + if (s1.equals(s2)) { + deleteStrs.add(s1); + break; + } + } + } + + // 进行公共元素的移除 + str1.removeAll(deleteStrs); + } } diff --git a/RoughSets/DataMining_RoughSets/Record.java b/RoughSets/DataMining_RoughSets/Record.java index 8cb093e..947cccc 100644 --- a/RoughSets/DataMining_RoughSets/Record.java +++ b/RoughSets/DataMining_RoughSets/Record.java @@ -1,4 +1,4 @@ -package DataMining_RoughSets; +package RoughSets.DataMining_RoughSets; import java.text.MessageFormat; import java.util.ArrayList; @@ -7,99 +7,87 @@ /** * 数据记录,包含这条记录所有属性 - * - * @author lyq - * + * + * @author Qstar */ -public class Record { - // 记录名称 - private String name; - // 记录属性键值对 - private HashMap attrValues; - - public Record(String name, HashMap attrValues) { - this.name = name; - this.attrValues = attrValues; - } - - public String getName() { - return this.name; - } - - /** - * 此数据是否包含此属性值 - * - * @param attr - * 待判断属性值 - * @return - */ - public boolean isContainedAttr(String attr) { - boolean isContained = false; - - if (attrValues.containsValue(attr)) { - isContained = true; - } - - return isContained; - } - - /** - * 判断数据记录是否是同一条记录,根据数据名称来判断 - * - * @param record - * 目标比较对象 - * @return - */ - public boolean isRecordSame(Record record) { - boolean isSame = false; - - if (this.name.equals(record.name)) { - isSame = true; - } - - return isSame; - } - - /** - * 数据的决策属性分类 - * - * @return - */ - public String getRecordDecisionClass() { - String value = null; - - value = attrValues.get(RoughSetsTool.DECISION_ATTR_NAME); - - return value; - } - - /** - * 根据约简属性输出决策规则 - * - * @param reductAttr - * 约简属性集合 - */ - public String getDecisionRule(ArrayList reductAttr) { - String ruleStr = ""; - String attrName = null; - String value = null; - String decisionValue; - - decisionValue = attrValues.get(RoughSetsTool.DECISION_ATTR_NAME); - ruleStr += "属性"; - for (Map.Entry entry : this.attrValues.entrySet()) { - attrName = (String) entry.getKey(); - value = (String) entry.getValue(); - - if (attrName.equals(RoughSetsTool.DECISION_ATTR_NAME) - || reductAttr.contains(attrName) || value.equals(name)) { - continue; - } - - ruleStr += MessageFormat.format("{0}={1},", attrName, value); - } - ruleStr += "他的分类为" + decisionValue; - - return ruleStr; - } +class Record { + // 记录名称 + private String name; + // 记录属性键值对 + private HashMap attrValues; + + Record(String name, HashMap attrValues){ + this.name = name; + this.attrValues = attrValues; + } + + public String getName(){ + return this.name; + } + + /** + * 此数据是否包含此属性值 + * + * @param attr 待判断属性值 + */ + boolean isContainedAttr(String attr){ + boolean isContained = false; + + if (attrValues.containsValue(attr)) { + isContained = true; + } + + return isContained; + } + + /** + * 判断数据记录是否是同一条记录,根据数据名称来判断 + * + * @param record 目标比较对象 + */ + boolean isRecordSame(Record record){ + boolean isSame = false; + + if (this.name.equals(record.name)) { + isSame = true; + } + + return isSame; + } + + /** + * 数据的决策属性分类 + */ + String getRecordDecisionClass(){ + return attrValues.get(RoughSetsTool.DECISION_ATTR_NAME); + } + + /** + * 根据约简属性输出决策规则 + * + * @param reductAttr 约简属性集合 + */ + String getDecisionRule(ArrayList reductAttr){ + String ruleStr = ""; + String attrName; + String value; + String decisionValue; + + decisionValue = attrValues.get(RoughSetsTool.DECISION_ATTR_NAME); + ruleStr += "属性"; + for (Map.Entry entry : this.attrValues.entrySet()) { + attrName = (String) entry.getKey(); + value = (String) entry.getValue(); + + if (attrName.equals(RoughSetsTool.DECISION_ATTR_NAME) + || reductAttr.contains(attrName) || value.equals(name)) { + continue; + } + + ruleStr += MessageFormat.format("{0}={1},", attrName, value); + } + ruleStr += "他的分类为" + decisionValue; + + return ruleStr; + } } diff --git a/RoughSets/DataMining_RoughSets/RecordCollection.java b/RoughSets/DataMining_RoughSets/RecordCollection.java index 1fc5665..6262eb6 100644 --- a/RoughSets/DataMining_RoughSets/RecordCollection.java +++ b/RoughSets/DataMining_RoughSets/RecordCollection.java @@ -1,176 +1,158 @@ -package DataMining_RoughSets; +package RoughSets.DataMining_RoughSets; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; +import java.util.stream.Collectors; /** * 数据记录集合,包含一些共同的属性 - * - * @author lyq - * + * + * @author Qstar */ -public class RecordCollection { - // 集合包含的属性 - private HashMap attrValues; - // 数据记录列表 - private ArrayList recordList; - - public RecordCollection() { - this.attrValues = new HashMap<>(); - this.recordList = new ArrayList<>(); - } - - public RecordCollection(HashMap attrValues, - ArrayList recordList) { - this.attrValues = attrValues; - this.recordList = recordList; - } - - public ArrayList getRecord() { - return this.recordList; - } - - /** - * 返回集合的字符名称数组 - * - * @return - */ - public ArrayList getRecordNames() { - ArrayList names = new ArrayList<>(); - - for (int i = 0; i < recordList.size(); i++) { - names.add(recordList.get(i).getName()); - } - - return names; - } - - /** - * 判断集合是否包含此属性名称对应的属性值 - * - * @param attrName - * 属性名 - * @return - */ - public boolean isContainedAttrName(String attrName) { - boolean isContained = false; - - if (this.attrValues.containsKey(attrName)) { - isContained = true; - } - - return isContained; - } - - /** - * 判断2个集合是否相等,比较包含的数据记录是否完全一致 - * - * @param rc - * 待比较集合 - * @return - */ - public boolean isCollectionSame(RecordCollection rc) { - boolean isSame = false; - - for (Record r : recordList) { - isSame = false; - - for (Record r2 : rc.recordList) { - if (r.isRecordSame(r2)) { - isSame = true; - break; - } - } - - // 如果有1个记录不包含,就算集合不相等 - if (!isSame) { - break; - } - } - - return isSame; - } - - /** - * 集合之间的交运算 - * - * @param rc - * 交运算的参与运算的另外一集合 - * @return - */ - public RecordCollection overlapCalculate(RecordCollection rc) { - String key; - String value; - RecordCollection resultCollection = null; - HashMap resultAttrValues = new HashMap<>(); - ArrayList resultRecords = new ArrayList<>(); - - // 进行集合的交运算,有相同的记录的则进行添加 - for (Record record : this.recordList) { - for (Record record2 : rc.recordList) { - if (record.isRecordSame(record2)) { - resultRecords.add(record); - break; - } - } - } - - // 如果没有交集,则直接返回 - if (resultRecords.size() == 0) { - return null; - } - - // 将2个集合的属性进行合并 - for (Map.Entry entry : this.attrValues.entrySet()) { - key = (String) entry.getKey(); - value = (String) entry.getValue(); - - resultAttrValues.put(key, value); - } - - for (Map.Entry entry : rc.attrValues.entrySet()) { - key = (String) entry.getKey(); - value = (String) entry.getValue(); - - resultAttrValues.put(key, value); - } - - resultCollection = new RecordCollection(resultAttrValues, resultRecords); - return resultCollection; - } - - /** - * 求集合的并集,各自保留各自的属性 - * - * @param rc - * 待合并的集合 - * @return - */ - public RecordCollection unionCal(RecordCollection rc) { - RecordCollection resultRc = null; - ArrayList records = new ArrayList<>(); - - for (Record r1 : this.recordList) { - records.add(r1); - } - - for (Record r2 : rc.recordList) { - records.add(r2); - } - - resultRc = new RecordCollection(null, records); - return resultRc; - } - - /** - * 输出集合中包含的元素 - */ - public void printRc(){ - System.out.print("{"); - for (Record r : this.getRecord()) { - System.out.print(r.getName() + ", "); - } - System.out.println("}"); - } +class RecordCollection { + // 集合包含的属性 + private HashMap attrValues; + // 数据记录列表 + private ArrayList recordList; + + RecordCollection(){ + this.attrValues = new HashMap<>(); + this.recordList = new ArrayList<>(); + } + + RecordCollection(HashMap attrValues, + ArrayList recordList){ + this.attrValues = attrValues; + this.recordList = recordList; + } + + ArrayList getRecord(){ + return this.recordList; + } + + /** + * 返回集合的字符名称数组 + */ + ArrayList getRecordNames(){ + return recordList + .stream() + .map(Record::getName) + .collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * 判断集合是否包含此属性名称对应的属性值 + * + * @param attrName 属性名 + */ + boolean isContainedAttrName(String attrName){ + boolean isContained = false; + + if (this.attrValues.containsKey(attrName)) { + isContained = true; + } + + return isContained; + } + + /** + * 判断2个集合是否相等,比较包含的数据记录是否完全一致 + * + * @param rc 待比较集合 + */ + boolean isCollectionSame(RecordCollection rc){ + boolean isSame = false; + + for (Record r : recordList) { + isSame = false; + + for (Record r2 : rc.recordList) { + if (r.isRecordSame(r2)) { + isSame = true; + break; + } + } + + // 如果有1个记录不包含,就算集合不相等 + if (!isSame) { + break; + } + } + + return isSame; + } + + /** + * 集合之间的交运算 + * + * @param rc 交运算的参与运算的另外一集合 + */ + RecordCollection overlapCalculate(RecordCollection rc){ + String key; + String value; + RecordCollection resultCollection; + HashMap resultAttrValues = new HashMap<>(); + ArrayList resultRecords = new ArrayList<>(); + + // 进行集合的交运算,有相同的记录的则进行添加 + for (Record record : this.recordList) { + for (Record record2 : rc.recordList) { + if (record.isRecordSame(record2)) { + resultRecords.add(record); + break; + } + } + } + + // 如果没有交集,则直接返回 + if (resultRecords.size() == 0) { + return null; + } + + // 将2个集合的属性进行合并 + for (Map.Entry entry : this.attrValues.entrySet()) { + key = (String) entry.getKey(); + value = (String) entry.getValue(); + + resultAttrValues.put(key, value); + } + + for (Map.Entry entry : rc.attrValues.entrySet()) { + key = (String) entry.getKey(); + value = (String) entry.getValue(); + + resultAttrValues.put(key, value); + } + + resultCollection = new RecordCollection(resultAttrValues, resultRecords); + return resultCollection; + } + + /** + * 求集合的并集,各自保留各自的属性 + * + * @param rc 待合并的集合 + */ + RecordCollection unionCal(RecordCollection rc){ + RecordCollection resultRc; + ArrayList records = new ArrayList<>(); + + records.addAll(this.recordList); + records.addAll(rc.recordList); + + resultRc = new RecordCollection(null, records); + return resultRc; + } + + /** + * 输出集合中包含的元素 + */ + void printRc(){ + System.out.print("{"); + for (Record r : this.getRecord()) { + System.out.print(r.getName() + ", "); + } + System.out.println("}"); + } } diff --git a/RoughSets/DataMining_RoughSets/RoughSetsTool.java b/RoughSets/DataMining_RoughSets/RoughSetsTool.java index 98fef23..df7ecc1 100644 --- a/RoughSets/DataMining_RoughSets/RoughSetsTool.java +++ b/RoughSets/DataMining_RoughSets/RoughSetsTool.java @@ -1,430 +1,419 @@ -package DataMining_RoughSets; +package RoughSets.DataMining_RoughSets; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.stream.Collectors; /** * 粗糙集属性约简算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class RoughSetsTool { - // 决策属性名称 - public static String DECISION_ATTR_NAME; - - // 测试数据文件地址 - private String filePath; - // 数据属性列名称 - private String[] attrNames; - // 所有的数据 - private ArrayList totalDatas; - // 所有的数据记录,与上面的区别是记录的属性是可约简的,原始数据是不能变的 - private ArrayList totalRecords; - // 条件属性图 - private HashMap> conditionAttr; - // 属性记录集合 - private ArrayList collectionList; - - public RoughSetsTool(String filePath) { - this.filePath = filePath; - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - String[] array; - Record tempRecord; - HashMap attrMap; - ArrayList attrList; - totalDatas = new ArrayList<>(); - totalRecords = new ArrayList<>(); - conditionAttr = new HashMap<>(); - // 赋值属性名称行 - attrNames = dataArray.get(0); - DECISION_ATTR_NAME = attrNames[attrNames.length - 1]; - for (int j = 0; j < dataArray.size(); j++) { - array = dataArray.get(j); - totalDatas.add(array); - if (j == 0) { - // 过滤掉第一行列名称数据 - continue; - } - - attrMap = new HashMap<>(); - for (int i = 0; i < attrNames.length; i++) { - attrMap.put(attrNames[i], array[i]); - - // 寻找条件属性 - if (i > 0 && i < attrNames.length - 1) { - if (conditionAttr.containsKey(attrNames[i])) { - attrList = conditionAttr.get(attrNames[i]); - if (!attrList.contains(array[i])) { - attrList.add(array[i]); - } - } else { - attrList = new ArrayList<>(); - attrList.add(array[i]); - } - conditionAttr.put(attrNames[i], attrList); - } - } - tempRecord = new Record(array[0], attrMap); - totalRecords.add(tempRecord); - } - } - - /** - * 将数据记录根据属性分割到集合中 - */ - private void recordSpiltToCollection() { - String attrName; - ArrayList attrList; - ArrayList recordList; - HashMap collectionAttrValues; - RecordCollection collection; - collectionList = new ArrayList<>(); - - for (Map.Entry entry : conditionAttr.entrySet()) { - attrName = (String) entry.getKey(); - attrList = (ArrayList) entry.getValue(); - - for (String s : attrList) { - recordList = new ArrayList<>(); - // 寻找属性为s的数据记录分入到集合中 - for (Record record : totalRecords) { - if (record.isContainedAttr(s)) { - recordList.add(record); - } - } - collectionAttrValues = new HashMap<>(); - collectionAttrValues.put(attrName, s); - collection = new RecordCollection(collectionAttrValues, - recordList); - - collectionList.add(collection); - } - } - } - - /** - * 构造属性集合图 - * - * @param reductAttr - * 需要约简的属性 - * @return - */ - private HashMap> constructCollectionMap( - ArrayList reductAttr) { - String currentAtttrName; - ArrayList cList; - // 集合属性对应图 - HashMap> collectionMap = new HashMap<>(); - - // 截取出条件属性部分 - for (int i = 1; i < attrNames.length - 1; i++) { - currentAtttrName = attrNames[i]; - - // 判断此属性列是否需要约简 - if (reductAttr != null && reductAttr.contains(currentAtttrName)) { - continue; - } - - cList = new ArrayList<>(); - - for (RecordCollection c : collectionList) { - if (c.isContainedAttrName(currentAtttrName)) { - cList.add(c); - } - } - - collectionMap.put(currentAtttrName, cList); - } - - return collectionMap; - } - - /** - * 根据已有的分裂集合计算知识系统 - */ - private ArrayList computeKnowledgeSystem( - HashMap> collectionMap) { - String attrName = null; - ArrayList cList = null; - // 知识系统 - ArrayList ksCollections; - - ksCollections = new ArrayList<>(); - - // 取出1项 - for (Map.Entry entry : collectionMap.entrySet()) { - attrName = (String) entry.getKey(); - cList = (ArrayList) entry.getValue(); - break; - } - collectionMap.remove(attrName); - - for (RecordCollection rc : cList) { - recurrenceComputeKS(ksCollections, collectionMap, rc); - } - - return ksCollections; - } - - /** - * 递归计算所有的知识系统,通过计算所有集合的交集 - * - * @param ksCollection - * 已经求得知识系统的集合 - * @param map - * 还未曾进行过交运算的集合 - * @param preCollection - * 前个步骤中已经通过交运算计算出的集合 - */ - private void recurrenceComputeKS(ArrayList ksCollections, - HashMap> map, - RecordCollection preCollection) { - String attrName = null; - RecordCollection tempCollection; - ArrayList cList = null; - HashMap> mapCopy = new HashMap<>(); - - //如果已经没有数据了,则直接添加 - if(map.size() == 0){ - ksCollections.add(preCollection); - return; - } - - for (Map.Entry entry : map.entrySet()) { - cList = (ArrayList) entry.getValue(); - mapCopy.put((String) entry.getKey(), cList); - } - - // 取出1项 - for (Map.Entry entry : map.entrySet()) { - attrName = (String) entry.getKey(); - cList = (ArrayList) entry.getValue(); - break; - } - - mapCopy.remove(attrName); - for (RecordCollection rc : cList) { - // 挑选此属性的一个集合进行交运算,然后再次递归 - tempCollection = preCollection.overlapCalculate(rc); - - if (tempCollection == null) { - continue; - } - - // 如果map中已经没有数据了,说明递归到头了 - if (mapCopy.size() == 0) { - ksCollections.add(tempCollection); - } else { - recurrenceComputeKS(ksCollections, mapCopy, tempCollection); - } - } - } - - /** - * 进行粗糙集属性约简算法 - */ - public void findingReduct() { - RecordCollection[] sameClassRcs; - KnowledgeSystem ks; - ArrayList ksCollections; - // 待约简的属性 - ArrayList reductAttr = null; - ArrayList attrNameList; - // 最终可约简的属性组 - ArrayList> canReductAttrs; - HashMap> collectionMap; - - sameClassRcs = selectTheSameClassRC(); - // 这里讲数据按照各个分类的小属性划分了9个集合 - recordSpiltToCollection(); - - collectionMap = constructCollectionMap(reductAttr); - ksCollections = computeKnowledgeSystem(collectionMap); - ks = new KnowledgeSystem(ksCollections); - System.out.println("原始集合分类的上下近似集合"); - ks.getDownSimilarRC(sameClassRcs[0]).printRc(); - ks.getUpSimilarRC(sameClassRcs[0]).printRc(); - ks.getDownSimilarRC(sameClassRcs[1]).printRc(); - ks.getUpSimilarRC(sameClassRcs[1]).printRc(); - - attrNameList = new ArrayList<>(); - for (int i = 1; i < attrNames.length - 1; i++) { - attrNameList.add(attrNames[i]); - } - - ArrayList remainAttr; - canReductAttrs = new ArrayList<>(); - reductAttr = new ArrayList<>(); - // 进行条件属性的递归约简 - for (String s : attrNameList) { - remainAttr = (ArrayList) attrNameList.clone(); - remainAttr.remove(s); - reductAttr = new ArrayList<>(); - reductAttr.add(s); - recurrenceFindingReduct(canReductAttrs, reductAttr, remainAttr, - sameClassRcs); - } - - printRules(canReductAttrs); - } - - /** - * 递归进行属性约简 - * - * @param resultAttr - * 已经计算出的约简属性组 - * @param reductAttr - * 将要约简的属性组 - * @param remainAttr - * 剩余的属性 - * @param sameClassRc - * 待计算上下近似集合的同类集合 - */ - private void recurrenceFindingReduct( - ArrayList> resultAttr, - ArrayList reductAttr, ArrayList remainAttr, - RecordCollection[] sameClassRc) { - KnowledgeSystem ks; - ArrayList ksCollections; - ArrayList copyRemainAttr; - ArrayList copyReductAttr; - HashMap> collectionMap; - RecordCollection upRc1; - RecordCollection downRc1; - RecordCollection upRc2; - RecordCollection downRc2; - - collectionMap = constructCollectionMap(reductAttr); - ksCollections = computeKnowledgeSystem(collectionMap); - ks = new KnowledgeSystem(ksCollections); - - downRc1 = ks.getDownSimilarRC(sameClassRc[0]); - upRc1 = ks.getUpSimilarRC(sameClassRc[0]); - downRc2 = ks.getDownSimilarRC(sameClassRc[1]); - upRc2 = ks.getUpSimilarRC(sameClassRc[1]); - - // 如果上下近似没有完全拟合原集合则认为属性不能被约简 - if (!upRc1.isCollectionSame(sameClassRc[0]) - || !downRc1.isCollectionSame(sameClassRc[0])) { - return; - } - //正类和负类都需比较 - if (!upRc2.isCollectionSame(sameClassRc[1]) - || !downRc2.isCollectionSame(sameClassRc[1])) { - return; - } - - // 加入到结果集中 - resultAttr.add(reductAttr); - //只剩下1个属性不能再约简 - if (remainAttr.size() == 1) { - return; - } - - for (String s : remainAttr) { - copyRemainAttr = (ArrayList) remainAttr.clone(); - copyReductAttr = (ArrayList) reductAttr.clone(); - copyRemainAttr.remove(s); - copyReductAttr.add(s); - recurrenceFindingReduct(resultAttr, copyReductAttr, copyRemainAttr, - sameClassRc); - } - } - - /** - * 选出决策属性一致的集合 - * - * @return - */ - private RecordCollection[] selectTheSameClassRC() { - RecordCollection[] resultRc = new RecordCollection[2]; - resultRc[0] = new RecordCollection(); - resultRc[1] = new RecordCollection(); - String attrValue; - - // 找出第一个记录的决策属性作为一个分类 - attrValue = totalRecords.get(0).getRecordDecisionClass(); - for (Record r : totalRecords) { - if (attrValue.equals(r.getRecordDecisionClass())) { - resultRc[0].getRecord().add(r); - }else{ - resultRc[1].getRecord().add(r); - } - } - - return resultRc; - } - - /** - * 输出决策规则 - * @param reductAttrArray - * 约简属性组 - */ - public void printRules(ArrayList> reductAttrArray){ - //用来保存已经描述过的规则,避免重复输出 - ArrayList rulesArray; - String rule; - - for(ArrayList ra: reductAttrArray){ - rulesArray = new ArrayList<>(); - System.out.print("约简的属性:"); - for(String s: ra){ - System.out.print(s + ","); - } - System.out.println(); - - for(Record r: totalRecords){ - rule = r.getDecisionRule(ra); - if(!rulesArray.contains(rule)){ - rulesArray.add(rule); - System.out.println(rule); - } - } - System.out.println(); - } - } - - /** - * 输出记录集合 - * - * @param rcList - * 待输出记录集合 - */ - public void printRecordCollectionList(ArrayList rcList) { - for (RecordCollection rc : rcList) { - System.out.print("{"); - for (Record r : rc.getRecord()) { - System.out.print(r.getName() + ", "); - } - System.out.println("}"); - } - } +class RoughSetsTool { + // 决策属性名称 + static String DECISION_ATTR_NAME; + + // 测试数据文件地址 + private String filePath; + // 数据属性列名称 + private String[] attrNames; + // 所有的数据 + private ArrayList totalDatas; + // 所有的数据记录,与上面的区别是记录的属性是可约简的,原始数据是不能变的 + private ArrayList totalRecords; + // 条件属性图 + private HashMap> conditionAttr; + // 属性记录集合 + private ArrayList collectionList; + + RoughSetsTool(String filePath){ + this.filePath = filePath; + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + String[] array; + Record tempRecord; + HashMap attrMap; + ArrayList attrList; + totalDatas = new ArrayList<>(); + totalRecords = new ArrayList<>(); + conditionAttr = new HashMap<>(); + // 赋值属性名称行 + attrNames = dataArray.get(0); + DECISION_ATTR_NAME = attrNames[attrNames.length - 1]; + for (int j = 0; j < dataArray.size(); j++) { + array = dataArray.get(j); + totalDatas.add(array); + if (j == 0) { + // 过滤掉第一行列名称数据 + continue; + } + + attrMap = new HashMap<>(); + for (int i = 0; i < attrNames.length; i++) { + attrMap.put(attrNames[i], array[i]); + + // 寻找条件属性 + if (i > 0 && i < attrNames.length - 1) { + if (conditionAttr.containsKey(attrNames[i])) { + attrList = conditionAttr.get(attrNames[i]); + if (!attrList.contains(array[i])) { + attrList.add(array[i]); + } + } else { + attrList = new ArrayList<>(); + attrList.add(array[i]); + } + conditionAttr.put(attrNames[i], attrList); + } + } + tempRecord = new Record(array[0], attrMap); + totalRecords.add(tempRecord); + } + } + + /** + * 将数据记录根据属性分割到集合中 + */ + private void recordSpiltToCollection(){ + String attrName; + ArrayList attrList; + ArrayList recordList; + HashMap collectionAttrValues; + RecordCollection collection; + collectionList = new ArrayList<>(); + + for (Map.Entry entry : conditionAttr.entrySet()) { + attrName = (String) entry.getKey(); + attrList = (ArrayList) entry.getValue(); + + for (String s : attrList) { + recordList = new ArrayList<>(); + // 寻找属性为s的数据记录分入到集合中 + recordList.addAll(totalRecords + .stream() + .filter(record -> record.isContainedAttr(s)) + .collect(Collectors.toList())); + collectionAttrValues = new HashMap<>(); + collectionAttrValues.put(attrName, s); + collection = new RecordCollection(collectionAttrValues, + recordList); + + collectionList.add(collection); + } + } + } + + /** + * 构造属性集合图 + * + * @param reductAttr 需要约简的属性 + */ + private HashMap> constructCollectionMap( + ArrayList reductAttr){ + String currentAtttrName; + ArrayList cList; + // 集合属性对应图 + HashMap> collectionMap = new HashMap<>(); + + // 截取出条件属性部分 + for (int i = 1; i < attrNames.length - 1; i++) { + currentAtttrName = attrNames[i]; + + // 判断此属性列是否需要约简 + if (reductAttr != null && reductAttr.contains(currentAtttrName)) { + continue; + } + + cList = new ArrayList<>(); + + for (RecordCollection c : collectionList) { + if (c.isContainedAttrName(currentAtttrName)) { + cList.add(c); + } + } + + collectionMap.put(currentAtttrName, cList); + } + + return collectionMap; + } + + /** + * 根据已有的分裂集合计算知识系统 + */ + private ArrayList computeKnowledgeSystem( + HashMap> collectionMap){ + String attrName = null; + ArrayList cList = null; + // 知识系统 + ArrayList ksCollections; + + ksCollections = new ArrayList<>(); + + // 取出1项 + for (Map.Entry entry : collectionMap.entrySet()) { + attrName = (String) entry.getKey(); + cList = (ArrayList) entry.getValue(); + break; + } + collectionMap.remove(attrName); + + if (cList != null) { + for (RecordCollection rc : cList) { + recurrenceComputeKS(ksCollections, collectionMap, rc); + } + } + + return ksCollections; + } + + /** + * 递归计算所有的知识系统,通过计算所有集合的交集 + * + * @param ksCollections 已经求得知识系统的集合 + * @param map 还未曾进行过交运算的集合 + * @param preCollection 前个步骤中已经通过交运算计算出的集合 + */ + private void recurrenceComputeKS(ArrayList ksCollections, + HashMap> map, + RecordCollection preCollection){ + String attrName = null; + RecordCollection tempCollection; + ArrayList cList = null; + HashMap> mapCopy = new HashMap<>(); + + //如果已经没有数据了,则直接添加 + if (map.size() == 0) { + ksCollections.add(preCollection); + return; + } + + for (Map.Entry entry : map.entrySet()) { + cList = (ArrayList) entry.getValue(); + mapCopy.put((String) entry.getKey(), cList); + } + + // 取出1项 + for (Map.Entry entry : map.entrySet()) { + attrName = (String) entry.getKey(); + cList = (ArrayList) entry.getValue(); + break; + } + + mapCopy.remove(attrName); + if (cList != null) { + for (RecordCollection rc : cList) { + // 挑选此属性的一个集合进行交运算,然后再次递归 + tempCollection = preCollection.overlapCalculate(rc); + + if (tempCollection == null) { + continue; + } + + // 如果map中已经没有数据了,说明递归到头了 + if (mapCopy.size() == 0) { + ksCollections.add(tempCollection); + } else { + recurrenceComputeKS(ksCollections, mapCopy, tempCollection); + } + } + } + } + + /** + * 进行粗糙集属性约简算法 + */ + void findingReduct(){ + RecordCollection[] sameClassRcs; + KnowledgeSystem ks; + ArrayList ksCollections; + // 待约简的属性 + ArrayList reductAttr; + ArrayList attrNameList; + // 最终可约简的属性组 + ArrayList> canReductAttrs; + HashMap> collectionMap; + + sameClassRcs = selectTheSameClassRC(); + // 这里讲数据按照各个分类的小属性划分了9个集合 + recordSpiltToCollection(); + + collectionMap = constructCollectionMap(null); + ksCollections = computeKnowledgeSystem(collectionMap); + ks = new KnowledgeSystem(ksCollections); + System.out.println("原始集合分类的上下近似集合"); + ks.getDownSimilarRC(sameClassRcs[0]).printRc(); + ks.getUpSimilarRC(sameClassRcs[0]).printRc(); + ks.getDownSimilarRC(sameClassRcs[1]).printRc(); + ks.getUpSimilarRC(sameClassRcs[1]).printRc(); + + attrNameList = new ArrayList<>(); + attrNameList.addAll(Arrays.asList(attrNames).subList(1, attrNames.length - 1)); + + ArrayList remainAttr; + canReductAttrs = new ArrayList<>(); + // 进行条件属性的递归约简 + for (String s : attrNameList) { + remainAttr = (ArrayList) attrNameList.clone(); + remainAttr.remove(s); + reductAttr = new ArrayList<>(); + reductAttr.add(s); + recurrenceFindingReduct(canReductAttrs, reductAttr, remainAttr, + sameClassRcs); + } + + printRules(canReductAttrs); + } + + /** + * 递归进行属性约简 + * + * @param resultAttr 已经计算出的约简属性组 + * @param reductAttr 将要约简的属性组 + * @param remainAttr 剩余的属性 + * @param sameClassRc 待计算上下近似集合的同类集合 + */ + private void recurrenceFindingReduct( + ArrayList> resultAttr, + ArrayList reductAttr, ArrayList remainAttr, + RecordCollection[] sameClassRc){ + KnowledgeSystem ks; + ArrayList ksCollections; + ArrayList copyRemainAttr; + ArrayList copyReductAttr; + HashMap> collectionMap; + RecordCollection upRc1; + RecordCollection downRc1; + RecordCollection upRc2; + RecordCollection downRc2; + + collectionMap = constructCollectionMap(reductAttr); + ksCollections = computeKnowledgeSystem(collectionMap); + ks = new KnowledgeSystem(ksCollections); + + downRc1 = ks.getDownSimilarRC(sameClassRc[0]); + upRc1 = ks.getUpSimilarRC(sameClassRc[0]); + downRc2 = ks.getDownSimilarRC(sameClassRc[1]); + upRc2 = ks.getUpSimilarRC(sameClassRc[1]); + + // 如果上下近似没有完全拟合原集合则认为属性不能被约简 + if (!upRc1.isCollectionSame(sameClassRc[0]) + || !downRc1.isCollectionSame(sameClassRc[0])) { + return; + } + //正类和负类都需比较 + if (!upRc2.isCollectionSame(sameClassRc[1]) + || !downRc2.isCollectionSame(sameClassRc[1])) { + return; + } + + // 加入到结果集中 + resultAttr.add(reductAttr); + //只剩下1个属性不能再约简 + if (remainAttr.size() == 1) { + return; + } + + for (String s : remainAttr) { + copyRemainAttr = (ArrayList) remainAttr.clone(); + copyReductAttr = (ArrayList) reductAttr.clone(); + copyRemainAttr.remove(s); + copyReductAttr.add(s); + recurrenceFindingReduct(resultAttr, copyReductAttr, copyRemainAttr, + sameClassRc); + } + } + + /** + * 选出决策属性一致的集合 + */ + private RecordCollection[] selectTheSameClassRC(){ + RecordCollection[] resultRc = new RecordCollection[2]; + resultRc[0] = new RecordCollection(); + resultRc[1] = new RecordCollection(); + String attrValue; + + // 找出第一个记录的决策属性作为一个分类 + attrValue = totalRecords.get(0).getRecordDecisionClass(); + for (Record r : totalRecords) { + if (attrValue.equals(r.getRecordDecisionClass())) { + resultRc[0].getRecord().add(r); + } else { + resultRc[1].getRecord().add(r); + } + } + + return resultRc; + } + + /** + * 输出决策规则 + * + * @param reductAttrArray 约简属性组 + */ + private void printRules(ArrayList> reductAttrArray){ + //用来保存已经描述过的规则,避免重复输出 + ArrayList rulesArray; + String rule; + + for (ArrayList ra : reductAttrArray) { + rulesArray = new ArrayList<>(); + System.out.print("约简的属性:"); + for (String s : ra) { + System.out.print(s + ","); + } + System.out.println(); + + for (Record r : totalRecords) { + rule = r.getDecisionRule(ra); + if (!rulesArray.contains(rule)) { + rulesArray.add(rule); + System.out.println(rule); + } + } + System.out.println(); + } + } + + /** + * 输出记录集合 + * + * @param rcList 待输出记录集合 + */ + public void printRecordCollectionList(ArrayList rcList){ + for (RecordCollection rc : rcList) { + System.out.print("{"); + for (Record r : rc.getRecord()) { + System.out.print(r.getName() + ", "); + } + System.out.println("}"); + } + } } diff --git a/SequentialPatterns/DataMining_GSP/Client.java b/SequentialPatterns/DataMining_GSP/Client.java index 7c9d0d7..43c596f 100644 --- a/SequentialPatterns/DataMining_GSP/Client.java +++ b/SequentialPatterns/DataMining_GSP/Client.java @@ -1,21 +1,21 @@ -package DataMining_GSP; +package SequentialPatterns.DataMining_GSP; /** * GSP序列模式分析算法 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt"; - //最小支持度阈值 - int minSupportCount = 2; - //时间最小间隔 - int min_gap = 1; - //施加最大间隔 - int max_gap = 5; - - GSPTool tool = new GSPTool(filePath, minSupportCount, min_gap, max_gap); - tool.gspCalculate(); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/SequentialPatterns/DataMining_GSP/testInput.txt"; + //最小支持度阈值 + int minSupportCount = 2; + //时间最小间隔 + int min_gap = 1; + //施加最大间隔 + int max_gap = 5; + + GSPTool tool = new GSPTool(filePath, minSupportCount, min_gap, max_gap); + tool.gspCalculate(); + } } diff --git a/SequentialPatterns/DataMining_GSP/GSPTool.java b/SequentialPatterns/DataMining_GSP/GSPTool.java index 3f08e0c..0a1903b 100644 --- a/SequentialPatterns/DataMining_GSP/GSPTool.java +++ b/SequentialPatterns/DataMining_GSP/GSPTool.java @@ -1,533 +1,515 @@ -package DataMining_GSP; +package SequentialPatterns.DataMining_GSP; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; +import java.util.*; +import java.util.stream.Collectors; /** * GSP序列模式分析算法 - * - * @author lyq - * + * + * @author Qstar */ -public class GSPTool { - // 测试数据文件地址 - private String filePath; - // 最小支持度阈值 - private int minSupportCount; - // 时间最小间隔 - private int min_gap; - // 时间最大间隔 - private int max_gap; - // 原始数据序列 - private ArrayList totalSequences; - // GSP算法中产生的所有的频繁项集序列 - private ArrayList totalFrequencySeqs; - // 序列项数字对时间的映射图容器 - private ArrayList>> itemNum2Time; - - public GSPTool(String filePath, int minSupportCount, int min_gap, - int max_gap) { - this.filePath = filePath; - this.minSupportCount = minSupportCount; - this.min_gap = min_gap; - this.max_gap = max_gap; - totalFrequencySeqs = new ArrayList<>(); - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - HashMap mapSeq = new HashMap<>(); - Sequence seq; - ItemSet itemSet; - int tID; - String[] itemStr; - for (String[] str : dataArray) { - tID = Integer.parseInt(str[0]); - itemStr = new String[Integer.parseInt(str[1])]; - System.arraycopy(str, 2, itemStr, 0, itemStr.length); - itemSet = new ItemSet(itemStr); - - if (mapSeq.containsKey(tID)) { - seq = mapSeq.get(tID); - } else { - seq = new Sequence(tID); - } - seq.getItemSetList().add(itemSet); - mapSeq.put(tID, seq); - } - - // 将序列图加入到序列List中 - totalSequences = new ArrayList<>(); - for (Map.Entry entry : mapSeq.entrySet()) { - totalSequences.add((Sequence) entry.getValue()); - } - } - - /** - * 生成1频繁项集 - * - * @return - */ - private ArrayList generateOneFrequencyItem() { - int count = 0; - int currentTransanctionID = 0; - Sequence tempSeq; - ItemSet tempItemSet; - HashMap itemNumMap = new HashMap<>(); - ArrayList seqList = new ArrayList<>(); - - for (Sequence seq : totalSequences) { - for (ItemSet itemSet : seq.getItemSetList()) { - for (int num : itemSet.getItems()) { - // 如果没有此种类型项,则进行添加操作 - if (!itemNumMap.containsKey(num)) { - itemNumMap.put(num, 1); - } - } - } - } - - boolean isContain = false; - int number = 0; - for (Map.Entry entry : itemNumMap.entrySet()) { - count = 0; - number = (int) entry.getKey(); - for (Sequence seq : totalSequences) { - isContain = false; - - for (ItemSet itemSet : seq.getItemSetList()) { - for (int num : itemSet.getItems()) { - // 如果没有此种类型项,则进行添加操作 - if (num == number) { - isContain = true; - break; - } - } - - if(isContain){ - break; - } - } - - if(isContain){ - count++; - } - } - - itemNumMap.put(number, count); - } - - - for (Map.Entry entry : itemNumMap.entrySet()) { - count = (int) entry.getValue(); - if (count >= minSupportCount) { - tempSeq = new Sequence(); - tempItemSet = new ItemSet(new int[] { (int) entry.getKey() }); - - tempSeq.getItemSetList().add(tempItemSet); - seqList.add(tempSeq); - } - - } - // 将序列升序排列 - Collections.sort(seqList); - // 将频繁1项集加入总频繁项集列表中 - totalFrequencySeqs.addAll(seqList); - - return seqList; - } - - /** - * 通过1频繁项集连接产生2频繁项集 - * - * @param oneSeq - * 1频繁项集序列 - * @return - */ - private ArrayList generateTwoFrequencyItem( - ArrayList oneSeq) { - Sequence tempSeq; - ArrayList resultSeq = new ArrayList<>(); - ItemSet tempItemSet; - int num1; - int num2; - - // 假如将,2个1频繁项集做连接组合,可以分为,4个序列模式 - // 注意此时的每个序列中包含2个独立项集 - for (int i = 0; i < oneSeq.size(); i++) { - num1 = oneSeq.get(i).getFirstItemSetNum(); - for (int j = 0; j < oneSeq.size(); j++) { - num2 = oneSeq.get(j).getFirstItemSetNum(); - - tempSeq = new Sequence(); - tempItemSet = new ItemSet(new int[] { num1 }); - tempSeq.getItemSetList().add(tempItemSet); - tempItemSet = new ItemSet(new int[] { num2 }); - tempSeq.getItemSetList().add(tempItemSet); - - if (countSupport(tempSeq) >= minSupportCount) { - resultSeq.add(tempSeq); - } - } - } - - // 上面连接还有1种情况是每个序列中只包含有一个项集的情况,此时a,b的划分则是<(a,a)> <(a,b)> <(b,b)> - for (int i = 0; i < oneSeq.size(); i++) { - num1 = oneSeq.get(i).getFirstItemSetNum(); - for (int j = i; j < oneSeq.size(); j++) { - num2 = oneSeq.get(j).getFirstItemSetNum(); - - tempSeq = new Sequence(); - tempItemSet = new ItemSet(new int[] { num1, num2 }); - tempSeq.getItemSetList().add(tempItemSet); - - if (countSupport(tempSeq) >= minSupportCount) { - resultSeq.add(tempSeq); - } - } - } - // 同样将2频繁项集加入到总频繁项集中 - totalFrequencySeqs.addAll(resultSeq); - - return resultSeq; - } - - /** - * 根据上次的频繁集连接产生新的侯选集 - * - * @param seqList - * 上次产生的候选集 - * @return - */ - private ArrayList generateCandidateItem( - ArrayList seqList) { - Sequence tempSeq; - ArrayList tempNumArray; - ArrayList resultSeq = new ArrayList<>(); - // 序列数字项列表 - ArrayList> seqNums = new ArrayList<>(); - - for (int i = 0; i < seqList.size(); i++) { - tempNumArray = new ArrayList<>(); - tempSeq = seqList.get(i); - for (ItemSet itemSet : tempSeq.getItemSetList()) { - tempNumArray.addAll(itemSet.copyItems()); - } - seqNums.add(tempNumArray); - } - - ArrayList array1; - ArrayList array2; - // 序列i,j的拷贝 - Sequence seqi = null; - Sequence seqj = null; - // 判断是否能够连接,默认能连接 - boolean canConnect = true; - // 进行连接运算,包括自己与自己连接 - for (int i = 0; i < seqNums.size(); i++) { - for (int j = 0; j < seqNums.size(); j++) { - array1 = (ArrayList) seqNums.get(i).clone(); - array2 = (ArrayList) seqNums.get(j).clone(); - - // 将第一个数字组去掉第一个,第二个数字组去掉最后一个,如果剩下的部分相等,则可以连接 - array1.remove(0); - array2.remove(array2.size() - 1); - - canConnect = true; - for (int k = 0; k < array1.size(); k++) { - if (array1.get(k) != array2.get(k)) { - canConnect = false; - break; - } - } - - if (canConnect) { - seqi = seqList.get(i).copySeqence(); - seqj = seqList.get(j).copySeqence(); - - int lastItemNum = seqj.getLastItemSetNum(); - if (seqj.isLastItemSetSingleNum()) { - // 如果j序列的最后项集为单一值,则最后一个数字以独立项集加入i序列 - ItemSet itemSet = new ItemSet(new int[] { lastItemNum }); - seqi.getItemSetList().add(itemSet); - } else { - // 如果j序列的最后项集为非单一值,则最后一个数字加入i序列最后一个项集中 - ItemSet itemSet = seqi.getLastItemSet(); - itemSet.getItems().add(lastItemNum); - } - - // 判断是否超过最小支持度阈值 - if (isChildSeqContained(seqi) - && countSupport(seqi) >= minSupportCount) { - resultSeq.add(seqi); - } - } - } - } - - totalFrequencySeqs.addAll(resultSeq); - return resultSeq; - } - - /** - * 判断此序列的所有子序列是否也是频繁序列 - * - * @param seq - * 待比较序列 - * @return - */ - private boolean isChildSeqContained(Sequence seq) { - boolean isContained = false; - ArrayList childSeqs; - - childSeqs = seq.createChildSeqs(); - for (Sequence tempSeq : childSeqs) { - isContained = false; - - for (Sequence frequencySeq : totalFrequencySeqs) { - if (tempSeq.compareIsSame(frequencySeq)) { - isContained = true; - break; - } - } - - if (!isContained) { - break; - } - } - - return isContained; - } - - /** - * 候选集判断支持度的值 - * - * @param seq - * 待判断序列 - * @return - */ - private int countSupport(Sequence seq) { - int count = 0; - int matchNum = 0; - Sequence tempSeq; - ItemSet tempItemSet; - HashMap timeMap; - ArrayList itemSetList; - ArrayList> numArray = new ArrayList<>(); - // 每项集对应的时间链表 - ArrayList> timeArray = new ArrayList<>(); - - for (ItemSet itemSet : seq.getItemSetList()) { - numArray.add(itemSet.getItems()); - } - - for (int i = 0; i < totalSequences.size(); i++) { - timeArray = new ArrayList<>(); - - for (int s = 0; s < numArray.size(); s++) { - ArrayList childNum = numArray.get(s); - ArrayList localTime = new ArrayList<>(); - tempSeq = totalSequences.get(i); - itemSetList = tempSeq.getItemSetList(); - - for (int j = 0; j < itemSetList.size(); j++) { - tempItemSet = itemSetList.get(j); - matchNum = 0; - int t = 0; - - if (tempItemSet.getItems().size() == childNum.size()) { - timeMap = itemNum2Time.get(i).get(j); - // 只有当项集长度匹配时才匹配 - for (int k = 0; k < childNum.size(); k++) { - if (timeMap.containsKey(childNum.get(k))) { - matchNum++; - t = timeMap.get(childNum.get(k)); - } - } - - // 如果完全匹配,则记录时间 - if (matchNum == childNum.size()) { - localTime.add(t); - } - } - - } - - if (localTime.size() > 0) { - timeArray.add(localTime); - } - } - - // 判断时间是否满足时间最大最小约束,如果满足,则此条事务包含候选事务 - if (timeArray.size() == numArray.size() - && judgeTimeInGap(timeArray)) { - count++; - } - } - - return count; - } - - /** - * 判断事务是否满足时间约束 - * - * @param timeArray - * 时间数组,每行代表各项集的在事务中的发生时间链表 - * @return - */ - private boolean judgeTimeInGap(ArrayList> timeArray) { - boolean result = false; - int preTime = 0; - ArrayList firstTimes = timeArray.get(0); - timeArray.remove(0); - - if (timeArray.size() == 0) { - return false; - } - - for (int i = 0; i < firstTimes.size(); i++) { - preTime = firstTimes.get(i); - - if (dfsJudgeTime(preTime, timeArray)) { - result = true; - break; - } - } - - return result; - } - - /** - * 深度优先遍历时间,判断是否有符合条件的时间间隔 - * - * @param preTime - * @param timeArray - * @return - */ - private boolean dfsJudgeTime(int preTime, - ArrayList> timeArray) { - boolean result = false; - ArrayList> timeArrayClone = (ArrayList>) timeArray - .clone(); - ArrayList firstItemItem = timeArrayClone.get(0); - - for (int i = 0; i < firstItemItem.size(); i++) { - if (firstItemItem.get(i) - preTime >= min_gap - && firstItemItem.get(i) - preTime <= max_gap) { - // 如果此2项间隔时间满足时间约束,则继续往下递归 - preTime = firstItemItem.get(i); - timeArrayClone.remove(0); - - if (timeArrayClone.size() == 0) { - return true; - } else { - result = dfsJudgeTime(preTime, timeArrayClone); - if (result) { - return true; - } - } - } - } - - return result; - } - - /** - * 初始化序列项到时间的序列图,为了后面的时间约束计算 - */ - private void initItemNumToTimeMap() { - Sequence seq; - itemNum2Time = new ArrayList<>(); - HashMap tempMap; - ArrayList> tempMapList; - - for (int i = 0; i < totalSequences.size(); i++) { - seq = totalSequences.get(i); - tempMapList = new ArrayList<>(); - - for (int j = 0; j < seq.getItemSetList().size(); j++) { - ItemSet itemSet = seq.getItemSetList().get(j); - tempMap = new HashMap<>(); - for (int itemNum : itemSet.getItems()) { - tempMap.put(itemNum, j + 1); - } - - tempMapList.add(tempMap); - } - - itemNum2Time.add(tempMapList); - } - } - - /** - * 进行GSP算法计算 - */ - public void gspCalculate() { - ArrayList oneSeq; - ArrayList twoSeq; - ArrayList candidateSeq; - - initItemNumToTimeMap(); - oneSeq = generateOneFrequencyItem(); - twoSeq = generateTwoFrequencyItem(oneSeq); - candidateSeq = twoSeq; - - // 不断连接生产候选集,直到没有产生出侯选集 - for (;;) { - candidateSeq = generateCandidateItem(candidateSeq); - - if (candidateSeq.size() == 0) { - break; - } - } - - outputSeqence(totalFrequencySeqs); - - } - - /** - * 输出序列列表信息 - * - * @param outputSeqList - * 待输出序列列表 - */ - private void outputSeqence(ArrayList outputSeqList) { - for (Sequence seq : outputSeqList) { - System.out.print("<"); - for (ItemSet itemSet : seq.getItemSetList()) { - System.out.print("("); - for (int num : itemSet.getItems()) { - System.out.print(num + ","); - } - System.out.print("), "); - } - System.out.println(">"); - } - } +class GSPTool { + // 测试数据文件地址 + private String filePath; + // 最小支持度阈值 + private int minSupportCount; + // 时间最小间隔 + private int min_gap; + // 时间最大间隔 + private int max_gap; + // 原始数据序列 + private ArrayList totalSequences; + // GSP算法中产生的所有的频繁项集序列 + private ArrayList totalFrequencySeqs; + // 序列项数字对时间的映射图容器 + private ArrayList>> itemNum2Time; + + GSPTool(String filePath, int minSupportCount, int min_gap, + int max_gap){ + this.filePath = filePath; + this.minSupportCount = minSupportCount; + this.min_gap = min_gap; + this.max_gap = max_gap; + totalFrequencySeqs = new ArrayList<>(); + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + HashMap mapSeq = new HashMap<>(); + Sequence seq; + ItemSet itemSet; + int tID; + String[] itemStr; + for (String[] str : dataArray) { + tID = Integer.parseInt(str[0]); + itemStr = new String[Integer.parseInt(str[1])]; + System.arraycopy(str, 2, itemStr, 0, itemStr.length); + itemSet = new ItemSet(itemStr); + + if (mapSeq.containsKey(tID)) { + seq = mapSeq.get(tID); + } else { + seq = new Sequence(tID); + } + seq.getItemSetList().add(itemSet); + mapSeq.put(tID, seq); + } + + // 将序列图加入到序列List中 + totalSequences = new ArrayList<>(); + totalSequences.addAll(mapSeq.entrySet() + .stream() + .map(Map.Entry::getValue) + .collect(Collectors.toList())); + } + + /** + * 生成1频繁项集 + */ + private ArrayList generateOneFrequencyItem(){ + int count; + Sequence tempSeq; + ItemSet tempItemSet; + HashMap itemNumMap = new HashMap<>(); + ArrayList seqList = new ArrayList<>(); + + for (Sequence seq : totalSequences) { + for (ItemSet itemSet : seq.getItemSetList()) { + // 如果没有此种类型项,则进行添加操作 + itemSet.getItems() + .stream() + .filter(num -> !itemNumMap.containsKey(num)) + .forEach(num -> itemNumMap.put(num, 1)); + } + } + + boolean isContain; + int number; + for (Map.Entry entry : itemNumMap.entrySet()) { + count = 0; + number = (int) entry.getKey(); + for (Sequence seq : totalSequences) { + isContain = false; + + for (ItemSet itemSet : seq.getItemSetList()) { + for (int num : itemSet.getItems()) { + // 如果没有此种类型项,则进行添加操作 + if (num == number) { + isContain = true; + break; + } + } + + if (isContain) { + break; + } + } + + if (isContain) { + count++; + } + } + + itemNumMap.put(number, count); + } + + + for (Map.Entry entry : itemNumMap.entrySet()) { + count = (int) entry.getValue(); + if (count >= minSupportCount) { + tempSeq = new Sequence(); + tempItemSet = new ItemSet(new int[]{(int) entry.getKey()}); + + tempSeq.getItemSetList().add(tempItemSet); + seqList.add(tempSeq); + } + + } + // 将序列升序排列 + Collections.sort(seqList); + // 将频繁1项集加入总频繁项集列表中 + totalFrequencySeqs.addAll(seqList); + + return seqList; + } + + /** + * 通过1频繁项集连接产生2频繁项集 + * + * @param oneSeq 1频繁项集序列 + */ + private ArrayList generateTwoFrequencyItem( + ArrayList oneSeq){ + Sequence tempSeq; + ArrayList resultSeq = new ArrayList<>(); + ItemSet tempItemSet; + int num1; + int num2; + + // 假如将,2个1频繁项集做连接组合,可以分为,4个序列模式 + // 注意此时的每个序列中包含2个独立项集 + for (int i = 0; i < oneSeq.size(); i++) { + num1 = oneSeq.get(i).getFirstItemSetNum(); + for (Sequence anOneSeq : oneSeq) { + num2 = anOneSeq.getFirstItemSetNum(); + + tempSeq = new Sequence(); + tempItemSet = new ItemSet(new int[]{num1}); + tempSeq.getItemSetList().add(tempItemSet); + tempItemSet = new ItemSet(new int[]{num2}); + tempSeq.getItemSetList().add(tempItemSet); + + if (countSupport(tempSeq) >= minSupportCount) { + resultSeq.add(tempSeq); + } + } + } + + // 上面连接还有1种情况是每个序列中只包含有一个项集的情况,此时a,b的划分则是<(a,a)> <(a,b)> <(b,b)> + for (int i = 0; i < oneSeq.size(); i++) { + num1 = oneSeq.get(i).getFirstItemSetNum(); + for (int j = i; j < oneSeq.size(); j++) { + num2 = oneSeq.get(j).getFirstItemSetNum(); + + tempSeq = new Sequence(); + tempItemSet = new ItemSet(new int[]{num1, num2}); + tempSeq.getItemSetList().add(tempItemSet); + + if (countSupport(tempSeq) >= minSupportCount) { + resultSeq.add(tempSeq); + } + } + } + // 同样将2频繁项集加入到总频繁项集中 + totalFrequencySeqs.addAll(resultSeq); + + return resultSeq; + } + + /** + * 根据上次的频繁集连接产生新的侯选集 + * + * @param seqList 上次产生的候选集 + */ + private ArrayList generateCandidateItem( + ArrayList seqList){ + Sequence tempSeq; + ArrayList tempNumArray; + ArrayList resultSeq = new ArrayList<>(); + // 序列数字项列表 + ArrayList> seqNums = new ArrayList<>(); + + for (Sequence aSeqList : seqList) { + tempNumArray = new ArrayList<>(); + tempSeq = aSeqList; + for (ItemSet itemSet : tempSeq.getItemSetList()) { + tempNumArray.addAll(itemSet.copyItems()); + } + seqNums.add(tempNumArray); + } + + ArrayList array1; + ArrayList array2; + // 序列i,j的拷贝 + Sequence seqi; + Sequence seqj; + // 判断是否能够连接,默认能连接 + boolean canConnect; + // 进行连接运算,包括自己与自己连接 + for (int i = 0; i < seqNums.size(); i++) { + for (int j = 0; j < seqNums.size(); j++) { + array1 = (ArrayList) seqNums.get(i).clone(); + array2 = (ArrayList) seqNums.get(j).clone(); + + // 将第一个数字组去掉第一个,第二个数字组去掉最后一个,如果剩下的部分相等,则可以连接 + array1.remove(0); + array2.remove(array2.size() - 1); + + canConnect = true; + for (int k = 0; k < array1.size(); k++) { + if (!Objects.equals(array1.get(k), array2.get(k))) { + canConnect = false; + break; + } + } + + if (canConnect) { + seqi = seqList.get(i).copySeqence(); + seqj = seqList.get(j).copySeqence(); + + int lastItemNum = seqj.getLastItemSetNum(); + if (seqj.isLastItemSetSingleNum()) { + // 如果j序列的最后项集为单一值,则最后一个数字以独立项集加入i序列 + ItemSet itemSet = new ItemSet(new int[]{lastItemNum}); + seqi.getItemSetList().add(itemSet); + } else { + // 如果j序列的最后项集为非单一值,则最后一个数字加入i序列最后一个项集中 + ItemSet itemSet = seqi.getLastItemSet(); + itemSet.getItems().add(lastItemNum); + } + + // 判断是否超过最小支持度阈值 + if (isChildSeqContained(seqi) + && countSupport(seqi) >= minSupportCount) { + resultSeq.add(seqi); + } + } + } + } + + totalFrequencySeqs.addAll(resultSeq); + return resultSeq; + } + + /** + * 判断此序列的所有子序列是否也是频繁序列 + * + * @param seq 待比较序列 + */ + private boolean isChildSeqContained(Sequence seq){ + boolean isContained = false; + ArrayList childSeqs; + + childSeqs = seq.createChildSeqs(); + for (Sequence tempSeq : childSeqs) { + isContained = false; + + for (Sequence frequencySeq : totalFrequencySeqs) { + if (tempSeq.compareIsSame(frequencySeq)) { + isContained = true; + break; + } + } + + if (!isContained) { + break; + } + } + + return isContained; + } + + /** + * 候选集判断支持度的值 + * + * @param seq 待判断序列 + */ + private int countSupport(Sequence seq){ + int count = 0; + int matchNum; + Sequence tempSeq; + ItemSet tempItemSet; + HashMap timeMap; + ArrayList itemSetList; + ArrayList> numArray = new ArrayList<>(); + // 每项集对应的时间链表 + ArrayList> timeArray; + + numArray.addAll(seq.getItemSetList() + .stream() + .map(ItemSet::getItems) + .collect(Collectors.toList())); + + for (int i = 0; i < totalSequences.size(); i++) { + timeArray = new ArrayList<>(); + + for (ArrayList childNum : numArray) { + ArrayList localTime = new ArrayList<>(); + tempSeq = totalSequences.get(i); + itemSetList = tempSeq.getItemSetList(); + + for (int j = 0; j < itemSetList.size(); j++) { + tempItemSet = itemSetList.get(j); + matchNum = 0; + int t = 0; + + if (tempItemSet.getItems().size() == childNum.size()) { + timeMap = itemNum2Time.get(i).get(j); + // 只有当项集长度匹配时才匹配 + for (Integer aChildNum : childNum) { + if (timeMap.containsKey(aChildNum)) { + matchNum++; + t = timeMap.get(aChildNum); + } + } + + // 如果完全匹配,则记录时间 + if (matchNum == childNum.size()) { + localTime.add(t); + } + } + + } + + if (localTime.size() > 0) { + timeArray.add(localTime); + } + } + + // 判断时间是否满足时间最大最小约束,如果满足,则此条事务包含候选事务 + if (timeArray.size() == numArray.size() + && judgeTimeInGap(timeArray)) { + count++; + } + } + + return count; + } + + /** + * 判断事务是否满足时间约束 + * + * @param timeArray 时间数组,每行代表各项集的在事务中的发生时间链表 + */ + private boolean judgeTimeInGap(ArrayList> timeArray){ + boolean result = false; + int preTime; + ArrayList firstTimes = timeArray.get(0); + timeArray.remove(0); + + if (timeArray.size() == 0) { + return false; + } + + for (Integer firstTime : firstTimes) { + preTime = firstTime; + + if (dfsJudgeTime(preTime, timeArray)) { + result = true; + break; + } + } + + return result; + } + + /** + * 深度优先遍历时间,判断是否有符合条件的时间间隔 + * + * @param preTime + * @param timeArray + */ + private boolean dfsJudgeTime(int preTime, + ArrayList> timeArray){ + boolean result = false; + ArrayList> timeArrayClone = (ArrayList>) timeArray + .clone(); + ArrayList firstItemItem = timeArrayClone.get(0); + + for (Integer aFirstItemItem : firstItemItem) { + if (aFirstItemItem - preTime >= min_gap + && aFirstItemItem - preTime <= max_gap) { + // 如果此2项间隔时间满足时间约束,则继续往下递归 + preTime = aFirstItemItem; + timeArrayClone.remove(0); + + if (timeArrayClone.size() == 0) { + return true; + } else { + result = dfsJudgeTime(preTime, timeArrayClone); + if (result) { + return true; + } + } + } + } + + return result; + } + + /** + * 初始化序列项到时间的序列图,为了后面的时间约束计算 + */ + private void initItemNumToTimeMap(){ + Sequence seq; + itemNum2Time = new ArrayList<>(); + HashMap tempMap; + ArrayList> tempMapList; + + for (Sequence totalSequence : totalSequences) { + seq = totalSequence; + tempMapList = new ArrayList<>(); + + for (int j = 0; j < seq.getItemSetList().size(); j++) { + ItemSet itemSet = seq.getItemSetList().get(j); + tempMap = new HashMap<>(); + for (int itemNum : itemSet.getItems()) { + tempMap.put(itemNum, j + 1); + } + + tempMapList.add(tempMap); + } + + itemNum2Time.add(tempMapList); + } + } + + /** + * 进行GSP算法计算 + */ + void gspCalculate(){ + ArrayList oneSeq; + ArrayList twoSeq; + ArrayList candidateSeq; + + initItemNumToTimeMap(); + oneSeq = generateOneFrequencyItem(); + twoSeq = generateTwoFrequencyItem(oneSeq); + candidateSeq = twoSeq; + + // 不断连接生产候选集,直到没有产生出侯选集 + for (; ; ) { + candidateSeq = generateCandidateItem(candidateSeq); + + if (candidateSeq.size() == 0) { + break; + } + } + + outputSeqence(totalFrequencySeqs); + + } + + /** + * 输出序列列表信息 + * + * @param outputSeqList 待输出序列列表 + */ + private void outputSeqence(ArrayList outputSeqList){ + for (Sequence seq : outputSeqList) { + System.out.print("<"); + for (ItemSet itemSet : seq.getItemSetList()) { + System.out.print("("); + for (int num : itemSet.getItems()) { + System.out.print(num + ","); + } + System.out.print("), "); + } + System.out.println(">"); + } + } } diff --git a/SequentialPatterns/DataMining_GSP/ItemSet.java b/SequentialPatterns/DataMining_GSP/ItemSet.java index fc85c99..bb1fd73 100644 --- a/SequentialPatterns/DataMining_GSP/ItemSet.java +++ b/SequentialPatterns/DataMining_GSP/ItemSet.java @@ -1,82 +1,72 @@ -package DataMining_GSP; +package SequentialPatterns.DataMining_GSP; import java.util.ArrayList; +import java.util.Objects; /** * 序列中的子项集 - * - * @author lyq - * + * + * @author Qstar */ -public class ItemSet { - /** - * 项集中保存的是数字项数组 - */ - private ArrayList items; +class ItemSet { + /** + * 项集中保存的是数字项数组 + */ + private ArrayList items; - public ItemSet(String[] itemStr) { - items = new ArrayList<>(); - for (String s : itemStr) { - items.add(Integer.parseInt(s)); - } - } + ItemSet(String[] itemStr){ + items = new ArrayList<>(); + for (String s : itemStr) { + items.add(Integer.parseInt(s)); + } + } - public ItemSet(int[] itemNum) { - items = new ArrayList<>(); - for (int num : itemNum) { - items.add(num); - } - } - - public ItemSet(ArrayList itemNum) { - this.items = itemNum; - } + ItemSet(int[] itemNum){ + items = new ArrayList<>(); + for (int num : itemNum) { + items.add(num); + } + } - public ArrayList getItems() { - return items; - } + ItemSet(ArrayList itemNum){ + this.items = itemNum; + } - public void setItems(ArrayList items) { - this.items = items; - } + ArrayList getItems(){ + return items; + } - /** - * 判断2个项集是否相等 - * - * @param itemSet - * 比较对象 - * @return - */ - public boolean compareIsSame(ItemSet itemSet) { - boolean result = true; + /** + * 判断2个项集是否相等 + * + * @param itemSet 比较对象 + */ + boolean compareIsSame(ItemSet itemSet){ + boolean result = true; - if (this.items.size() != itemSet.items.size()) { - return false; - } + if (this.items.size() != itemSet.items.size()) { + return false; + } - for (int i = 0; i < itemSet.items.size(); i++) { - if (this.items.get(i) != itemSet.items.get(i)) { - // 只要有值不相等,直接算作不相等 - result = false; - break; - } - } + for (int i = 0; i < itemSet.items.size(); i++) { + if (!Objects.equals(this.items.get(i), itemSet.items.get(i))) { + // 只要有值不相等,直接算作不相等 + result = false; + break; + } + } - return result; - } + return result; + } - /** - * 拷贝项集中同样的数据一份 - * - * @return - */ - public ArrayList copyItems() { - ArrayList copyItems = new ArrayList<>(); + /** + * 拷贝项集中同样的数据一份 + */ + ArrayList copyItems(){ + ArrayList copyItems = new ArrayList<>(); - for (int num : this.items) { - copyItems.add(num); - } + copyItems.addAll(this.items); - return copyItems; - } + return copyItems; + } } diff --git a/SequentialPatterns/DataMining_GSP/Sequence.java b/SequentialPatterns/DataMining_GSP/Sequence.java index ce41d69..401cca9 100644 --- a/SequentialPatterns/DataMining_GSP/Sequence.java +++ b/SequentialPatterns/DataMining_GSP/Sequence.java @@ -1,173 +1,147 @@ -package DataMining_GSP; +package SequentialPatterns.DataMining_GSP; import java.util.ArrayList; /** * 序列,每个序列内部包含多组ItemSet项集 - * - * @author lyq - * + * + * @author Qstar */ -public class Sequence implements Comparable, Cloneable { - // 序列所属事务ID - private int trsanctionID; - // 项集列表 - private ArrayList itemSetList; - - public Sequence(int trsanctionID) { - this.trsanctionID = trsanctionID; - this.itemSetList = new ArrayList<>(); - } - - public Sequence() { - this.itemSetList = new ArrayList<>(); - } - - public int getTrsanctionID() { - return trsanctionID; - } - - public void setTrsanctionID(int trsanctionID) { - this.trsanctionID = trsanctionID; - } - - public ArrayList getItemSetList() { - return itemSetList; - } - - public void setItemSetList(ArrayList itemSetList) { - this.itemSetList = itemSetList; - } - - /** - * 取出序列中第一个项集的第一个元素 - * - * @return - */ - public Integer getFirstItemSetNum() { - return this.getItemSetList().get(0).getItems().get(0); - } - - /** - * 获取序列中最后一个项集 - * - * @return - */ - public ItemSet getLastItemSet() { - return getItemSetList().get(getItemSetList().size() - 1); - } - - /** - * 获取序列中最后一个项集的最后一个一个元素 - * - * @return - */ - public Integer getLastItemSetNum() { - ItemSet lastItemSet = getItemSetList().get(getItemSetList().size() - 1); - int lastItemNum = lastItemSet.getItems().get( - lastItemSet.getItems().size() - 1); - - return lastItemNum; - } - - /** - * 判断序列中最后一个项集是否为单一的值 - * - * @return - */ - public boolean isLastItemSetSingleNum() { - ItemSet lastItemSet = getItemSetList().get(getItemSetList().size() - 1); - int size = lastItemSet.getItems().size(); - - return size == 1 ? true : false; - } - - @Override - public int compareTo(Sequence o) { - // TODO Auto-generated method stub - return this.getFirstItemSetNum().compareTo(o.getFirstItemSetNum()); - } - - @Override - protected Object clone() throws CloneNotSupportedException { - // TODO Auto-generated method stub - return super.clone(); - } - - /** - * 拷贝一份一模一样的序列 - */ - public Sequence copySeqence(){ - Sequence copySeq = new Sequence(); - for(ItemSet itemSet: this.itemSetList){ - copySeq.getItemSetList().add(new ItemSet(itemSet.copyItems())); - } - - return copySeq; - } - - /** - * 比较2个序列是否相等,需要判断内部的每个项集是否完全一致 - * - * @param seq - * 比较的序列对象 - * @return - */ - public boolean compareIsSame(Sequence seq) { - boolean result = true; - ArrayList itemSetList2 = seq.getItemSetList(); - ItemSet tempItemSet1; - ItemSet tempItemSet2; - - if (itemSetList2.size() != this.itemSetList.size()) { - return false; - } - for (int i = 0; i < itemSetList2.size(); i++) { - tempItemSet1 = this.itemSetList.get(i); - tempItemSet2 = itemSetList2.get(i); - - if (!tempItemSet1.compareIsSame(tempItemSet2)) { - // 只要不相等,直接退出函数 - result = false; - break; - } - } - - return result; - } - - /** - * 生成此序列的所有子序列 - * - * @return - */ - public ArrayList createChildSeqs() { - ArrayList childSeqs = new ArrayList<>(); - ArrayList tempItems; - Sequence tempSeq = null; - ItemSet tempItemSet; - - for (int i = 0; i < this.itemSetList.size(); i++) { - tempItemSet = itemSetList.get(i); - if (tempItemSet.getItems().size() == 1) { - tempSeq = this.copySeqence(); - - // 如果只有项集中只有1个元素,则直接移除 - tempSeq.itemSetList.remove(i); - childSeqs.add(tempSeq); - } else { - tempItems = tempItemSet.getItems(); - for (int j = 0; j < tempItems.size(); j++) { - tempSeq = this.copySeqence(); - - // 在拷贝的序列中移除一个数字 - tempSeq.getItemSetList().get(i).getItems().remove(j); - childSeqs.add(tempSeq); - } - } - } - - return childSeqs; - } +class Sequence implements Comparable, Cloneable { + // 序列所属事务ID + private int trsanctionID; + // 项集列表 + private ArrayList itemSetList; + + Sequence(int trsanctionID){ + this.trsanctionID = trsanctionID; + this.itemSetList = new ArrayList<>(); + } + + Sequence(){ + this.itemSetList = new ArrayList<>(); + } + + ArrayList getItemSetList(){ + return itemSetList; + } + + /** + * 取出序列中第一个项集的第一个元素 + */ + Integer getFirstItemSetNum(){ + return this.getItemSetList().get(0).getItems().get(0); + } + + /** + * 获取序列中最后一个项集 + */ + ItemSet getLastItemSet(){ + return getItemSetList().get(getItemSetList().size() - 1); + } + + /** + * 获取序列中最后一个项集的最后一个一个元素 + */ + Integer getLastItemSetNum(){ + ItemSet lastItemSet = getItemSetList().get(getItemSetList().size() - 1); + + return lastItemSet.getItems().get( + lastItemSet.getItems().size() - 1); + } + + /** + * 判断序列中最后一个项集是否为单一的值 + */ + boolean isLastItemSetSingleNum(){ + ItemSet lastItemSet = getItemSetList().get(getItemSetList().size() - 1); + int size = lastItemSet.getItems().size(); + + return size == 1; + } + + @Override + public int compareTo(Sequence o){ + // TODO Auto-generated method stub + return this.getFirstItemSetNum().compareTo(o.getFirstItemSetNum()); + } + + @Override + protected Object clone() throws CloneNotSupportedException{ + // TODO Auto-generated method stub + return super.clone(); + } + + /** + * 拷贝一份一模一样的序列 + */ + Sequence copySeqence(){ + Sequence copySeq = new Sequence(); + for (ItemSet itemSet : this.itemSetList) { + copySeq.getItemSetList().add(new ItemSet(itemSet.copyItems())); + } + + return copySeq; + } + + /** + * 比较2个序列是否相等,需要判断内部的每个项集是否完全一致 + * + * @param seq 比较的序列对象 + */ + boolean compareIsSame(Sequence seq){ + boolean result = true; + ArrayList itemSetList2 = seq.getItemSetList(); + ItemSet tempItemSet1; + ItemSet tempItemSet2; + + if (itemSetList2.size() != this.itemSetList.size()) { + return false; + } + for (int i = 0; i < itemSetList2.size(); i++) { + tempItemSet1 = this.itemSetList.get(i); + tempItemSet2 = itemSetList2.get(i); + + if (!tempItemSet1.compareIsSame(tempItemSet2)) { + // 只要不相等,直接退出函数 + result = false; + break; + } + } + + return result; + } + + /** + * 生成此序列的所有子序列 + */ + ArrayList createChildSeqs(){ + ArrayList childSeqs = new ArrayList<>(); + ArrayList tempItems; + Sequence tempSeq; + ItemSet tempItemSet; + + for (int i = 0; i < this.itemSetList.size(); i++) { + tempItemSet = itemSetList.get(i); + if (tempItemSet.getItems().size() == 1) { + tempSeq = this.copySeqence(); + + // 如果只有项集中只有1个元素,则直接移除 + tempSeq.itemSetList.remove(i); + childSeqs.add(tempSeq); + } else { + tempItems = tempItemSet.getItems(); + for (int j = 0; j < tempItems.size(); j++) { + tempSeq = this.copySeqence(); + + // 在拷贝的序列中移除一个数字 + tempSeq.getItemSetList().get(i).getItems().remove(j); + childSeqs.add(tempSeq); + } + } + } + + return childSeqs; + } } diff --git a/SequentialPatterns/DataMining_PrefixSpan/Client.java b/SequentialPatterns/DataMining_PrefixSpan/Client.java index 4bde489..819f394 100644 --- a/SequentialPatterns/DataMining_PrefixSpan/Client.java +++ b/SequentialPatterns/DataMining_PrefixSpan/Client.java @@ -1,18 +1,18 @@ -package DataMining_PrefixSpan; +package SequentialPatterns.DataMining_PrefixSpan; /** * PrefixSpan序列模式挖掘算法 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] agrs){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - //最小支持度阈值率 - double minSupportRate = 0.4; - - PrefixSpanTool tool = new PrefixSpanTool(filePath, minSupportRate); - tool.prefixSpanCalculate(); - } + public static void main(String[] agrs){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/SequentialPatterns/DataMining_PrefixSpan/input.txt"; + //最小支持度阈值率 + double minSupportRate = 0.4; + + PrefixSpanTool tool = new PrefixSpanTool(filePath, minSupportRate); + tool.prefixSpanCalculate(); + } } diff --git a/SequentialPatterns/DataMining_PrefixSpan/ItemSet.java b/SequentialPatterns/DataMining_PrefixSpan/ItemSet.java index a401650..b1e7cb2 100644 --- a/SequentialPatterns/DataMining_PrefixSpan/ItemSet.java +++ b/SequentialPatterns/DataMining_PrefixSpan/ItemSet.java @@ -1,51 +1,43 @@ -package DataMining_PrefixSpan; +package SequentialPatterns.DataMining_PrefixSpan; import java.util.ArrayList; +import java.util.Collections; /** * 字符项集类 - * - * @author lyq - * + * + * @author Qstar */ -public class ItemSet { - // 项集内的字符 - private ArrayList items; - - public ItemSet(String[] str) { - items = new ArrayList<>(); - for (String s : str) { - items.add(s); - } - } - - public ItemSet(ArrayList itemsList) { - this.items = itemsList; - } - - public ItemSet(String s) { - items = new ArrayList<>(); - for (int i = 0; i < s.length(); i++) { - items.add(s.charAt(i) + ""); - } - } - - public ArrayList getItems() { - return items; - } - - public void setItems(ArrayList items) { - this.items = items; - } - - /** - * 获取项集最后1个元素 - * - * @return - */ - public String getLastValue() { - int size = this.items.size(); - - return this.items.get(size - 1); - } +class ItemSet { + // 项集内的字符 + private ArrayList items; + + public ItemSet(String[] str){ + items = new ArrayList<>(); + Collections.addAll(items, str); + } + + ItemSet(ArrayList itemsList){ + this.items = itemsList; + } + + ItemSet(String s){ + items = new ArrayList<>(); + for (int i = 0; i < s.length(); i++) { + items.add(s.charAt(i) + ""); + } + } + + ArrayList getItems(){ + return items; + } + + /** + * 获取项集最后1个元素 + */ + String getLastValue(){ + int size = this.items.size(); + + return this.items.get(size - 1); + } } diff --git a/SequentialPatterns/DataMining_PrefixSpan/PrefixSpanTool.java b/SequentialPatterns/DataMining_PrefixSpan/PrefixSpanTool.java index 620baf0..883b696 100644 --- a/SequentialPatterns/DataMining_PrefixSpan/PrefixSpanTool.java +++ b/SequentialPatterns/DataMining_PrefixSpan/PrefixSpanTool.java @@ -1,4 +1,4 @@ -package DataMining_PrefixSpan; +package SequentialPatterns.DataMining_PrefixSpan; import java.io.BufferedReader; import java.io.File; @@ -8,346 +8,335 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.stream.Collectors; /** * PrefixSpanTool序列模式分析算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class PrefixSpanTool { - // 测试数据文件地址 - private String filePath; - // 最小支持度阈值比例 - private double minSupportRate; - // 最小支持度,通过序列总数乘以阈值比例计算 - private int minSupport; - // 原始序列组 - private ArrayList totalSeqs; - // 挖掘出的所有序列频繁模式 - private ArrayList totalFrequentSeqs; - // 所有的单一项,用于递归枚举 - private ArrayList singleItems; - - public PrefixSpanTool(String filePath, double minSupportRate) { - this.filePath = filePath; - this.minSupportRate = minSupportRate; - readDataFile(); - } - - /** - * 从文件中读取数据 - */ - private void readDataFile() { - File file = new File(filePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - minSupport = (int) (dataArray.size() * minSupportRate); - totalSeqs = new ArrayList<>(); - totalFrequentSeqs = new ArrayList<>(); - Sequence tempSeq; - ItemSet tempItemSet; - for (String[] str : dataArray) { - tempSeq = new Sequence(); - for (String s : str) { - tempItemSet = new ItemSet(s); - tempSeq.getItemSetList().add(tempItemSet); - } - totalSeqs.add(tempSeq); - } - - System.out.println("原始序列数据:"); - outputSeqence(totalSeqs); - } - - /** - * 输出序列列表内容 - * - * @param seqList - * 待输出序列列表 - */ - private void outputSeqence(ArrayList seqList) { - for (Sequence seq : seqList) { - System.out.print("<"); - for (ItemSet itemSet : seq.getItemSetList()) { - if (itemSet.getItems().size() > 1) { - System.out.print("("); - } - - for (String s : itemSet.getItems()) { - System.out.print(s + " "); - } - - if (itemSet.getItems().size() > 1) { - System.out.print(")"); - } - } - System.out.println(">"); - } - } - - /** - * 移除初始序列中不满足最小支持度阈值的单项 - */ - private void removeInitSeqsItem() { - int count = 0; - HashMap itemMap = new HashMap<>(); - singleItems = new ArrayList<>(); - - for (Sequence seq : totalSeqs) { - for (ItemSet itemSet : seq.getItemSetList()) { - for (String s : itemSet.getItems()) { - if (!itemMap.containsKey(s)) { - itemMap.put(s, 1); - } - } - } - } - - String key; - for (Map.Entry entry : itemMap.entrySet()) { - count = 0; - key = (String) entry.getKey(); - for (Sequence seq : totalSeqs) { - if (seq.strIsContained(key)) { - count++; - } - } - - itemMap.put(key, count); - - } - - for (Map.Entry entry : itemMap.entrySet()) { - key = (String) entry.getKey(); - count = (int) entry.getValue(); - - if (count < minSupport) { - // 如果支持度阈值小于所得的最小支持度阈值,则删除该项 - for (Sequence seq : totalSeqs) { - seq.deleteSingleItem(key); - } - } else { - singleItems.add(key); - } - } - - Collections.sort(singleItems); - } - - /** - * 递归搜索满足条件的序列模式 - * - * @param beforeSeq - * 前缀序列 - * @param afterSeqList - * 后缀序列列表 - */ - private void recursiveSearchSeqs(Sequence beforeSeq, - ArrayList afterSeqList) { - ItemSet tempItemSet; - Sequence tempSeq2; - Sequence tempSeq; - ArrayList tempSeqList = new ArrayList<>(); - - for (String s : singleItems) { - // 分成2种形式递归,以为起始项,第一种直接加入独立项集遍历, .. - if (isLargerThanMinSupport(s, afterSeqList)) { - tempSeq = beforeSeq.copySeqence(); - tempItemSet = new ItemSet(s); - tempSeq.getItemSetList().add(tempItemSet); - - totalFrequentSeqs.add(tempSeq); - - tempSeqList = new ArrayList<>(); - for (Sequence seq : afterSeqList) { - if (seq.strIsContained(s)) { - tempSeq2 = seq.extractItem(s); - tempSeqList.add(tempSeq2); - } - } - - recursiveSearchSeqs(tempSeq, tempSeqList); - } - - // 第二种递归为以元素的身份加入最后的项集内以a为例<(aa)>,<(ab)>,<(ac)>... - // a在这里可以理解为一个前缀序列,里面可能是单个元素或者已经是多元素的项集 - tempSeq = beforeSeq.copySeqence(); - int size = tempSeq.getItemSetList().size(); - tempItemSet = tempSeq.getItemSetList().get(size - 1); - tempItemSet.getItems().add(s); - - if (isLargerThanMinSupport(tempItemSet, afterSeqList)) { - tempSeqList = new ArrayList<>(); - for (Sequence seq : afterSeqList) { - if (seq.compoentItemIsContain(tempItemSet)) { - tempSeq2 = seq.extractCompoentItem(tempItemSet - .getItems()); - tempSeqList.add(tempSeq2); - } - } - totalFrequentSeqs.add(tempSeq); - - recursiveSearchSeqs(tempSeq, tempSeqList); - } - } - } - - /** - * 所传入的项组合在所给定序列中的支持度是否超过阈值 - * - * @param s - * 所需匹配的项 - * @param seqList - * 比较序列数据 - * @return - */ - private boolean isLargerThanMinSupport(String s, ArrayList seqList) { - boolean isLarge = false; - int count = 0; - - for (Sequence seq : seqList) { - if (seq.strIsContained(s)) { - count++; - } - } - - if (count >= minSupport) { - isLarge = true; - } - - return isLarge; - } - - /** - * 所传入的组合项集在序列中的支持度是否大于阈值 - * - * @param itemSet - * 组合元素项集 - * @param seqList - * 比较的序列列表 - * @return - */ - private boolean isLargerThanMinSupport(ItemSet itemSet, - ArrayList seqList) { - boolean isLarge = false; - int count = 0; - - if (seqList == null) { - return false; - } - - for (Sequence seq : seqList) { - if (seq.compoentItemIsContain(itemSet)) { - count++; - } - } - - if (count >= minSupport) { - isLarge = true; - } - - return isLarge; - } - - /** - * 序列模式分析计算 - */ - public void prefixSpanCalculate() { - Sequence seq; - Sequence tempSeq; - ArrayList tempSeqList = new ArrayList<>(); - ItemSet itemSet; - removeInitSeqsItem(); - - for (String s : singleItems) { - // 从最开始的a,b,d开始递归往下寻找频繁序列模式 - seq = new Sequence(); - itemSet = new ItemSet(s); - seq.getItemSetList().add(itemSet); - - if (isLargerThanMinSupport(s, totalSeqs)) { - tempSeqList = new ArrayList<>(); - for (Sequence s2 : totalSeqs) { - // 判断单一项是否包含于在序列中,包含才进行提取操作 - if (s2.strIsContained(s)) { - tempSeq = s2.extractItem(s); - tempSeqList.add(tempSeq); - } - } - - totalFrequentSeqs.add(seq); - recursiveSearchSeqs(seq, tempSeqList); - } - } - - printTotalFreSeqs(); - } - - /** - * 按模式类别输出频繁序列模式 - */ - private void printTotalFreSeqs() { - System.out.println("序列模式挖掘结果:"); - - ArrayList seqList; - HashMap> seqMap = new HashMap<>(); - for (String s : singleItems) { - seqList = new ArrayList<>(); - for (Sequence seq : totalFrequentSeqs) { - if (seq.getItemSetList().get(0).getItems().get(0).equals(s)) { - seqList.add(seq); - } - } - seqMap.put(s, seqList); - } - - int count = 0; - for (String s : singleItems) { - count = 0; - System.out.println(); - System.out.println(); - - seqList = (ArrayList) seqMap.get(s); - for (Sequence tempSeq : seqList) { - count++; - System.out.print("<"); - for (ItemSet itemSet : tempSeq.getItemSetList()) { - if (itemSet.getItems().size() > 1) { - System.out.print("("); - } - - for (String str : itemSet.getItems()) { - System.out.print(str + " "); - } - - if (itemSet.getItems().size() > 1) { - System.out.print(")"); - } - } - System.out.print(">, "); - - // 每5个序列换一行 - if (count == 5) { - count = 0; - System.out.println(); - } - } - - } - } +class PrefixSpanTool { + // 测试数据文件地址 + private String filePath; + // 最小支持度阈值比例 + private double minSupportRate; + // 最小支持度,通过序列总数乘以阈值比例计算 + private int minSupport; + // 原始序列组 + private ArrayList totalSeqs; + // 挖掘出的所有序列频繁模式 + private ArrayList totalFrequentSeqs; + // 所有的单一项,用于递归枚举 + private ArrayList singleItems; + + PrefixSpanTool(String filePath, double minSupportRate){ + this.filePath = filePath; + this.minSupportRate = minSupportRate; + readDataFile(); + } + + /** + * 从文件中读取数据 + */ + private void readDataFile(){ + File file = new File(filePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + minSupport = (int) (dataArray.size() * minSupportRate); + totalSeqs = new ArrayList<>(); + totalFrequentSeqs = new ArrayList<>(); + Sequence tempSeq; + ItemSet tempItemSet; + for (String[] str : dataArray) { + tempSeq = new Sequence(); + for (String s : str) { + tempItemSet = new ItemSet(s); + tempSeq.getItemSetList().add(tempItemSet); + } + totalSeqs.add(tempSeq); + } + + System.out.println("原始序列数据:"); + outputSeqence(totalSeqs); + } + + /** + * 输出序列列表内容 + * + * @param seqList 待输出序列列表 + */ + private void outputSeqence(ArrayList seqList){ + for (Sequence seq : seqList) { + System.out.print("<"); + for (ItemSet itemSet : seq.getItemSetList()) { + if (itemSet.getItems().size() > 1) { + System.out.print("("); + } + + for (String s : itemSet.getItems()) { + System.out.print(s + " "); + } + + if (itemSet.getItems().size() > 1) { + System.out.print(")"); + } + } + System.out.println(">"); + } + } + + /** + * 移除初始序列中不满足最小支持度阈值的单项 + */ + private void removeInitSeqsItem(){ + int count; + HashMap itemMap = new HashMap<>(); + singleItems = new ArrayList<>(); + + for (Sequence seq : totalSeqs) { + for (ItemSet itemSet : seq.getItemSetList()) { + itemSet.getItems() + .stream() + .filter(s -> !itemMap.containsKey(s)) + .forEach(s -> itemMap.put(s, 1)); + } + } + + String key; + for (Map.Entry entry : itemMap.entrySet()) { + count = 0; + key = (String) entry.getKey(); + for (Sequence seq : totalSeqs) { + if (seq.strIsContained(key)) { + count++; + } + } + + itemMap.put(key, count); + + } + + for (Map.Entry entry : itemMap.entrySet()) { + key = (String) entry.getKey(); + count = (int) entry.getValue(); + + if (count < minSupport) { + // 如果支持度阈值小于所得的最小支持度阈值,则删除该项 + for (Sequence seq : totalSeqs) { + seq.deleteSingleItem(key); + } + } else { + singleItems.add(key); + } + } + + Collections.sort(singleItems); + } + + /** + * 递归搜索满足条件的序列模式 + * + * @param beforeSeq 前缀序列 + * @param afterSeqList 后缀序列列表 + */ + private void recursiveSearchSeqs(Sequence beforeSeq, + ArrayList afterSeqList){ + ItemSet tempItemSet; + Sequence tempSeq2; + Sequence tempSeq; + ArrayList tempSeqList; + + for (String s : singleItems) { + // 分成2种形式递归,以为起始项,第一种直接加入独立项集遍历, .. + if (isLargerThanMinSupport(s, afterSeqList)) { + tempSeq = beforeSeq.copySeqence(); + tempItemSet = new ItemSet(s); + tempSeq.getItemSetList().add(tempItemSet); + + totalFrequentSeqs.add(tempSeq); + + tempSeqList = new ArrayList<>(); + for (Sequence seq : afterSeqList) { + if (seq.strIsContained(s)) { + tempSeq2 = seq.extractItem(s); + tempSeqList.add(tempSeq2); + } + } + + recursiveSearchSeqs(tempSeq, tempSeqList); + } + + // 第二种递归为以元素的身份加入最后的项集内以a为例<(aa)>,<(ab)>,<(ac)>... + // a在这里可以理解为一个前缀序列,里面可能是单个元素或者已经是多元素的项集 + tempSeq = beforeSeq.copySeqence(); + int size = tempSeq.getItemSetList().size(); + tempItemSet = tempSeq.getItemSetList().get(size - 1); + tempItemSet.getItems().add(s); + + if (isLargerThanMinSupport(tempItemSet, afterSeqList)) { + tempSeqList = new ArrayList<>(); + for (Sequence seq : afterSeqList) { + if (seq.compoentItemIsContain(tempItemSet)) { + tempSeq2 = seq.extractCompoentItem(tempItemSet + .getItems()); + tempSeqList.add(tempSeq2); + } + } + totalFrequentSeqs.add(tempSeq); + + recursiveSearchSeqs(tempSeq, tempSeqList); + } + } + } + + /** + * 所传入的项组合在所给定序列中的支持度是否超过阈值 + * + * @param s 所需匹配的项 + * @param seqList 比较序列数据 + */ + private boolean isLargerThanMinSupport(String s, ArrayList seqList){ + boolean isLarge = false; + int count = 0; + + for (Sequence seq : seqList) { + if (seq.strIsContained(s)) { + count++; + } + } + + if (count >= minSupport) { + isLarge = true; + } + + return isLarge; + } + + /** + * 所传入的组合项集在序列中的支持度是否大于阈值 + * + * @param itemSet 组合元素项集 + * @param seqList 比较的序列列表 + */ + private boolean isLargerThanMinSupport(ItemSet itemSet, + ArrayList seqList){ + boolean isLarge = false; + int count = 0; + + if (seqList == null) { + return false; + } + + for (Sequence seq : seqList) { + if (seq.compoentItemIsContain(itemSet)) { + count++; + } + } + + if (count >= minSupport) { + isLarge = true; + } + + return isLarge; + } + + /** + * 序列模式分析计算 + */ + void prefixSpanCalculate(){ + Sequence seq; + Sequence tempSeq; + ArrayList tempSeqList; + ItemSet itemSet; + removeInitSeqsItem(); + + for (String s : singleItems) { + // 从最开始的a,b,d开始递归往下寻找频繁序列模式 + seq = new Sequence(); + itemSet = new ItemSet(s); + seq.getItemSetList().add(itemSet); + + if (isLargerThanMinSupport(s, totalSeqs)) { + tempSeqList = new ArrayList<>(); + for (Sequence s2 : totalSeqs) { + // 判断单一项是否包含于在序列中,包含才进行提取操作 + if (s2.strIsContained(s)) { + tempSeq = s2.extractItem(s); + tempSeqList.add(tempSeq); + } + } + + totalFrequentSeqs.add(seq); + recursiveSearchSeqs(seq, tempSeqList); + } + } + + printTotalFreSeqs(); + } + + /** + * 按模式类别输出频繁序列模式 + */ + private void printTotalFreSeqs(){ + System.out.println("序列模式挖掘结果:"); + + ArrayList seqList; + HashMap> seqMap = new HashMap<>(); + for (String s : singleItems) { + seqList = new ArrayList<>(); + seqList.addAll(totalFrequentSeqs + .stream() + .filter(seq -> seq.getItemSetList().get(0).getItems().get(0).equals(s)) + .collect(Collectors.toList())); + seqMap.put(s, seqList); + } + + int count; + for (String s : singleItems) { + count = 0; + System.out.println(); + System.out.println(); + + seqList = seqMap.get(s); + for (Sequence tempSeq : seqList) { + count++; + System.out.print("<"); + for (ItemSet itemSet : tempSeq.getItemSetList()) { + if (itemSet.getItems().size() > 1) { + System.out.print("("); + } + + for (String str : itemSet.getItems()) { + System.out.print(str + " "); + } + + if (itemSet.getItems().size() > 1) { + System.out.print(")"); + } + } + System.out.print(">, "); + + // 每5个序列换一行 + if (count == 5) { + count = 0; + System.out.println(); + } + } + + } + } } diff --git a/SequentialPatterns/DataMining_PrefixSpan/Sequence.java b/SequentialPatterns/DataMining_PrefixSpan/Sequence.java index 04aed96..b1c00f0 100644 --- a/SequentialPatterns/DataMining_PrefixSpan/Sequence.java +++ b/SequentialPatterns/DataMining_PrefixSpan/Sequence.java @@ -1,298 +1,269 @@ -package DataMining_PrefixSpan; +package SequentialPatterns.DataMining_PrefixSpan; import java.util.ArrayList; +import java.util.stream.Collectors; /** * 序列类 - * - * @author lyq - * + * + * @author Qstar */ -public class Sequence { - // 序列内的项集 - private ArrayList itemSetList; - - public Sequence() { - this.itemSetList = new ArrayList<>(); - } - - public ArrayList getItemSetList() { - return itemSetList; - } - - public void setItemSetList(ArrayList itemSetList) { - this.itemSetList = itemSetList; - } - - /** - * 判断单一项是否包含于此序列 - * - * @param c - * 待判断项 - * @return - */ - public boolean strIsContained(String c) { - boolean isContained = false; - - for (ItemSet itemSet : itemSetList) { - isContained = false; - - for (String s : itemSet.getItems()) { - if (itemSet.getItems().contains("_")) { - continue; - } - - if (s.equals(c)) { - isContained = true; - break; - } - } - - if (isContained) { - // 如果已经检测出包含了,直接挑出循环 - break; - } - } - - return isContained; - } - - /** - * 判断组合项集是否包含于序列中 - * - * @param itemSet - * 组合的项集,元素超过1个 - * @return - */ - public boolean compoentItemIsContain(ItemSet itemSet) { - boolean isContained = false; - ArrayList tempItems; - String lastItem = itemSet.getLastValue(); - - for (int i = 0; i < this.itemSetList.size(); i++) { - tempItems = this.itemSetList.get(i).getItems(); - // 分2种情况查找,第一种从_X中找出x等于项集最后的元素,因为_前缀已经为原本的元素 - if (tempItems.size() > 1 && tempItems.get(0).equals("_") - && tempItems.get(1).equals(lastItem)) { - isContained = true; - break; - } else if (!tempItems.get(0).equals("_")) { - // 从没有_前缀的项集开始寻找,第二种为从后面的后缀中找出直接找出连续字符为ab为同一项集的项集 - if (strArrayContains(tempItems, itemSet.getItems())) { - isContained = true; - break; - } - } - - if (isContained) { - break; - } - } - - return isContained; - } - - /** - * 删除单个项 - * - * @param s - * 待删除项 - */ - public void deleteSingleItem(String s) { - ArrayList tempItems; - ArrayList deleteItems = new ArrayList<>(); - - for (ItemSet itemSet : this.itemSetList) { - tempItems = itemSet.getItems(); - deleteItems = new ArrayList<>(); - - for (int i = 0; i < tempItems.size(); i++) { - if (tempItems.get(i).equals(s)) { - deleteItems.add(tempItems.get(i)); - } - } - - tempItems.removeAll(deleteItems); - } - } - - /** - * 提取项s之后所得的序列 - * - * @param s - * 目标提取项s - */ - public Sequence extractItem(String s) { - Sequence extractSeq = this.copySeqence(); - ItemSet itemSet; - ArrayList items; - ArrayList deleteItemSets = new ArrayList<>(); - ArrayList tempItems = new ArrayList<>(); - - for (int k = 0; k < extractSeq.itemSetList.size(); k++) { - itemSet = extractSeq.itemSetList.get(k); - items = itemSet.getItems(); - if (items.size() == 1 && items.get(0).equals(s)) { - //如果找到的是单项,则完全移除,跳出循环 - extractSeq.itemSetList.remove(k); - break; - } else if (items.size() > 1 && !items.get(0).equals("_")) { - //在后续的多元素项中判断是否包含此元素 - if (items.contains(s)) { - //如果包含把s后面的元素加入到临时字符数组中 - int index = items.indexOf(s); - for (int j = index; j < items.size(); j++) { - tempItems.add(items.get(j)); - } - //将第一位的s变成下标符"_" - tempItems.set(0, "_"); - if (tempItems.size() == 1) { - // 如果此匹配为在最末端,同样移除 - deleteItemSets.add(itemSet); - } else { - //将变化后的项集替换原来的 - extractSeq.itemSetList.set(k, new ItemSet(tempItems)); - } - break; - } else { - deleteItemSets.add(itemSet); - } - } else { - // 不符合以上2项条件的统统移除 - deleteItemSets.add(itemSet); - } - } - extractSeq.itemSetList.removeAll(deleteItemSets); - - return extractSeq; - } - - /** - * 提取组合项之后的序列 - * - * @param array - * 组合数组 - * @return - */ - public Sequence extractCompoentItem(ArrayList array) { - // 找到目标项,是否立刻停止 - boolean stopExtract = false; - Sequence seq = this.copySeqence(); - String lastItem = array.get(array.size() - 1); - ArrayList tempItems; - ArrayList deleteItems = new ArrayList<>(); - - for (int i = 0; i < seq.itemSetList.size(); i++) { - if (stopExtract) { - break; - } - - tempItems = seq.itemSetList.get(i).getItems(); - // 分2种情况查找,第一种从_X中找出x等于项集最后的元素,因为_前缀已经为原本的元素 - if (tempItems.size() > 1 && tempItems.get(0).equals("_") - && tempItems.get(1).equals(lastItem)) { - if (tempItems.size() == 2) { - seq.itemSetList.remove(i); - } else { - // 把1号位置变为下标符"_",往后移1个字符的位置 - tempItems.set(1, "_"); - // 移除第一个的"_"下划符 - tempItems.remove(0); - } - stopExtract = true; - break; - } else if (!tempItems.get(0).equals("_")) { - // 从没有_前缀的项集开始寻找,第二种为从后面的后缀中找出直接找出连续字符为ab为同一项集的项集 - if (strArrayContains(tempItems, array)) { - // 从左往右找出第一个给定字符的位置,把后面的部分截取出来 - int index = tempItems.indexOf(lastItem); - ArrayList array2 = new ArrayList(); - - for (int j = index; j < tempItems.size(); j++) { - array2.add(tempItems.get(j)); - } - array2.set(0, "_"); - - if (array2.size() == 1) { - //如果此项在末尾的位置,则移除该项,否则进行替换 - deleteItems.add(seq.itemSetList.get(i)); - } else { - seq.itemSetList.set(i, new ItemSet(array2)); - } - stopExtract = true; - break; - } else { - deleteItems.add(seq.itemSetList.get(i)); - } - } else { - // 这种情况是处理_X中X不等于最后一个元素的情况 - deleteItems.add(seq.itemSetList.get(i)); - } - } - - seq.itemSetList.removeAll(deleteItems); - - return seq; - } - - /** - * 深拷贝一个序列 - * - * @return - */ - public Sequence copySeqence() { - Sequence copySeq = new Sequence(); - ItemSet tempItemSet; - ArrayList items; - - for (ItemSet itemSet : this.itemSetList) { - items = (ArrayList) itemSet.getItems().clone(); - tempItemSet = new ItemSet(items); - copySeq.getItemSetList().add(tempItemSet); - } - - return copySeq; - } - - /** - * 获取序列中最后一个项集的最后1个元素 - * - * @return - */ - public String getLastItemSetValue() { - int size = this.getItemSetList().size(); - ItemSet itemSet = this.getItemSetList().get(size - 1); - size = itemSet.getItems().size(); - - return itemSet.getItems().get(size - 1); - } - - /** - * 判断strList2是否是strList1的子序列 - * - * @param strList1 - * @param strList2 - * @return - */ - public boolean strArrayContains(ArrayList strList1, - ArrayList strList2) { - boolean isContained = false; - - for (int i = 0; i < strList1.size() - strList2.size() + 1; i++) { - isContained = true; - - for (int j = 0, k = i; j < strList2.size(); j++, k++) { - if (!strList1.get(k).equals(strList2.get(j))) { - isContained = false; - break; - } - } - - if (isContained) { - break; - } - } - - return isContained; - } +class Sequence { + // 序列内的项集 + private ArrayList itemSetList; + + Sequence(){ + this.itemSetList = new ArrayList<>(); + } + + ArrayList getItemSetList(){ + return itemSetList; + } + + /** + * 判断单一项是否包含于此序列 + * + * @param c 待判断项 + */ + boolean strIsContained(String c){ + boolean isContained = false; + + for (ItemSet itemSet : itemSetList) { + isContained = false; + + for (String s : itemSet.getItems()) { + if (itemSet.getItems().contains("_")) { + continue; + } + + if (s.equals(c)) { + isContained = true; + break; + } + } + + if (isContained) { + // 如果已经检测出包含了,直接挑出循环 + break; + } + } + + return isContained; + } + + /** + * 判断组合项集是否包含于序列中 + * + * @param itemSet 组合的项集,元素超过1个 + */ + boolean compoentItemIsContain(ItemSet itemSet){ + boolean isContained = false; + ArrayList tempItems; + String lastItem = itemSet.getLastValue(); + + for (ItemSet anItemSetList : this.itemSetList) { + tempItems = anItemSetList.getItems(); + // 分2种情况查找,第一种从_X中找出x等于项集最后的元素,因为_前缀已经为原本的元素 + if (tempItems.size() > 1 && tempItems.get(0).equals("_") + && tempItems.get(1).equals(lastItem)) { + isContained = true; + break; + } else if (!tempItems.get(0).equals("_")) { + // 从没有_前缀的项集开始寻找,第二种为从后面的后缀中找出直接找出连续字符为ab为同一项集的项集 + if (strArrayContains(tempItems, itemSet.getItems())) { + isContained = true; + break; + } + } + } + return isContained; + } + + /** + * 删除单个项 + * + * @param s 待删除项 + */ + void deleteSingleItem(String s){ + ArrayList tempItems; + ArrayList deleteItems; + + for (ItemSet itemSet : this.itemSetList) { + tempItems = itemSet.getItems(); + deleteItems = new ArrayList<>(); + + deleteItems.addAll(tempItems + .stream() + .filter(tempItem -> tempItem.equals(s)) + .collect(Collectors.toList())); + + tempItems.removeAll(deleteItems); + } + } + + /** + * 提取项s之后所得的序列 + * + * @param s 目标提取项s + */ + Sequence extractItem(String s){ + Sequence extractSeq = this.copySeqence(); + ItemSet itemSet; + ArrayList items; + ArrayList deleteItemSets = new ArrayList<>(); + ArrayList tempItems = new ArrayList<>(); + + for (int k = 0; k < extractSeq.itemSetList.size(); k++) { + itemSet = extractSeq.itemSetList.get(k); + items = itemSet.getItems(); + if (items.size() == 1 && items.get(0).equals(s)) { + //如果找到的是单项,则完全移除,跳出循环 + extractSeq.itemSetList.remove(k); + break; + } else if (items.size() > 1 && !items.get(0).equals("_")) { + //在后续的多元素项中判断是否包含此元素 + if (items.contains(s)) { + //如果包含把s后面的元素加入到临时字符数组中 + int index = items.indexOf(s); + for (int j = index; j < items.size(); j++) { + tempItems.add(items.get(j)); + } + //将第一位的s变成下标符"_" + tempItems.set(0, "_"); + if (tempItems.size() == 1) { + // 如果此匹配为在最末端,同样移除 + deleteItemSets.add(itemSet); + } else { + //将变化后的项集替换原来的 + extractSeq.itemSetList.set(k, new ItemSet(tempItems)); + } + break; + } else { + deleteItemSets.add(itemSet); + } + } else { + // 不符合以上2项条件的统统移除 + deleteItemSets.add(itemSet); + } + } + extractSeq.itemSetList.removeAll(deleteItemSets); + + return extractSeq; + } + + /** + * 提取组合项之后的序列 + * + * @param array 组合数组 + */ + Sequence extractCompoentItem(ArrayList array){ + // 找到目标项,是否立刻停止 + Sequence seq = this.copySeqence(); + String lastItem = array.get(array.size() - 1); + ArrayList tempItems; + ArrayList deleteItems = new ArrayList<>(); + + for (int i = 0; i < seq.itemSetList.size(); i++) { + + tempItems = seq.itemSetList.get(i).getItems(); + // 分2种情况查找,第一种从_X中找出x等于项集最后的元素,因为_前缀已经为原本的元素 + if (tempItems.size() > 1 && tempItems.get(0).equals("_") + && tempItems.get(1).equals(lastItem)) { + if (tempItems.size() == 2) { + seq.itemSetList.remove(i); + } else { + // 把1号位置变为下标符"_",往后移1个字符的位置 + tempItems.set(1, "_"); + // 移除第一个的"_"下划符 + tempItems.remove(0); + } + break; + } else if (!tempItems.get(0).equals("_")) { + // 从没有_前缀的项集开始寻找,第二种为从后面的后缀中找出直接找出连续字符为ab为同一项集的项集 + if (strArrayContains(tempItems, array)) { + // 从左往右找出第一个给定字符的位置,把后面的部分截取出来 + int index = tempItems.indexOf(lastItem); + ArrayList array2 = new ArrayList<>(); + + for (int j = index; j < tempItems.size(); j++) { + array2.add(tempItems.get(j)); + } + array2.set(0, "_"); + + if (array2.size() == 1) { + //如果此项在末尾的位置,则移除该项,否则进行替换 + deleteItems.add(seq.itemSetList.get(i)); + } else { + seq.itemSetList.set(i, new ItemSet(array2)); + } + break; + } else { + deleteItems.add(seq.itemSetList.get(i)); + } + } else { + // 这种情况是处理_X中X不等于最后一个元素的情况 + deleteItems.add(seq.itemSetList.get(i)); + } + } + + seq.itemSetList.removeAll(deleteItems); + + return seq; + } + + /** + * 深拷贝一个序列 + */ + Sequence copySeqence(){ + Sequence copySeq = new Sequence(); + ItemSet tempItemSet; + ArrayList items; + + for (ItemSet itemSet : this.itemSetList) { + items = (ArrayList) itemSet.getItems().clone(); + tempItemSet = new ItemSet(items); + copySeq.getItemSetList().add(tempItemSet); + } + + return copySeq; + } + + /** + * 获取序列中最后一个项集的最后1个元素 + */ + public String getLastItemSetValue(){ + int size = this.getItemSetList().size(); + ItemSet itemSet = this.getItemSetList().get(size - 1); + size = itemSet.getItems().size(); + + return itemSet.getItems().get(size - 1); + } + + /** + * 判断strList2是否是strList1的子序列 + * + * @param strList1 序列1 + * @param strList2 序列2 + */ + private boolean strArrayContains(ArrayList strList1, + ArrayList strList2){ + boolean isContained = false; + + for (int i = 0; i < strList1.size() - strList2.size() + 1; i++) { + isContained = true; + + for (int j = 0, k = i; j < strList2.size(); j++, k++) { + if (!strList1.get(k).equals(strList2.get(j))) { + isContained = false; + break; + } + } + + if (isContained) { + break; + } + } + + return isContained; + } } diff --git a/StatisticalLearning/DataMining_EM/Client.java b/StatisticalLearning/DataMining_EM/Client.java index 044ebdb..8878f90 100644 --- a/StatisticalLearning/DataMining_EM/Client.java +++ b/StatisticalLearning/DataMining_EM/Client.java @@ -1,16 +1,16 @@ -package DataMining_EM; +package StatisticalLearning.DataMining_EM; /** * EM期望最大化算法场景调用类 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; - - EMTool tool = new EMTool(filePath); - tool.readDataFile(); - tool.exceptMaxStep(); - } + public static void main(String[] args){ + String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/StatisticalLearning/DataMining_EM/input.txt"; + + EMTool tool = new EMTool(filePath); + tool.readDataFile(); + tool.exceptMaxStep(); + } } diff --git a/StatisticalLearning/DataMining_EM/EMTool.java b/StatisticalLearning/DataMining_EM/EMTool.java index 4014bc2..6c5fcfa 100644 --- a/StatisticalLearning/DataMining_EM/EMTool.java +++ b/StatisticalLearning/DataMining_EM/EMTool.java @@ -1,4 +1,4 @@ -package DataMining_EM; +package StatisticalLearning.DataMining_EM; import java.io.BufferedReader; import java.io.File; @@ -9,138 +9,136 @@ /** * EM最大期望算法工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class EMTool { - // 测试数据文件地址 - private String dataFilePath; - // 测试坐标点数据 - private String[][] data; - // 测试坐标点数据列表 - private ArrayList pointArray; - // 目标C1点 - private Point p1; - // 目标C2点 - private Point p2; - - public EMTool(String dataFilePath) { - this.dataFilePath = dataFilePath; - pointArray = new ArrayList<>(); - } - - /** - * 从文件中读取数据 - */ - public void readDataFile() { - File file = new File(dataFilePath); - ArrayList dataArray = new ArrayList(); - - try { - BufferedReader in = new BufferedReader(new FileReader(file)); - String str; - String[] tempArray; - while ((str = in.readLine()) != null) { - tempArray = str.split(" "); - dataArray.add(tempArray); - } - in.close(); - } catch (IOException e) { - e.getStackTrace(); - } - - data = new String[dataArray.size()][]; - dataArray.toArray(data); - - // 开始时默认取头2个点作为2个簇中心 - p1 = new Point(Integer.parseInt(data[0][0]), - Integer.parseInt(data[0][1])); - p2 = new Point(Integer.parseInt(data[1][0]), - Integer.parseInt(data[1][1])); - - Point p; - for (String[] array : data) { - // 将数据转换为对象加入列表方便计算 - p = new Point(Integer.parseInt(array[0]), - Integer.parseInt(array[1])); - pointArray.add(p); - } - } - - /** - * 计算坐标点对于2个簇中心点的隶属度 - * - * @param p - * 待测试坐标点 - */ - private void computeMemberShip(Point p) { - // p点距离第一个簇中心点的距离 - double distance1 = 0; - // p距离第二个中心点的距离 - double distance2 = 0; - - // 用欧式距离计算 - distance1 = Math.pow(p.getX() - p1.getX(), 2) - + Math.pow(p.getY() - p1.getY(), 2); - distance2 = Math.pow(p.getX() - p2.getX(), 2) - + Math.pow(p.getY() - p2.getY(), 2); - - // 计算对于p1点的隶属度,与距离成反比关系,距离靠近越小,隶属度越大,所以要用大的distance2另外的距离来表示 - p.setMemberShip1(distance2 / (distance1 + distance2)); - // 计算对于p2点的隶属度 - p.setMemberShip2(distance1 / (distance1 + distance2)); - } - - /** - * 执行期望最大化步骤 - */ - public void exceptMaxStep() { - // 新的优化过的簇中心点 - double p1X = 0; - double p1Y = 0; - double p2X = 0; - double p2Y = 0; - double temp1 = 0; - double temp2 = 0; - // 误差值 - double errorValue1 = 0; - double errorValue2 = 0; - // 上次更新的簇点坐标 - Point lastP1 = null; - Point lastP2 = null; - - // 当开始计算的时候,或是中心点的误差值超过1的时候都需要再次迭代计算 - while (lastP1 == null || errorValue1 > 1.0 || errorValue2 > 1.0) { - for (Point p : pointArray) { - computeMemberShip(p); - p1X += p.getMemberShip1() * p.getMemberShip1() * p.getX(); - p1Y += p.getMemberShip1() * p.getMemberShip1() * p.getY(); - temp1 += p.getMemberShip1() * p.getMemberShip1(); - - p2X += p.getMemberShip2() * p.getMemberShip2() * p.getX(); - p2Y += p.getMemberShip2() * p.getMemberShip2() * p.getY(); - temp2 += p.getMemberShip2() * p.getMemberShip2(); - } - - lastP1 = new Point(p1.getX(), p1.getY()); - lastP2 = new Point(p2.getX(), p2.getY()); - - // 套公式计算新的簇中心点坐标,最最大化处理 - p1.setX(p1X / temp1); - p1.setY(p1Y / temp1); - p2.setX(p2X / temp2); - p2.setY(p2Y / temp2); - - errorValue1 = Math.abs(lastP1.getX() - p1.getX()) - + Math.abs(lastP1.getY() - p1.getY()); - errorValue2 = Math.abs(lastP2.getX() - p2.getX()) - + Math.abs(lastP2.getY() - p2.getY()); - } - - System.out.println(MessageFormat.format( - "簇中心节点p1({0}, {1}), p2({2}, {3})", p1.getX(), p1.getY(), - p2.getX(), p2.getY())); - } +class EMTool { + // 测试数据文件地址 + private String dataFilePath; + // 测试坐标点数据 + private String[][] data; + // 测试坐标点数据列表 + private ArrayList pointArray; + // 目标C1点 + private Point p1; + // 目标C2点 + private Point p2; + + EMTool(String dataFilePath){ + this.dataFilePath = dataFilePath; + pointArray = new ArrayList<>(); + } + + /** + * 从文件中读取数据 + */ + public void readDataFile(){ + File file = new File(dataFilePath); + ArrayList dataArray = new ArrayList<>(); + + try { + BufferedReader in = new BufferedReader(new FileReader(file)); + String str; + String[] tempArray; + while ((str = in.readLine()) != null) { + tempArray = str.split(" "); + dataArray.add(tempArray); + } + in.close(); + } catch (IOException e) { + e.getStackTrace(); + } + + data = new String[dataArray.size()][]; + dataArray.toArray(data); + + // 开始时默认取头2个点作为2个簇中心 + p1 = new Point(Integer.parseInt(data[0][0]), + Integer.parseInt(data[0][1])); + p2 = new Point(Integer.parseInt(data[1][0]), + Integer.parseInt(data[1][1])); + + Point p; + for (String[] array : data) { + // 将数据转换为对象加入列表方便计算 + p = new Point(Integer.parseInt(array[0]), + Integer.parseInt(array[1])); + pointArray.add(p); + } + } + + /** + * 计算坐标点对于2个簇中心点的隶属度 + * + * @param p 待测试坐标点 + */ + private void computeMemberShip(Point p){ + // p点距离第一个簇中心点的距离 + double distance1; + // p距离第二个中心点的距离 + double distance2; + + // 用欧式距离计算 + distance1 = Math.pow(p.getX() - p1.getX(), 2) + + Math.pow(p.getY() - p1.getY(), 2); + distance2 = Math.pow(p.getX() - p2.getX(), 2) + + Math.pow(p.getY() - p2.getY(), 2); + + // 计算对于p1点的隶属度,与距离成反比关系,距离靠近越小,隶属度越大,所以要用大的distance2另外的距离来表示 + p.setMemberShip1(distance2 / (distance1 + distance2)); + // 计算对于p2点的隶属度 + p.setMemberShip2(distance1 / (distance1 + distance2)); + } + + /** + * 执行期望最大化步骤 + */ + void exceptMaxStep(){ + // 新的优化过的簇中心点 + double p1X = 0; + double p1Y = 0; + double p2X = 0; + double p2Y = 0; + double temp1 = 0; + double temp2 = 0; + // 误差值 + double errorValue1 = 0; + double errorValue2 = 0; + // 上次更新的簇点坐标 + Point lastP1 = null; + Point lastP2; + + // 当开始计算的时候,或是中心点的误差值超过1的时候都需要再次迭代计算 + while (lastP1 == null || errorValue1 > 1.0 || errorValue2 > 1.0) { + for (Point p : pointArray) { + computeMemberShip(p); + p1X += p.getMemberShip1() * p.getMemberShip1() * p.getX(); + p1Y += p.getMemberShip1() * p.getMemberShip1() * p.getY(); + temp1 += p.getMemberShip1() * p.getMemberShip1(); + + p2X += p.getMemberShip2() * p.getMemberShip2() * p.getX(); + p2Y += p.getMemberShip2() * p.getMemberShip2() * p.getY(); + temp2 += p.getMemberShip2() * p.getMemberShip2(); + } + + lastP1 = new Point(p1.getX(), p1.getY()); + lastP2 = new Point(p2.getX(), p2.getY()); + + // 套公式计算新的簇中心点坐标,最最大化处理 + p1.setX(p1X / temp1); + p1.setY(p1Y / temp1); + p2.setX(p2X / temp2); + p2.setY(p2Y / temp2); + + errorValue1 = Math.abs(lastP1.getX() - p1.getX()) + + Math.abs(lastP1.getY() - p1.getY()); + errorValue2 = Math.abs(lastP2.getX() - p2.getX()) + + Math.abs(lastP2.getY() - p2.getY()); + } + + System.out.println(MessageFormat.format( + "簇中心节点p1({0}, {1}), p2({2}, {3})", p1.getX(), p1.getY(), + p2.getX(), p2.getY())); + } } diff --git a/StatisticalLearning/DataMining_EM/Point.java b/StatisticalLearning/DataMining_EM/Point.java index d4f3ae6..b8fb3b4 100644 --- a/StatisticalLearning/DataMining_EM/Point.java +++ b/StatisticalLearning/DataMining_EM/Point.java @@ -1,56 +1,55 @@ -package DataMining_EM; +package StatisticalLearning.DataMining_EM; /** * 坐标点类 - * - * @author lyq - * + * + * @author Qstar */ public class Point { - // 坐标点横坐标 - private double x; - // 坐标点纵坐标 - private double y; - // 坐标点对于P1的隶属度 - private double memberShip1; - // 坐标点对于P2的隶属度 - private double memberShip2; - - public Point(double d, double e) { - this.x = d; - this.y = e; - } - - public double getX() { - return x; - } - - public void setX(double x) { - this.x = x; - } - - public double getY() { - return y; - } - - public void setY(double y) { - this.y = y; - } - - public double getMemberShip1() { - return memberShip1; - } - - public void setMemberShip1(double memberShip1) { - this.memberShip1 = memberShip1; - } - - public double getMemberShip2() { - return memberShip2; - } - - public void setMemberShip2(double memberShip2) { - this.memberShip2 = memberShip2; - } + // 坐标点横坐标 + private double x; + // 坐标点纵坐标 + private double y; + // 坐标点对于P1的隶属度 + private double memberShip1; + // 坐标点对于P2的隶属度 + private double memberShip2; + + public Point(double d, double e){ + this.x = d; + this.y = e; + } + + public double getX(){ + return x; + } + + public void setX(double x){ + this.x = x; + } + + public double getY(){ + return y; + } + + public void setY(double y){ + this.y = y; + } + + double getMemberShip1(){ + return memberShip1; + } + + void setMemberShip1(double memberShip1){ + this.memberShip1 = memberShip1; + } + + double getMemberShip2(){ + return memberShip2; + } + + void setMemberShip2(double memberShip2){ + this.memberShip2 = memberShip2; + } } diff --git a/StatisticalLearning/DataMining_SVM/Client.java b/StatisticalLearning/DataMining_SVM/Client.java index c300298..8d2cbfd 100644 --- a/StatisticalLearning/DataMining_SVM/Client.java +++ b/StatisticalLearning/DataMining_SVM/Client.java @@ -1,20 +1,20 @@ -package DataMining_SVM; +package StatisticalLearning.DataMining_SVM; /** * SVM支持向量机场景调用类 - * @author lyq * + * @author Qstar */ public class Client { - public static void main(String[] args){ - //训练集数据文件路径 - String trainDataPath = "C:\\Users\\lyq\\Desktop\\icon\\trainInput.txt"; - //测试数据文件路径 - String testDataPath = "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt"; - - SVMTool tool = new SVMTool(trainDataPath); - //对测试数据进行svm支持向量机分类 - tool.svmPredictData(testDataPath); - } + public static void main(String[] args){ + //训练集数据文件路径 + String trainDataPath = "/Users/Qstar/Desktop/DataMiningAlgorithm/StatisticalLearning/DataMining_SVM/trainInput.txt"; + //测试数据文件路径 + String testDataPath = "/Users/Qstar/Desktop/DataMiningAlgorithm/StatisticalLearning/DataMining_SVM/testInput.txt"; + + SVMTool tool = new SVMTool(trainDataPath); + //对测试数据进行svm支持向量机分类 + tool.svmPredictData(testDataPath); + } } diff --git a/StatisticalLearning/DataMining_SVM/SVMTool.java b/StatisticalLearning/DataMining_SVM/SVMTool.java index 72adbbb..f23bcc9 100644 --- a/StatisticalLearning/DataMining_SVM/SVMTool.java +++ b/StatisticalLearning/DataMining_SVM/SVMTool.java @@ -1,4 +1,6 @@ -package DataMining_SVM; +package StatisticalLearning.DataMining_SVM; + +import StatisticalLearning.DataMining_SVM.libsvm.*; import java.io.BufferedReader; import java.io.File; @@ -6,168 +8,153 @@ import java.util.ArrayList; import java.util.List; -import DataMining_SVM.libsvm.svm; -import DataMining_SVM.libsvm.svm_model; -import DataMining_SVM.libsvm.svm_node; -import DataMining_SVM.libsvm.svm_parameter; -import DataMining_SVM.libsvm.svm_problem; - /** * SVM支持向量机工具类 - * - * @author lyq - * + * + * @author Qstar */ -public class SVMTool { - // 训练集数据文件路径 - private String trainDataPath; - // svm_problem对象,用于构造svm model模型 - private svm_problem sProblem; - // svm参数,里面有svm支持向量机的类型和不同 的svm的核函数类型 - private svm_parameter sParam; - - public SVMTool(String trainDataPath) { - this.trainDataPath = trainDataPath; - - // 初始化svm相关变量 - sProblem = initSvmProblem(); - sParam = initSvmParam(); - } - - /** - * 初始化操作,根据训练集数据构造分类模型 - */ - private void initOperation(){ - - } - - /** - * svm_problem对象,训练集数据的相关信息配置 - * - * @return - */ - private svm_problem initSvmProblem() { - List label = new ArrayList(); - List nodeSet = new ArrayList(); - getData(nodeSet, label, trainDataPath); - - int dataRange = nodeSet.get(0).length; - svm_node[][] datas = new svm_node[nodeSet.size()][dataRange]; // 训练集的向量表 - for (int i = 0; i < datas.length; i++) { - for (int j = 0; j < dataRange; j++) { - datas[i][j] = nodeSet.get(i)[j]; - } - } - double[] lables = new double[label.size()]; // a,b 对应的lable - for (int i = 0; i < lables.length; i++) { - lables[i] = label.get(i); - } - - // 定义svm_problem对象 - svm_problem problem = new svm_problem(); - problem.l = nodeSet.size(); // 向量个数 - problem.x = datas; // 训练集向量表 - problem.y = lables; // 对应的lable数组 - - return problem; - } - - /** - * 初始化svm支持向量机的参数,包括svm的类型和核函数的类型 - * - * @return - */ - private svm_parameter initSvmParam() { - // 定义svm_parameter对象 - svm_parameter param = new svm_parameter(); - param.svm_type = svm_parameter.EPSILON_SVR; - // 设置svm的核函数类型为线型 - param.kernel_type = svm_parameter.LINEAR; - // 后面的参数配置只针对训练集的数据 - param.cache_size = 100; - param.eps = 0.00001; - param.C = 1.9; - - return param; - } - - /** - * 通过svm方式预测数据的类型 - * - * @param testDataPath - */ - public void svmPredictData(String testDataPath) { - // 获取测试数据 - List testlabel = new ArrayList(); - List testnodeSet = new ArrayList(); - getData(testnodeSet, testlabel, testDataPath); - int dataRange = testnodeSet.get(0).length; - - svm_node[][] testdatas = new svm_node[testnodeSet.size()][dataRange]; // 训练集的向量表 - for (int i = 0; i < testdatas.length; i++) { - for (int j = 0; j < dataRange; j++) { - testdatas[i][j] = testnodeSet.get(i)[j]; - } - } - // 测试数据的真实值,在后面将会与svm的预测值做比较 - double[] testlables = new double[testlabel.size()]; // a,b 对应的lable - for (int i = 0; i < testlables.length; i++) { - testlables[i] = testlabel.get(i); - } - - // 如果参数没有问题,则svm.svm_check_parameter()函数返回null,否则返回error描述。 - // 对svm的配置参数叫验证,因为有些参数只针对部分的支持向量机的类型 - System.out.println(svm.svm_check_parameter(sProblem, sParam)); - System.out.println("------------检验参数-----------"); - // 训练SVM分类模型 - svm_model model = svm.svm_train(sProblem, sParam); - - // 预测测试数据的lable - double err = 0.0; - for (int i = 0; i < testdatas.length; i++) { - double truevalue = testlables[i]; - // 测试数据真实值 - System.out.print(truevalue + " "); - double predictValue = svm.svm_predict(model, testdatas[i]); - // 测试数据预测值 - System.out.println(predictValue); - } - } - - /** - * 从文件中获取数据 - * - * @param nodeSet - * 向量节点 - * @param label - * 节点值类型值 - * @param filename - * 数据文件地址 - */ - private void getData(List nodeSet, List label, - String filename) { - try { - - FileReader fr = new FileReader(new File(filename)); - BufferedReader br = new BufferedReader(fr); - String line = null; - while ((line = br.readLine()) != null) { - String[] datas = line.split(","); - svm_node[] vector = new svm_node[datas.length - 1]; - for (int i = 0; i < datas.length - 1; i++) { - svm_node node = new svm_node(); - node.index = i + 1; - node.value = Double.parseDouble(datas[i]); - vector[i] = node; - } - nodeSet.add(vector); - double lablevalue = Double.parseDouble(datas[datas.length - 1]); - label.add(lablevalue); - } - } catch (Exception e) { - e.printStackTrace(); - } - - } +class SVMTool { + // 训练集数据文件路径 + private String trainDataPath; + // svm_problem对象,用于构造svm model模型 + private svm_problem sProblem; + // svm参数,里面有svm支持向量机的类型和不同 的svm的核函数类型 + private svm_parameter sParam; + + SVMTool(String trainDataPath){ + this.trainDataPath = trainDataPath; + + // 初始化svm相关变量 + sProblem = initSvmProblem(); + sParam = initSvmParam(); + } + + /** + * 初始化操作,根据训练集数据构造分类模型 + */ + private void initOperation(){ + + } + + /** + * svm_problem对象,训练集数据的相关信息配置 + */ + private svm_problem initSvmProblem(){ + List label = new ArrayList<>(); + List nodeSet = new ArrayList<>(); + getData(nodeSet, label, trainDataPath); + + int dataRange = nodeSet.get(0).length; + svm_node[][] datas = new svm_node[nodeSet.size()][dataRange]; // 训练集的向量表 + for (int i = 0; i < datas.length; i++) { + for (int j = 0; j < dataRange; j++) { + datas[i][j] = nodeSet.get(i)[j]; + } + } + double[] lables = new double[label.size()]; // a,b 对应的lable + for (int i = 0; i < lables.length; i++) { + lables[i] = label.get(i); + } + + // 定义svm_problem对象 + svm_problem problem = new svm_problem(); + problem.l = nodeSet.size(); // 向量个数 + problem.x = datas; // 训练集向量表 + problem.y = lables; // 对应的lable数组 + + return problem; + } + + /** + * 初始化svm支持向量机的参数,包括svm的类型和核函数的类型 + */ + private svm_parameter initSvmParam(){ + // 定义svm_parameter对象 + svm_parameter param = new svm_parameter(); + param.svm_type = svm_parameter.EPSILON_SVR; + // 设置svm的核函数类型为线型 + param.kernel_type = svm_parameter.LINEAR; + // 后面的参数配置只针对训练集的数据 + param.cache_size = 100; + param.eps = 0.00001; + param.C = 1.9; + + return param; + } + + /** + * 通过svm方式预测数据的类型 + * + * @param testDataPath 测试数据路径 + */ + void svmPredictData(String testDataPath){ + // 获取测试数据 + List testlabel = new ArrayList<>(); + List testnodeSet = new ArrayList<>(); + getData(testnodeSet, testlabel, testDataPath); + int dataRange = testnodeSet.get(0).length; + + svm_node[][] testdatas = new svm_node[testnodeSet.size()][dataRange]; // 训练集的向量表 + for (int i = 0; i < testdatas.length; i++) { + for (int j = 0; j < dataRange; j++) { + testdatas[i][j] = testnodeSet.get(i)[j]; + } + } + // 测试数据的真实值,在后面将会与svm的预测值做比较 + double[] testlables = new double[testlabel.size()]; // a,b 对应的lable + for (int i = 0; i < testlables.length; i++) { + testlables[i] = testlabel.get(i); + } + + // 如果参数没有问题,则svm.svm_check_parameter()函数返回null,否则返回error描述。 + // 对svm的配置参数叫验证,因为有些参数只针对部分的支持向量机的类型 + System.out.println(svm.svm_check_parameter(sProblem, sParam)); + System.out.println("------------检验参数-----------"); + // 训练SVM分类模型 + svm_model model = svm.svm_train(sProblem, sParam); + + // 预测测试数据的lable + for (int i = 0; i < testdatas.length; i++) { + double truevalue = testlables[i]; + // 测试数据真实值 + System.out.print(truevalue + " "); + double predictValue = svm.svm_predict(model, testdatas[i]); + // 测试数据预测值 + System.out.println(predictValue); + } + } + + /** + * 从文件中获取数据 + * + * @param nodeSet 向量节点 + * @param label 节点值类型值 + * @param filename 数据文件地址 + */ + private void getData(List nodeSet, List label, + String filename){ + try { + + FileReader fr = new FileReader(new File(filename)); + BufferedReader br = new BufferedReader(fr); + String line; + while ((line = br.readLine()) != null) { + String[] datas = line.split(","); + svm_node[] vector = new svm_node[datas.length - 1]; + for (int i = 0; i < datas.length - 1; i++) { + svm_node node = new svm_node(); + node.index = i + 1; + node.value = Double.parseDouble(datas[i]); + vector[i] = node; + } + nodeSet.add(vector); + double lablevalue = Double.parseDouble(datas[datas.length - 1]); + label.add(lablevalue); + } + } catch (Exception e) { + e.printStackTrace(); + } + + } } diff --git a/StatisticalLearning/DataMining_SVM/libsvm/svm.java b/StatisticalLearning/DataMining_SVM/libsvm/svm.java index f4f24d1..06af67e 100644 --- a/StatisticalLearning/DataMining_SVM/libsvm/svm.java +++ b/StatisticalLearning/DataMining_SVM/libsvm/svm.java @@ -1,10 +1,7 @@ +package StatisticalLearning.DataMining_SVM.libsvm; - - - -package DataMining_SVM.libsvm; import java.io.*; -import java.util.*; +import java.util.StringTokenizer; // // Kernel Cache @@ -13,110 +10,120 @@ // size is the cache size limit in bytes // class Cache { - private final int l; - private long size; - private final class head_t - { - head_t prev, next; // a cicular list - float[] data; - int len; // data[0,len) is cached in this entry - } - private final head_t[] head; - private head_t lru_head; - - Cache(int l_, long size_) - { - l = l_; - size = size_; - head = new head_t[l]; - for(int i=0;i= len if nothing needs to be filled) - // java: simulate pointer using single-element array - int get_data(int index, float[][] data, int len) - { - head_t h = head[index]; - if(h.len > 0) lru_delete(h); - int more = len - h.len; - - if(more > 0) - { - // free old space - while(size < more) - { - head_t old = lru_head.next; - lru_delete(old); - size += old.len; - old.data = null; - old.len = 0; - } - - // allocate new space - float[] new_data = new float[len]; - if(h.data != null) System.arraycopy(h.data,0,new_data,0,h.len); - h.data = new_data; - size -= more; - do {int _=h.len; h.len=len; len=_;} while(false); - } - - lru_insert(h); - data[0] = h.data; - return len; - } - - void swap_index(int i, int j) - { - if(i==j) return; - - if(head[i].len > 0) lru_delete(head[i]); - if(head[j].len > 0) lru_delete(head[j]); - do {float[] _=head[i].data; head[i].data=head[j].data; head[j].data=_;} while(false); - do {int _=head[i].len; head[i].len=head[j].len; head[j].len=_;} while(false); - if(head[i].len > 0) lru_insert(head[i]); - if(head[j].len > 0) lru_insert(head[j]); - - if(i>j) do {int _=i; i=j; j=_;} while(false); - for(head_t h = lru_head.next; h!=lru_head; h=h.next) - { - if(h.len > i) - { - if(h.len > j) - do {float _=h.data[i]; h.data[i]=h.data[j]; h.data[j]=_;} while(false); - else - { - // give up - lru_delete(h); - size += h.len; - h.data = null; - h.len = 0; - } - } - } - } + private final int l; + private final head_t[] head; + private long size; + private head_t lru_head; + + Cache(int l_, long size_){ + l = l_; + size = size_; + head = new head_t[l]; + for (int i = 0; i < l; i++) head[i] = new head_t(); + size /= 4; + size -= l * (16 / 4); // sizeof(head_t) == 16 + size = Math.max(size, 2 * (long) l); // cache must be large enough for two columns + lru_head = new head_t(); + lru_head.next = lru_head.prev = lru_head; + } + + private void lru_delete(head_t h){ + // delete from current location + h.prev.next = h.next; + h.next.prev = h.prev; + } + + private void lru_insert(head_t h){ + // insert to last position + h.next = lru_head; + h.prev = lru_head.prev; + h.prev.next = h; + h.next.prev = h; + } + + // request data [0,len) + // return some position p where [p,len) need to be filled + // (p >= len if nothing needs to be filled) + // java: simulate pointer using single-element array + int get_data(int index, float[][] data, int len){ + head_t h = head[index]; + if (h.len > 0) lru_delete(h); + int more = len - h.len; + + if (more > 0) { + // free old space + while (size < more) { + head_t old = lru_head.next; + lru_delete(old); + size += old.len; + old.data = null; + old.len = 0; + } + + // allocate new space + float[] new_data = new float[len]; + if (h.data != null) System.arraycopy(h.data, 0, new_data, 0, h.len); + h.data = new_data; + size -= more; + do { + int temp = h.len; + h.len = len; + len = temp; + } while (false); + } + + lru_insert(h); + data[0] = h.data; + return len; + } + + void swap_index(int i, int j){ + if (i == j) return; + + if (head[i].len > 0) lru_delete(head[i]); + if (head[j].len > 0) lru_delete(head[j]); + do { + float[] temp = head[i].data; + head[i].data = head[j].data; + head[j].data = temp; + } while (false); + do { + int temp = head[i].len; + head[i].len = head[j].len; + head[j].len = temp; + } while (false); + if (head[i].len > 0) lru_insert(head[i]); + if (head[j].len > 0) lru_insert(head[j]); + + if (i > j) do { + int temp = i; + i = j; + j = temp; + } while (false); + for (head_t h = lru_head.next; h != lru_head; h = h.next) { + if (h.len > i) { + if (h.len > j) + do { + float temp = h.data[i]; + h.data[i] = h.data[j]; + h.data[j] = temp; + } while (false); + else { + // give up + lru_delete(h); + size += h.len; + h.data = null; + h.len = 0; + } + } + } + } + + private final class head_t { + head_t prev, next; // a cicular list + float[] data; + int len; // data[0,len) is cached in this entry + } } // @@ -127,158 +134,146 @@ void swap_index(int i, int j) // the member function get_Q is for getting one column from the Q Matrix // abstract class QMatrix { - abstract float[] get_Q(int column, int len); - abstract float[] get_QD(); - abstract void swap_index(int i, int j); -}; + abstract float[] get_Q(int column, int len); + + abstract float[] get_QD(); + + abstract void swap_index(int i, int j); +} abstract class Kernel extends QMatrix { - private svm_node[][] x; - private final double[] x_square; - - // svm_parameter - private final int kernel_type; - private final int degree; - private final double gamma; - private final double coef0; - - abstract float[] get_Q(int column, int len); - abstract float[] get_QD(); - - void swap_index(int i, int j) - { - do {svm_node[] _=x[i]; x[i]=x[j]; x[j]=_;} while(false); - if(x_square != null) do {double _=x_square[i]; x_square[i]=x_square[j]; x_square[j]=_;} while(false); - } - - private static double powi(double base, int times) - { - double tmp = base, ret = 1.0; - - for(int t=times; t>0; t/=2) - { - if(t%2==1) ret*=tmp; - tmp = tmp * tmp; - } - return ret; - } - - double kernel_function(int i, int j) - { - switch(kernel_type) - { - case svm_parameter.LINEAR: - return dot(x[i],x[j]); - case svm_parameter.POLY: - return powi(gamma*dot(x[i],x[j])+coef0,degree); - case svm_parameter.RBF: - return Math.exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j]))); - case svm_parameter.SIGMOID: - return Math.tanh(gamma*dot(x[i],x[j])+coef0); - case svm_parameter.PRECOMPUTED: - return x[i][(int)(x[j][0].value)].value; - default: - return 0; // java - } - } - - Kernel(int l, svm_node[][] x_, svm_parameter param) - { - this.kernel_type = param.kernel_type; - this.degree = param.degree; - this.gamma = param.gamma; - this.coef0 = param.coef0; - - x = (svm_node[][])x_.clone(); - - if(kernel_type == svm_parameter.RBF) - { - x_square = new double[l]; - for(int i=0;i y[j].index) - ++j; - else - ++i; - } - } - return sum; - } - - static double k_function(svm_node[] x, svm_node[] y, - svm_parameter param) - { - switch(param.kernel_type) - { - case svm_parameter.LINEAR: - return dot(x,y); - case svm_parameter.POLY: - return powi(param.gamma*dot(x,y)+param.coef0,param.degree); - case svm_parameter.RBF: - { - double sum = 0; - int xlen = x.length; - int ylen = y.length; - int i = 0; - int j = 0; - while(i < xlen && j < ylen) - { - if(x[i].index == y[j].index) - { - double d = x[i++].value - y[j++].value; - sum += d*d; - } - else if(x[i].index > y[j].index) - { - sum += y[j].value * y[j].value; - ++j; - } - else - { - sum += x[i].value * x[i].value; - ++i; - } - } - - while(i < xlen) - { - sum += x[i].value * x[i].value; - ++i; - } - - while(j < ylen) - { - sum += y[j].value * y[j].value; - ++j; - } - - return Math.exp(-param.gamma*sum); - } - case svm_parameter.SIGMOID: - return Math.tanh(param.gamma*dot(x,y)+param.coef0); - case svm_parameter.PRECOMPUTED: - return x[(int)(y[0].value)].value; - default: - return 0; // java - } - } + private final double[] x_square; + // svm_parameter + private final int kernel_type; + private final int degree; + private final double gamma; + private final double coef0; + private svm_node[][] x; + + Kernel(int l, svm_node[][] x_, svm_parameter param){ + this.kernel_type = param.kernel_type; + this.degree = param.degree; + this.gamma = param.gamma; + this.coef0 = param.coef0; + + x = x_.clone(); + + if (kernel_type == svm_parameter.RBF) { + x_square = new double[l]; + for (int i = 0; i < l; i++) + x_square[i] = dot(x[i], x[i]); + } else x_square = null; + } + + private static double powi(double base, int times){ + double tmp = base, ret = 1.0; + + for (int t = times; t > 0; t /= 2) { + if (t % 2 == 1) ret *= tmp; + tmp = tmp * tmp; + } + return ret; + } + + static double dot(svm_node[] x, svm_node[] y){ + double sum = 0; + int xlen = x.length; + int ylen = y.length; + int i = 0; + int j = 0; + while (i < xlen && j < ylen) { + if (x[i].index == y[j].index) + sum += x[i++].value * y[j++].value; + else { + if (x[i].index > y[j].index) + ++j; + else + ++i; + } + } + return sum; + } + + static double k_function(svm_node[] x, svm_node[] y, + svm_parameter param){ + switch (param.kernel_type) { + case svm_parameter.LINEAR: + return dot(x, y); + case svm_parameter.POLY: + return powi(param.gamma * dot(x, y) + param.coef0, param.degree); + case svm_parameter.RBF: { + double sum = 0; + int xlen = x.length; + int ylen = y.length; + int i = 0; + int j = 0; + while (i < xlen && j < ylen) { + if (x[i].index == y[j].index) { + double d = x[i++].value - y[j++].value; + sum += d * d; + } else if (x[i].index > y[j].index) { + sum += y[j].value * y[j].value; + ++j; + } else { + sum += x[i].value * x[i].value; + ++i; + } + } + + while (i < xlen) { + sum += x[i].value * x[i].value; + ++i; + } + + while (j < ylen) { + sum += y[j].value * y[j].value; + ++j; + } + + return Math.exp(-param.gamma * sum); + } + case svm_parameter.SIGMOID: + return Math.tanh(param.gamma * dot(x, y) + param.coef0); + case svm_parameter.PRECOMPUTED: + return x[(int) (y[0].value)].value; + default: + return 0; // java + } + } + + abstract float[] get_Q(int column, int len); + + abstract float[] get_QD(); + + void swap_index(int i, int j){ + do { + svm_node[] temp = x[i]; + x[i] = x[j]; + x[j] = temp; + } while (false); + if (x_square != null) do { + double temp = x_square[i]; + x_square[i] = x_square[j]; + x_square[j] = temp; + } while (false); + } + + double kernel_function(int i, int j){ + switch (kernel_type) { + case svm_parameter.LINEAR: + return dot(x[i], x[j]); + case svm_parameter.POLY: + return powi(gamma * dot(x[i], x[j]) + coef0, degree); + case svm_parameter.RBF: + return Math.exp(-gamma * (x_square[i] + x_square[j] - 2 * dot(x[i], x[j]))); + case svm_parameter.SIGMOID: + return Math.tanh(gamma * dot(x[i], x[j]) + coef0); + case svm_parameter.PRECOMPUTED: + return x[i][(int) (x[j][0].value)].value; + default: + return 0; // java + } + } } // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918 @@ -300,578 +295,539 @@ else if(x[i].index > y[j].index) // solution will be put in \alpha, objective value will be put in obj // class Solver { - int active_size; - byte[] y; - double[] G; // gradient of objective function - static final byte LOWER_BOUND = 0; - static final byte UPPER_BOUND = 1; - static final byte FREE = 2; - byte[] alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE - double[] alpha; - QMatrix Q; - float[] QD; - double eps; - double Cp,Cn; - double[] p; - int[] active_set; - double[] G_bar; // gradient, if we treat free variables as 0 - int l; - boolean unshrink; // XXX - - static final double INF = java.lang.Double.POSITIVE_INFINITY; - - double get_C(int i) - { - return (y[i] > 0)? Cp : Cn; - } - void update_alpha_status(int i) - { - if(alpha[i] >= get_C(i)) - alpha_status[i] = UPPER_BOUND; - else if(alpha[i] <= 0) - alpha_status[i] = LOWER_BOUND; - else alpha_status[i] = FREE; - } - boolean is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; } - boolean is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; } - boolean is_free(int i) { return alpha_status[i] == FREE; } - - // java: information about solution except alpha, - // because we cannot return multiple values otherwise... - static class SolutionInfo { - double obj; - double rho; - double upper_bound_p; - double upper_bound_n; - double r; // for Solver_NU - } - - void swap_index(int i, int j) - { - Q.swap_index(i,j); - do {byte _=y[i]; y[i]=y[j]; y[j]=_;} while(false); - do {double _=G[i]; G[i]=G[j]; G[j]=_;} while(false); - do {byte _=alpha_status[i]; alpha_status[i]=alpha_status[j]; alpha_status[j]=_;} while(false); - do {double _=alpha[i]; alpha[i]=alpha[j]; alpha[j]=_;} while(false); - do {double _=p[i]; p[i]=p[j]; p[j]=_;} while(false); - do {int _=active_set[i]; active_set[i]=active_set[j]; active_set[j]=_;} while(false); - do {double _=G_bar[i]; G_bar[i]=G_bar[j]; G_bar[j]=_;} while(false); - } - - void reconstruct_gradient() - { - // reconstruct inactive elements of G from G_bar and free variables - - if(active_size == l) return; - - int i,j; - int nr_free = 0; - - for(j=active_size;j 2*active_size*(l-active_size)) - { - for(i=active_size;i 0) - { - if(alpha[j] < 0) - { - alpha[j] = 0; - alpha[i] = diff; - } - } - else - { - if(alpha[i] < 0) - { - alpha[i] = 0; - alpha[j] = -diff; - } - } - if(diff > C_i - C_j) - { - if(alpha[i] > C_i) - { - alpha[i] = C_i; - alpha[j] = C_i - diff; - } - } - else - { - if(alpha[j] > C_j) - { - alpha[j] = C_j; - alpha[i] = C_j + diff; - } - } - } - else - { - double quad_coef = Q_i[i]+Q_j[j]-2*Q_i[j]; - if (quad_coef <= 0) - quad_coef = 1e-12; - double delta = (G[i]-G[j])/quad_coef; - double sum = alpha[i] + alpha[j]; - alpha[i] -= delta; - alpha[j] += delta; - - if(sum > C_i) - { - if(alpha[i] > C_i) - { - alpha[i] = C_i; - alpha[j] = sum - C_i; - } - } - else - { - if(alpha[j] < 0) - { - alpha[j] = 0; - alpha[i] = sum; - } - } - if(sum > C_j) - { - if(alpha[j] > C_j) - { - alpha[j] = C_j; - alpha[i] = sum - C_j; - } - } - else - { - if(alpha[i] < 0) - { - alpha[i] = 0; - alpha[j] = sum; - } - } - } - - // update G - - double delta_alpha_i = alpha[i] - old_alpha_i; - double delta_alpha_j = alpha[j] - old_alpha_j; - - for(int k=0;k= Gmax) - { - Gmax = -G[t]; - Gmax_idx = t; - } - } - else - { - if(!is_lower_bound(t)) - if(G[t] >= Gmax) - { - Gmax = G[t]; - Gmax_idx = t; - } - } - - int i = Gmax_idx; - float[] Q_i = null; - if(i != -1) // null Q_i not accessed: Gmax=-INF if i=-1 - Q_i = Q.get_Q(i,active_size); - - for(int j=0;j= Gmax2) - Gmax2 = G[j]; - if (grad_diff > 0) - { - double obj_diff; - double quad_coef=Q_i[i]+QD[j]-2.0*y[i]*Q_i[j]; - if (quad_coef > 0) - obj_diff = -(grad_diff*grad_diff)/quad_coef; - else - obj_diff = -(grad_diff*grad_diff)/1e-12; - - if (obj_diff <= obj_diff_min) - { - Gmin_idx=j; - obj_diff_min = obj_diff; - } - } - } - } - else - { - if (!is_upper_bound(j)) - { - double grad_diff= Gmax-G[j]; - if (-G[j] >= Gmax2) - Gmax2 = -G[j]; - if (grad_diff > 0) - { - double obj_diff; - double quad_coef=Q_i[i]+QD[j]+2.0*y[i]*Q_i[j]; - if (quad_coef > 0) - obj_diff = -(grad_diff*grad_diff)/quad_coef; - else - obj_diff = -(grad_diff*grad_diff)/1e-12; - - if (obj_diff <= obj_diff_min) - { - Gmin_idx=j; - obj_diff_min = obj_diff; - } - } - } - } - } - - if(Gmax+Gmax2 < eps) - return 1; - - working_set[0] = Gmax_idx; - working_set[1] = Gmin_idx; - return 0; - } - - private boolean be_shrunk(int i, double Gmax1, double Gmax2) - { - if(is_upper_bound(i)) - { - if(y[i]==+1) - return(-G[i] > Gmax1); - else - return(-G[i] > Gmax2); - } - else if(is_lower_bound(i)) - { - if(y[i]==+1) - return(G[i] > Gmax2); - else - return(G[i] > Gmax1); - } - else - return(false); - } - - void do_shrinking() - { - int i; - double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) } - double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) } - - // find maximal violating pair first - for(i=0;i= Gmax1) - Gmax1 = -G[i]; - } - if(!is_lower_bound(i)) - { - if(G[i] >= Gmax2) - Gmax2 = G[i]; - } - } - else - { - if(!is_upper_bound(i)) - { - if(-G[i] >= Gmax2) - Gmax2 = -G[i]; - } - if(!is_lower_bound(i)) - { - if(G[i] >= Gmax1) - Gmax1 = G[i]; - } - } - } - - if(unshrink == false && Gmax1 + Gmax2 <= eps*10) - { - unshrink = true; - reconstruct_gradient(); - active_size = l; - } - - for(i=0;i i) - { - if (!be_shrunk(active_size, Gmax1, Gmax2)) - { - swap_index(i,active_size); - break; - } - active_size--; - } - } - } - - double calculate_rho() - { - double r; - int nr_free = 0; - double ub = INF, lb = -INF, sum_free = 0; - for(int i=0;i 0) - ub = Math.min(ub,yG); - else - lb = Math.max(lb,yG); - } - else if(is_upper_bound(i)) - { - if(y[i] < 0) - ub = Math.min(ub,yG); - else - lb = Math.max(lb,yG); - } - else - { - ++nr_free; - sum_free += yG; - } - } - - if(nr_free>0) - r = sum_free/nr_free; - else - r = (ub+lb)/2; - - return r; - } + static final double INF = java.lang.Double.POSITIVE_INFINITY; + private static final byte LOWER_BOUND = 0; + private static final byte UPPER_BOUND = 1; + private static final byte FREE = 2; + int active_size; + byte[] y; + double[] G; // gradient of objective function + QMatrix Q; + float[] QD; + double eps; + int l; + boolean unshrink; // XXX + private double[] G_bar; // gradient, if we treat free variables as 0 + private byte[] alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE + private double[] alpha; + private double Cp, Cn; + private double[] p; + private int[] active_set; + + private double get_C(int i){ + return (y[i] > 0) ? Cp : Cn; + } + + private void update_alpha_status(int i){ + if (alpha[i] >= get_C(i)) + alpha_status[i] = UPPER_BOUND; + else if (alpha[i] <= 0) + alpha_status[i] = LOWER_BOUND; + else alpha_status[i] = FREE; + } + + boolean is_upper_bound(int i){ + return alpha_status[i] == UPPER_BOUND; + } + + boolean is_lower_bound(int i){ + return alpha_status[i] == LOWER_BOUND; + } + + private boolean is_free(int i){ + return alpha_status[i] == FREE; + } + + void swap_index(int i, int j){ + Q.swap_index(i, j); + do { + byte temp = y[i]; + y[i] = y[j]; + y[j] = temp; + } while (false); + do { + double temp = G[i]; + G[i] = G[j]; + G[j] = temp; + } while (false); + do { + byte temp = alpha_status[i]; + alpha_status[i] = alpha_status[j]; + alpha_status[j] = temp; + } while (false); + do { + double temp = alpha[i]; + alpha[i] = alpha[j]; + alpha[j] = temp; + } while (false); + do { + double temp = p[i]; + p[i] = p[j]; + p[j] = temp; + } while (false); + do { + int temp = active_set[i]; + active_set[i] = active_set[j]; + active_set[j] = temp; + } while (false); + do { + double temp = G_bar[i]; + G_bar[i] = G_bar[j]; + G_bar[j] = temp; + } while (false); + } + + void reconstruct_gradient(){ + // reconstruct inactive elements of G from G_bar and free variables + + if (active_size == l) return; + + int i, j; + int nr_free = 0; + + for (j = active_size; j < l; j++) + G[j] = G_bar[j] + p[j]; + + for (j = 0; j < active_size; j++) + if (is_free(j)) + nr_free++; + + if (2 * nr_free < active_size) + svm.info("\nWarning: using -h 0 may be faster\n"); + + if (nr_free * l > 2 * active_size * (l - active_size)) { + for (i = active_size; i < l; i++) { + float[] Q_i = Q.get_Q(i, active_size); + for (j = 0; j < active_size; j++) + if (is_free(j)) + G[i] += alpha[j] * Q_i[j]; + } + } else { + for (i = 0; i < active_size; i++) + if (is_free(i)) { + float[] Q_i = Q.get_Q(i, l); + double alpha_i = alpha[i]; + for (j = active_size; j < l; j++) + G[j] += alpha_i * Q_i[j]; + } + } + } + + void Solve(int l, QMatrix Q, double[] p_, byte[] y_, + double[] alpha_, double Cp, double Cn, double eps, SolutionInfo si, int shrinking){ + this.l = l; + this.Q = Q; + QD = Q.get_QD(); + p = p_.clone(); + y = y_.clone(); + alpha = alpha_.clone(); + this.Cp = Cp; + this.Cn = Cn; + this.eps = eps; + this.unshrink = false; + + // initialize alpha_status + { + alpha_status = new byte[l]; + for (int i = 0; i < l; i++) + update_alpha_status(i); + } + + // initialize active set (for shrinking) + { + active_set = new int[l]; + for (int i = 0; i < l; i++) + active_set[i] = i; + active_size = l; + } + + // initialize gradient + { + G = new double[l]; + G_bar = new double[l]; + int i; + for (i = 0; i < l; i++) { + G[i] = p[i]; + G_bar[i] = 0; + } + for (i = 0; i < l; i++) + if (!is_lower_bound(i)) { + float[] Q_i = Q.get_Q(i, l); + double alpha_i = alpha[i]; + int j; + for (j = 0; j < l; j++) + G[j] += alpha_i * Q_i[j]; + if (is_upper_bound(i)) + for (j = 0; j < l; j++) + G_bar[j] += get_C(i) * Q_i[j]; + } + } + + // optimization step + + int iter = 0; + int counter = Math.min(l, 1000) + 1; + int[] working_set = new int[2]; + + while (true) { + // show progress and do shrinking + + if (--counter == 0) { + counter = Math.min(l, 1000); + if (shrinking != 0) do_shrinking(); + svm.info("."); + } + + if (select_working_set(working_set) != 0) { + // reconstruct the whole gradient + reconstruct_gradient(); + // reset active set size and check + active_size = l; + svm.info("*"); + if (select_working_set(working_set) != 0) + break; + else + counter = 1; // do shrinking next iteration + } + + int i = working_set[0]; + int j = working_set[1]; + + ++iter; + + // update alpha[i] and alpha[j], handle bounds carefully + + float[] Q_i = Q.get_Q(i, active_size); + float[] Q_j = Q.get_Q(j, active_size); + + double C_i = get_C(i); + double C_j = get_C(j); + + double old_alpha_i = alpha[i]; + double old_alpha_j = alpha[j]; + + if (y[i] != y[j]) { + double quad_coef = Q_i[i] + Q_j[j] + 2 * Q_i[j]; + if (quad_coef <= 0) + quad_coef = 1e-12; + double delta = (-G[i] - G[j]) / quad_coef; + double diff = alpha[i] - alpha[j]; + alpha[i] += delta; + alpha[j] += delta; + + if (diff > 0) { + if (alpha[j] < 0) { + alpha[j] = 0; + alpha[i] = diff; + } + } else { + if (alpha[i] < 0) { + alpha[i] = 0; + alpha[j] = -diff; + } + } + if (diff > C_i - C_j) { + if (alpha[i] > C_i) { + alpha[i] = C_i; + alpha[j] = C_i - diff; + } + } else { + if (alpha[j] > C_j) { + alpha[j] = C_j; + alpha[i] = C_j + diff; + } + } + } else { + double quad_coef = Q_i[i] + Q_j[j] - 2 * Q_i[j]; + if (quad_coef <= 0) + quad_coef = 1e-12; + double delta = (G[i] - G[j]) / quad_coef; + double sum = alpha[i] + alpha[j]; + alpha[i] -= delta; + alpha[j] += delta; + + if (sum > C_i) { + if (alpha[i] > C_i) { + alpha[i] = C_i; + alpha[j] = sum - C_i; + } + } else { + if (alpha[j] < 0) { + alpha[j] = 0; + alpha[i] = sum; + } + } + if (sum > C_j) { + if (alpha[j] > C_j) { + alpha[j] = C_j; + alpha[i] = sum - C_j; + } + } else { + if (alpha[i] < 0) { + alpha[i] = 0; + alpha[j] = sum; + } + } + } + + // update G + + double delta_alpha_i = alpha[i] - old_alpha_i; + double delta_alpha_j = alpha[j] - old_alpha_j; + + for (int k = 0; k < active_size; k++) { + G[k] += Q_i[k] * delta_alpha_i + Q_j[k] * delta_alpha_j; + } + + // update alpha_status and G_bar + + { + boolean ui = is_upper_bound(i); + boolean uj = is_upper_bound(j); + update_alpha_status(i); + update_alpha_status(j); + int k; + if (ui != is_upper_bound(i)) { + Q_i = Q.get_Q(i, l); + if (ui) + for (k = 0; k < l; k++) + G_bar[k] -= C_i * Q_i[k]; + else + for (k = 0; k < l; k++) + G_bar[k] += C_i * Q_i[k]; + } + + if (uj != is_upper_bound(j)) { + Q_j = Q.get_Q(j, l); + if (uj) + for (k = 0; k < l; k++) + G_bar[k] -= C_j * Q_j[k]; + else + for (k = 0; k < l; k++) + G_bar[k] += C_j * Q_j[k]; + } + } + + } + + // calculate rho + + si.rho = calculate_rho(); + + // calculate objective value + { + double v = 0; + int i; + for (i = 0; i < l; i++) + v += alpha[i] * (G[i] + p[i]); + + si.obj = v / 2; + } + + // put back the solution + { + for (int i = 0; i < l; i++) + alpha_[active_set[i]] = alpha[i]; + } + + si.upper_bound_p = Cp; + si.upper_bound_n = Cn; + + svm.info("\noptimization finished, #iter = " + iter + "\n"); + } + + // return 1 if already optimal, return 0 otherwise + int select_working_set(int[] working_set){ + // return i,j such that + // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha) + // j: mimimizes the decrease of obj value + // (if quadratic coefficeint <= 0, replace it with tau) + // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha) + + double Gmax = -INF; + double Gmax2 = -INF; + int Gmax_idx = -1; + int Gmin_idx = -1; + double obj_diff_min = INF; + + for (int t = 0; t < active_size; t++) + if (y[t] == +1) { + if (!is_upper_bound(t)) + if (-G[t] >= Gmax) { + Gmax = -G[t]; + Gmax_idx = t; + } + } else { + if (!is_lower_bound(t)) + if (G[t] >= Gmax) { + Gmax = G[t]; + Gmax_idx = t; + } + } + + int i = Gmax_idx; + float[] Q_i = null; + if (i != -1) // null Q_i not accessed: Gmax=-INF if i=-1 + Q_i = Q.get_Q(i, active_size); + + for (int j = 0; j < active_size; j++) { + if (y[j] == +1) { + if (!is_lower_bound(j)) { + double grad_diff = Gmax + G[j]; + if (G[j] >= Gmax2) + Gmax2 = G[j]; + if (grad_diff > 0) { + double obj_diff; + double quad_coef = 0; + if (Q_i != null) { + quad_coef = Q_i[i] + QD[j] - 2.0 * y[i] * Q_i[j]; + } + if (quad_coef > 0) + obj_diff = -(grad_diff * grad_diff) / quad_coef; + else + obj_diff = -(grad_diff * grad_diff) / 1e-12; + + if (obj_diff <= obj_diff_min) { + Gmin_idx = j; + obj_diff_min = obj_diff; + } + } + } + } else { + if (!is_upper_bound(j)) { + double grad_diff = Gmax - G[j]; + if (-G[j] >= Gmax2) + Gmax2 = -G[j]; + if (grad_diff > 0) { + double obj_diff; + double quad_coef = 0; + if (Q_i != null) { + quad_coef = Q_i[i] + QD[j] + 2.0 * y[i] * Q_i[j]; + } + if (quad_coef > 0) + obj_diff = -(grad_diff * grad_diff) / quad_coef; + else + obj_diff = -(grad_diff * grad_diff) / 1e-12; + + if (obj_diff <= obj_diff_min) { + Gmin_idx = j; + obj_diff_min = obj_diff; + } + } + } + } + } + + if (Gmax + Gmax2 < eps) + return 1; + + working_set[0] = Gmax_idx; + working_set[1] = Gmin_idx; + return 0; + } + + private boolean be_shrunk(int i, double Gmax1, double Gmax2){ + if (is_upper_bound(i)) { + if (y[i] == +1) + return (-G[i] > Gmax1); + else + return (-G[i] > Gmax2); + } else if (is_lower_bound(i)) { + if (y[i] == +1) + return (G[i] > Gmax2); + else + return (G[i] > Gmax1); + } else + return (false); + } + + void do_shrinking(){ + int i; + double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) } + double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) } + + // find maximal violating pair first + for (i = 0; i < active_size; i++) { + if (y[i] == +1) { + if (!is_upper_bound(i)) { + if (-G[i] >= Gmax1) + Gmax1 = -G[i]; + } + if (!is_lower_bound(i)) { + if (G[i] >= Gmax2) + Gmax2 = G[i]; + } + } else { + if (!is_upper_bound(i)) { + if (-G[i] >= Gmax2) + Gmax2 = -G[i]; + } + if (!is_lower_bound(i)) { + if (G[i] >= Gmax1) + Gmax1 = G[i]; + } + } + } + + if (!unshrink && Gmax1 + Gmax2 <= eps * 10) { + unshrink = true; + reconstruct_gradient(); + active_size = l; + } + + for (i = 0; i < active_size; i++) + if (be_shrunk(i, Gmax1, Gmax2)) { + active_size--; + while (active_size > i) { + if (!be_shrunk(active_size, Gmax1, Gmax2)) { + swap_index(i, active_size); + break; + } + active_size--; + } + } + } + + double calculate_rho(){ + double r; + int nr_free = 0; + double ub = INF, lb = -INF, sum_free = 0; + for (int i = 0; i < active_size; i++) { + double yG = y[i] * G[i]; + + if (is_lower_bound(i)) { + if (y[i] > 0) + ub = Math.min(ub, yG); + else + lb = Math.max(lb, yG); + } else if (is_upper_bound(i)) { + if (y[i] < 0) + ub = Math.min(ub, yG); + else + lb = Math.max(lb, yG); + } else { + ++nr_free; + sum_free += yG; + } + } + + if (nr_free > 0) + r = sum_free / nr_free; + else + r = (ub + lb) / 2; + + return r; + } + + // java: information about solution except alpha, + // because we cannot return multiple values otherwise... + static class SolutionInfo { + double obj; + double rho; + double upper_bound_p; + double upper_bound_n; + double r; // for Solver_NU + } } @@ -880,1894 +836,1716 @@ else if(is_upper_bound(i)) // // additional constraint: e^T \alpha = constant // -final class Solver_NU extends Solver -{ - private SolutionInfo si; - - void Solve(int l, QMatrix Q, double[] p, byte[] y, - double[] alpha, double Cp, double Cn, double eps, - SolutionInfo si, int shrinking) - { - this.si = si; - super.Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking); - } - - // return 1 if already optimal, return 0 otherwise - int select_working_set(int[] working_set) - { - // return i,j such that y_i = y_j and - // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha) - // j: minimizes the decrease of obj value - // (if quadratic coefficeint <= 0, replace it with tau) - // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha) - - double Gmaxp = -INF; - double Gmaxp2 = -INF; - int Gmaxp_idx = -1; - - double Gmaxn = -INF; - double Gmaxn2 = -INF; - int Gmaxn_idx = -1; - - int Gmin_idx = -1; - double obj_diff_min = INF; - - for(int t=0;t= Gmaxp) - { - Gmaxp = -G[t]; - Gmaxp_idx = t; - } - } - else - { - if(!is_lower_bound(t)) - if(G[t] >= Gmaxn) - { - Gmaxn = G[t]; - Gmaxn_idx = t; - } - } - - int ip = Gmaxp_idx; - int in = Gmaxn_idx; - float[] Q_ip = null; - float[] Q_in = null; - if(ip != -1) // null Q_ip not accessed: Gmaxp=-INF if ip=-1 - Q_ip = Q.get_Q(ip,active_size); - if(in != -1) - Q_in = Q.get_Q(in,active_size); - - for(int j=0;j= Gmaxp2) - Gmaxp2 = G[j]; - if (grad_diff > 0) - { - double obj_diff; - double quad_coef = Q_ip[ip]+QD[j]-2*Q_ip[j]; - if (quad_coef > 0) - obj_diff = -(grad_diff*grad_diff)/quad_coef; - else - obj_diff = -(grad_diff*grad_diff)/1e-12; - - if (obj_diff <= obj_diff_min) - { - Gmin_idx=j; - obj_diff_min = obj_diff; - } - } - } - } - else - { - if (!is_upper_bound(j)) - { - double grad_diff=Gmaxn-G[j]; - if (-G[j] >= Gmaxn2) - Gmaxn2 = -G[j]; - if (grad_diff > 0) - { - double obj_diff; - double quad_coef = Q_in[in]+QD[j]-2*Q_in[j]; - if (quad_coef > 0) - obj_diff = -(grad_diff*grad_diff)/quad_coef; - else - obj_diff = -(grad_diff*grad_diff)/1e-12; - - if (obj_diff <= obj_diff_min) - { - Gmin_idx=j; - obj_diff_min = obj_diff; - } - } - } - } - } - - if(Math.max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps) - return 1; - - if(y[Gmin_idx] == +1) - working_set[0] = Gmaxp_idx; - else - working_set[0] = Gmaxn_idx; - working_set[1] = Gmin_idx; - - return 0; - } - - private boolean be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4) - { - if(is_upper_bound(i)) - { - if(y[i]==+1) - return(-G[i] > Gmax1); - else - return(-G[i] > Gmax4); - } - else if(is_lower_bound(i)) - { - if(y[i]==+1) - return(G[i] > Gmax2); - else - return(G[i] > Gmax3); - } - else - return(false); - } - - void do_shrinking() - { - double Gmax1 = -INF; // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) } - double Gmax2 = -INF; // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) } - double Gmax3 = -INF; // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) } - double Gmax4 = -INF; // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) } - - // find maximal violating pair first - int i; - for(i=0;i Gmax1) Gmax1 = -G[i]; - } - else if(-G[i] > Gmax4) Gmax4 = -G[i]; - } - if(!is_lower_bound(i)) - { - if(y[i]==+1) - { - if(G[i] > Gmax2) Gmax2 = G[i]; - } - else if(G[i] > Gmax3) Gmax3 = G[i]; - } - } - - if(unshrink == false && Math.max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) - { - unshrink = true; - reconstruct_gradient(); - active_size = l; - } - - for(i=0;i i) - { - if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4)) - { - swap_index(i,active_size); - break; - } - active_size--; - } - } - } - - double calculate_rho() - { - int nr_free1 = 0,nr_free2 = 0; - double ub1 = INF, ub2 = INF; - double lb1 = -INF, lb2 = -INF; - double sum_free1 = 0, sum_free2 = 0; - - for(int i=0;i 0) - r1 = sum_free1/nr_free1; - else - r1 = (ub1+lb1)/2; - - if(nr_free2 > 0) - r2 = sum_free2/nr_free2; - else - r2 = (ub2+lb2)/2; - - si.r = (r1+r2)/2; - return (r1-r2)/2; - } +final class Solver_NU extends Solver { + private SolutionInfo si; + + void Solve(int l, QMatrix Q, double[] p, byte[] y, + double[] alpha, double Cp, double Cn, double eps, + SolutionInfo si, int shrinking){ + this.si = si; + super.Solve(l, Q, p, y, alpha, Cp, Cn, eps, si, shrinking); + } + + // return 1 if already optimal, return 0 otherwise + int select_working_set(int[] working_set){ + // return i,j such that y_i = y_j and + // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha) + // j: minimizes the decrease of obj value + // (if quadratic coefficeint <= 0, replace it with tau) + // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha) + + double Gmaxp = -INF; + double Gmaxp2 = -INF; + int Gmaxp_idx = -1; + + double Gmaxn = -INF; + double Gmaxn2 = -INF; + int Gmaxn_idx = -1; + + int Gmin_idx = -1; + double obj_diff_min = INF; + + for (int t = 0; t < active_size; t++) + if (y[t] == +1) { + if (!is_upper_bound(t)) + if (-G[t] >= Gmaxp) { + Gmaxp = -G[t]; + Gmaxp_idx = t; + } + } else { + if (!is_lower_bound(t)) + if (G[t] >= Gmaxn) { + Gmaxn = G[t]; + Gmaxn_idx = t; + } + } + + int ip = Gmaxp_idx; + int in = Gmaxn_idx; + float[] Q_ip = null; + float[] Q_in = null; + if (ip != -1) // null Q_ip not accessed: Gmaxp=-INF if ip=-1 + Q_ip = Q.get_Q(ip, active_size); + if (in != -1) + Q_in = Q.get_Q(in, active_size); + + for (int j = 0; j < active_size; j++) { + if (y[j] == +1) { + if (!is_lower_bound(j)) { + double grad_diff = Gmaxp + G[j]; + if (G[j] >= Gmaxp2) + Gmaxp2 = G[j]; + if (grad_diff > 0) { + double obj_diff; + double quad_coef = Q_ip[ip] + QD[j] - 2 * Q_ip[j]; + if (quad_coef > 0) + obj_diff = -(grad_diff * grad_diff) / quad_coef; + else + obj_diff = -(grad_diff * grad_diff) / 1e-12; + + if (obj_diff <= obj_diff_min) { + Gmin_idx = j; + obj_diff_min = obj_diff; + } + } + } + } else { + if (!is_upper_bound(j)) { + double grad_diff = Gmaxn - G[j]; + if (-G[j] >= Gmaxn2) + Gmaxn2 = -G[j]; + if (grad_diff > 0) { + double obj_diff; + double quad_coef = Q_in[in] + QD[j] - 2 * Q_in[j]; + if (quad_coef > 0) + obj_diff = -(grad_diff * grad_diff) / quad_coef; + else + obj_diff = -(grad_diff * grad_diff) / 1e-12; + + if (obj_diff <= obj_diff_min) { + Gmin_idx = j; + obj_diff_min = obj_diff; + } + } + } + } + } + + if (Math.max(Gmaxp + Gmaxp2, Gmaxn + Gmaxn2) < eps) + return 1; + + if (y[Gmin_idx] == +1) + working_set[0] = Gmaxp_idx; + else + working_set[0] = Gmaxn_idx; + working_set[1] = Gmin_idx; + + return 0; + } + + private boolean be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4){ + if (is_upper_bound(i)) { + if (y[i] == +1) + return (-G[i] > Gmax1); + else + return (-G[i] > Gmax4); + } else if (is_lower_bound(i)) { + if (y[i] == +1) + return (G[i] > Gmax2); + else + return (G[i] > Gmax3); + } else + return (false); + } + + void do_shrinking(){ + double Gmax1 = -INF; // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) } + double Gmax2 = -INF; // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) } + double Gmax3 = -INF; // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) } + double Gmax4 = -INF; // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) } + + // find maximal violating pair first + int i; + for (i = 0; i < active_size; i++) { + if (!is_upper_bound(i)) { + if (y[i] == +1) { + if (-G[i] > Gmax1) Gmax1 = -G[i]; + } else if (-G[i] > Gmax4) Gmax4 = -G[i]; + } + if (!is_lower_bound(i)) { + if (y[i] == +1) { + if (G[i] > Gmax2) Gmax2 = G[i]; + } else if (G[i] > Gmax3) Gmax3 = G[i]; + } + } + + if (!unshrink && Math.max(Gmax1 + Gmax2, Gmax3 + Gmax4) <= eps * 10) { + unshrink = true; + reconstruct_gradient(); + active_size = l; + } + + for (i = 0; i < active_size; i++) + if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4)) { + active_size--; + while (active_size > i) { + if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4)) { + swap_index(i, active_size); + break; + } + active_size--; + } + } + } + + double calculate_rho(){ + int nr_free1 = 0, nr_free2 = 0; + double ub1 = INF, ub2 = INF; + double lb1 = -INF, lb2 = -INF; + double sum_free1 = 0, sum_free2 = 0; + + for (int i = 0; i < active_size; i++) { + if (y[i] == +1) { + if (is_lower_bound(i)) + ub1 = Math.min(ub1, G[i]); + else if (is_upper_bound(i)) + lb1 = Math.max(lb1, G[i]); + else { + ++nr_free1; + sum_free1 += G[i]; + } + } else { + if (is_lower_bound(i)) + ub2 = Math.min(ub2, G[i]); + else if (is_upper_bound(i)) + lb2 = Math.max(lb2, G[i]); + else { + ++nr_free2; + sum_free2 += G[i]; + } + } + } + + double r1, r2; + if (nr_free1 > 0) + r1 = sum_free1 / nr_free1; + else + r1 = (ub1 + lb1) / 2; + + if (nr_free2 > 0) + r2 = sum_free2 / nr_free2; + else + r2 = (ub2 + lb2) / 2; + + si.r = (r1 + r2) / 2; + return (r1 - r2) / 2; + } } // // Q matrices for various formulations // -class SVC_Q extends Kernel -{ - private final byte[] y; - private final Cache cache; - private final float[] QD; - - SVC_Q(svm_problem prob, svm_parameter param, byte[] y_) - { - super(prob.l, prob.x, param); - y = (byte[])y_.clone(); - cache = new Cache(prob.l,(long)(param.cache_size*(1<<20))); - QD = new float[prob.l]; - for(int i=0;i 0) y[i] = +1; else y[i]=-1; - } - - Solver s = new Solver(); - s.Solve(l, new SVC_Q(prob,param,y), minus_ones, y, - alpha, Cp, Cn, param.eps, si, param.shrinking); - - double sum_alpha=0; - for(i=0;i0) - y[i] = +1; - else - y[i] = -1; - - double sum_pos = nu*l/2; - double sum_neg = nu*l/2; - - for(i=0;i 0) - { - ++nSV; - if(prob.y[i] > 0) - { - if(Math.abs(alpha[i]) >= si.upper_bound_p) - ++nBSV; - } - else - { - if(Math.abs(alpha[i]) >= si.upper_bound_n) - ++nBSV; - } - } - } - - svm.info("nSV = "+nSV+", nBSV = "+nBSV+"\n"); - - decision_function f = new decision_function(); - f.alpha = alpha; - f.rho = si.rho; - return f; - } - - // Platt's binary SVM Probablistic Output: an improvement from Lin et al. - private static void sigmoid_train(int l, double[] dec_values, double[] labels, - double[] probAB) - { - double A, B; - double prior1=0, prior0 = 0; - int i; - - for (i=0;i 0) prior1+=1; - else prior0+=1; - - int max_iter=100; // Maximal number of iterations - double min_step=1e-10; // Minimal step taken in line search - double sigma=1e-12; // For numerically strict PD of Hessian - double eps=1e-5; - double hiTarget=(prior1+1.0)/(prior1+2.0); - double loTarget=1/(prior0+2.0); - double[] t= new double[l]; - double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize; - double newA,newB,newf,d1,d2; - int iter; - - // Initial Point and Initial Fun Value - A=0.0; B=Math.log((prior0+1.0)/(prior1+1.0)); - double fval = 0.0; - - for (i=0;i0) t[i]=hiTarget; - else t[i]=loTarget; - fApB = dec_values[i]*A+B; - if (fApB>=0) - fval += t[i]*fApB + Math.log(1+Math.exp(-fApB)); - else - fval += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB)); - } - for (iter=0;iter= 0) - { - p=Math.exp(-fApB)/(1.0+Math.exp(-fApB)); - q=1.0/(1.0+Math.exp(-fApB)); - } - else - { - p=1.0/(1.0+Math.exp(fApB)); - q=Math.exp(fApB)/(1.0+Math.exp(fApB)); - } - d2=p*q; - h11+=dec_values[i]*dec_values[i]*d2; - h22+=d2; - h21+=dec_values[i]*d2; - d1=t[i]-p; - g1+=dec_values[i]*d1; - g2+=d1; - } - - // Stopping Criteria - if (Math.abs(g1)= min_step) - { - newA = A + stepsize * dA; - newB = B + stepsize * dB; - - // New function value - newf = 0.0; - for (i=0;i= 0) - newf += t[i]*fApB + Math.log(1+Math.exp(-fApB)); - else - newf += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB)); - } - // Check sufficient decrease - if (newf=max_iter) - svm.info("Reaching maximal iterations in two-class probability estimates\n"); - probAB[0]=A;probAB[1]=B; - } - - private static double sigmoid_predict(double decision_value, double A, double B) - { - double fApB = decision_value*A+B; - if (fApB >= 0) - return Math.exp(-fApB)/(1.0+Math.exp(-fApB)); - else - return 1.0/(1+Math.exp(fApB)) ; - } - - // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng - private static void multiclass_probability(int k, double[][] r, double[] p) - { - int t,j; - int iter = 0, max_iter=Math.max(100,k); - double[][] Q=new double[k][k]; - double[] Qp= new double[k]; - double pQp, eps=0.005/k; - - for (t=0;tmax_error) - max_error=error; - } - if (max_error=max_iter) - svm.info("Exceeds max_iter in multiclass_prob\n"); - } - - // Cross-validation decision values for probability estimates - private static void svm_binary_svc_probability(svm_problem prob, svm_parameter param, double Cp, double Cn, double[] probAB) - { - int i; - int nr_fold = 5; - int[] perm = new int[prob.l]; - double[] dec_values = new double[prob.l]; - - // random shuffle - for(i=0;i0) - p_count++; - else - n_count++; - - if(p_count==0 && n_count==0) - for(j=begin;j 0 && n_count == 0) - for(j=begin;j 0) - for(j=begin;j 5*std) - count=count+1; - else - mae+=Math.abs(ymv[i]); - mae /= (prob.l-count); - svm.info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+mae+"\n"); - return mae; - } - - // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data - // perm, length l, must be allocated before calling this subroutine - private static void svm_group_classes(svm_problem prob, int[] nr_class_ret, int[][] label_ret, int[][] start_ret, int[][] count_ret, int[] perm) - { - int l = prob.l; - int max_nr_class = 16; - int nr_class = 0; - int[] label = new int[max_nr_class]; - int[] count = new int[max_nr_class]; - int[] data_label = new int[l]; - int i; - - for(i=0;i 0) ++nSV; - model.l = nSV; - model.SV = new svm_node[nSV][]; - model.sv_coef[0] = new double[nSV]; - int j = 0; - for(i=0;i 0) - { - model.SV[j] = prob.x[i]; - model.sv_coef[0][j] = f.alpha[i]; - ++j; - } - } - else - { - // classification - int l = prob.l; - int[] tmp_nr_class = new int[1]; - int[][] tmp_label = new int[1][]; - int[][] tmp_start = new int[1][]; - int[][] tmp_count = new int[1][]; - int[] perm = new int[l]; - - // group training data of the same class - svm_group_classes(prob,tmp_nr_class,tmp_label,tmp_start,tmp_count,perm); - int nr_class = tmp_nr_class[0]; - int[] label = tmp_label[0]; - int[] start = tmp_start[0]; - int[] count = tmp_count[0]; - svm_node[][] x = new svm_node[l][]; - int i; - for(i=0;i 0) - nonzero[si+k] = true; - for(k=0;k 0) - nonzero[sj+k] = true; - ++p; - } - - // build output - - model.nr_class = nr_class; - - model.label = new int[nr_class]; - for(i=0;i some folds may have zero elements - if((param.svm_type == svm_parameter.C_SVC || - param.svm_type == svm_parameter.NU_SVC) && nr_fold < l) - { - int[] tmp_nr_class = new int[1]; - int[][] tmp_label = new int[1][]; - int[][] tmp_start = new int[1][]; - int[][] tmp_count = new int[1][]; - - svm_group_classes(prob,tmp_nr_class,tmp_label,tmp_start,tmp_count,perm); - - int nr_class = tmp_nr_class[0]; - int[] label = tmp_label[0]; - int[] start = tmp_start[0]; - int[] count = tmp_count[0]; - - // random shuffle and then data grouped by fold using the array perm - int[] fold_count = new int[nr_fold]; - int c; - int[] index = new int[l]; - for(i=0;i0)?1:-1; - else - return res[0]; - } - else - { - int i; - int nr_class = model.nr_class; - double[] dec_values = new double[nr_class*(nr_class-1)/2]; - svm_predict_values(model, x, dec_values); - - int[] vote = new int[nr_class]; - for(i=0;i 0) - ++vote[i]; - else - ++vote[j]; - } - - int vote_max_idx = 0; - for(i=1;i vote[vote_max_idx]) - vote_max_idx = i; - return model.label[vote_max_idx]; - } - } - - public static double svm_predict_probability(svm_model model, svm_node[] x, double[] prob_estimates) - { - if ((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) && - model.probA!=null && model.probB!=null) - { - int i; - int nr_class = model.nr_class; - double[] dec_values = new double[nr_class*(nr_class-1)/2]; - svm_predict_values(model, x, dec_values); - - double min_prob=1e-7; - double[][] pairwise_prob=new double[nr_class][nr_class]; - - int k=0; - for(i=0;i prob_estimates[prob_max_idx]) - prob_max_idx = i; - return model.label[prob_max_idx]; - } - else - return svm_predict(model, x); - } - - static final String svm_type_table[] = - { - "c_svc","nu_svc","one_class","epsilon_svr","nu_svr", - }; - - static final String kernel_type_table[]= - { - "linear","polynomial","rbf","sigmoid","precomputed" - }; - - public static void svm_save_model(String model_file_name, svm_model model) throws IOException - { - DataOutputStream fp = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(model_file_name))); - - svm_parameter param = model.param; - - fp.writeBytes("svm_type "+svm_type_table[param.svm_type]+"\n"); - fp.writeBytes("kernel_type "+kernel_type_table[param.kernel_type]+"\n"); - - if(param.kernel_type == svm_parameter.POLY) - fp.writeBytes("degree "+param.degree+"\n"); - - if(param.kernel_type == svm_parameter.POLY || - param.kernel_type == svm_parameter.RBF || - param.kernel_type == svm_parameter.SIGMOID) - fp.writeBytes("gamma "+param.gamma+"\n"); - - if(param.kernel_type == svm_parameter.POLY || - param.kernel_type == svm_parameter.SIGMOID) - fp.writeBytes("coef0 "+param.coef0+"\n"); - - int nr_class = model.nr_class; - int l = model.l; - fp.writeBytes("nr_class "+nr_class+"\n"); - fp.writeBytes("total_sv "+l+"\n"); - - { - fp.writeBytes("rho"); - for(int i=0;i 1) - return "nu <= 0 or nu > 1"; - - if(svm_type == svm_parameter.EPSILON_SVR) - if(param.p < 0) - return "p < 0"; - - if(param.shrinking != 0 && - param.shrinking != 1) - return "shrinking != 0 and shrinking != 1"; - - if(param.probability != 0 && - param.probability != 1) - return "probability != 0 and probability != 1"; - - if(param.probability == 1 && - svm_type == svm_parameter.ONE_CLASS) - return "one-class SVM probability output not supported yet"; - - // check whether nu-svc is feasible - - if(svm_type == svm_parameter.NU_SVC) - { - int l = prob.l; - int max_nr_class = 16; - int nr_class = 0; - int[] label = new int[max_nr_class]; - int[] count = new int[max_nr_class]; - - int i; - for(i=0;i Math.min(n1,n2)) - return "specified nu is infeasible"; - } - } - } - - return null; - } - - public static int svm_check_probability_model(svm_model model) - { - if (((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) && - model.probA!=null && model.probB!=null) || - ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) && - model.probA!=null)) - return 1; - else - return 0; - } + private static final String svm_type_table[] = + { + "c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr", + }; + private static final String kernel_type_table[] = + { + "linear", "polynomial", "rbf", "sigmoid", "precomputed" + }; + private static svm_print_interface svm_print_string = System.out::print; + + static void info(String s){ + svm_print_string.print(s); + } + + private static void solve_c_svc(svm_problem prob, svm_parameter param, + double[] alpha, Solver.SolutionInfo si, + double Cp, double Cn){ + int l = prob.l; + double[] minus_ones = new double[l]; + byte[] y = new byte[l]; + + int i; + + for (i = 0; i < l; i++) { + alpha[i] = 0; + minus_ones[i] = -1; + if (prob.y[i] > 0) y[i] = +1; + else y[i] = -1; + } + + Solver s = new Solver(); + s.Solve(l, new SVC_Q(prob, param, y), minus_ones, y, + alpha, Cp, Cn, param.eps, si, param.shrinking); + + double sum_alpha = 0; + for (i = 0; i < l; i++) + sum_alpha += alpha[i]; + + if (Cp == Cn) + svm.info("nu = " + sum_alpha / (Cp * prob.l) + "\n"); + + for (i = 0; i < l; i++) + alpha[i] *= y[i]; + } + + private static void solve_nu_svc(svm_problem prob, svm_parameter param, + double[] alpha, Solver.SolutionInfo si){ + int i; + int l = prob.l; + double nu = param.nu; + + byte[] y = new byte[l]; + + for (i = 0; i < l; i++) + if (prob.y[i] > 0) + y[i] = +1; + else + y[i] = -1; + + double sum_pos = nu * l / 2; + double sum_neg = nu * l / 2; + + for (i = 0; i < l; i++) + if (y[i] == +1) { + alpha[i] = Math.min(1.0, sum_pos); + sum_pos -= alpha[i]; + } else { + alpha[i] = Math.min(1.0, sum_neg); + sum_neg -= alpha[i]; + } + + double[] zeros = new double[l]; + + for (i = 0; i < l; i++) + zeros[i] = 0; + + Solver_NU s = new Solver_NU(); + s.Solve(l, new SVC_Q(prob, param, y), zeros, y, + alpha, 1.0, 1.0, param.eps, si, param.shrinking); + double r = si.r; + + svm.info("C = " + 1 / r + "\n"); + + for (i = 0; i < l; i++) + alpha[i] *= y[i] / r; + + si.rho /= r; + si.obj /= (r * r); + si.upper_bound_p = 1 / r; + si.upper_bound_n = 1 / r; + } + + private static void solve_one_class(svm_problem prob, svm_parameter param, + double[] alpha, Solver.SolutionInfo si){ + int l = prob.l; + double[] zeros = new double[l]; + byte[] ones = new byte[l]; + int i; + + int n = (int) (param.nu * prob.l); // # of alpha's at upper bound + + for (i = 0; i < n; i++) + alpha[i] = 1; + if (n < prob.l) + alpha[n] = param.nu * prob.l - n; + for (i = n + 1; i < l; i++) + alpha[i] = 0; + + for (i = 0; i < l; i++) { + zeros[i] = 0; + ones[i] = 1; + } + + Solver s = new Solver(); + s.Solve(l, new ONE_CLASS_Q(prob, param), zeros, ones, + alpha, 1.0, 1.0, param.eps, si, param.shrinking); + } + + private static void solve_epsilon_svr(svm_problem prob, svm_parameter param, + double[] alpha, Solver.SolutionInfo si){ + int l = prob.l; + double[] alpha2 = new double[2 * l]; + double[] linear_term = new double[2 * l]; + byte[] y = new byte[2 * l]; + int i; + + for (i = 0; i < l; i++) { + alpha2[i] = 0; + linear_term[i] = param.p - prob.y[i]; + y[i] = 1; + + alpha2[i + l] = 0; + linear_term[i + l] = param.p + prob.y[i]; + y[i + l] = -1; + } + + Solver s = new Solver(); + s.Solve(2 * l, new SVR_Q(prob, param), linear_term, y, + alpha2, param.C, param.C, param.eps, si, param.shrinking); + + double sum_alpha = 0; + for (i = 0; i < l; i++) { + alpha[i] = alpha2[i] - alpha2[i + l]; + sum_alpha += Math.abs(alpha[i]); + } + svm.info("nu = " + sum_alpha / (param.C * l) + "\n"); + } + + private static void solve_nu_svr(svm_problem prob, svm_parameter param, + double[] alpha, Solver.SolutionInfo si){ + int l = prob.l; + double C = param.C; + double[] alpha2 = new double[2 * l]; + double[] linear_term = new double[2 * l]; + byte[] y = new byte[2 * l]; + int i; + + double sum = C * param.nu * l / 2; + for (i = 0; i < l; i++) { + alpha2[i] = alpha2[i + l] = Math.min(sum, C); + sum -= alpha2[i]; + + linear_term[i] = -prob.y[i]; + y[i] = 1; + + linear_term[i + l] = prob.y[i]; + y[i + l] = -1; + } + + Solver_NU s = new Solver_NU(); + s.Solve(2 * l, new SVR_Q(prob, param), linear_term, y, + alpha2, C, C, param.eps, si, param.shrinking); + + svm.info("epsilon = " + (-si.r) + "\n"); + + for (i = 0; i < l; i++) + alpha[i] = alpha2[i] - alpha2[i + l]; + } + + private static decision_function svm_train_one( + svm_problem prob, svm_parameter param, + double Cp, double Cn){ + double[] alpha = new double[prob.l]; + Solver.SolutionInfo si = new Solver.SolutionInfo(); + switch (param.svm_type) { + case svm_parameter.C_SVC: + solve_c_svc(prob, param, alpha, si, Cp, Cn); + break; + case svm_parameter.NU_SVC: + solve_nu_svc(prob, param, alpha, si); + break; + case svm_parameter.ONE_CLASS: + solve_one_class(prob, param, alpha, si); + break; + case svm_parameter.EPSILON_SVR: + solve_epsilon_svr(prob, param, alpha, si); + break; + case svm_parameter.NU_SVR: + solve_nu_svr(prob, param, alpha, si); + break; + } + + svm.info("obj = " + si.obj + ", rho = " + si.rho + "\n"); + + // output SVs + + int nSV = 0; + int nBSV = 0; + for (int i = 0; i < prob.l; i++) { + if (Math.abs(alpha[i]) > 0) { + ++nSV; + if (prob.y[i] > 0) { + if (Math.abs(alpha[i]) >= si.upper_bound_p) + ++nBSV; + } else { + if (Math.abs(alpha[i]) >= si.upper_bound_n) + ++nBSV; + } + } + } + + svm.info("nSV = " + nSV + ", nBSV = " + nBSV + "\n"); + + decision_function f = new decision_function(); + f.alpha = alpha; + f.rho = si.rho; + return f; + } + + // Platt's binary SVM Probablistic Output: an improvement from Lin et al. + private static void sigmoid_train(int l, double[] dec_values, double[] labels, + double[] probAB){ + double A, B; + double prior1 = 0, prior0 = 0; + int i; + + for (i = 0; i < l; i++) + if (labels[i] > 0) prior1 += 1; + else prior0 += 1; + + int max_iter = 100; // Maximal number of iterations + double min_step = 1e-10; // Minimal step taken in line search + double sigma = 1e-12; // For numerically strict PD of Hessian + double eps = 1e-5; + double hiTarget = (prior1 + 1.0) / (prior1 + 2.0); + double loTarget = 1 / (prior0 + 2.0); + double[] t = new double[l]; + double fApB, p, q, h11, h22, h21, g1, g2, det, dA, dB, gd, stepsize; + double newA, newB, newf, d1, d2; + int iter; + + // Initial Point and Initial Fun Value + A = 0.0; + B = Math.log((prior0 + 1.0) / (prior1 + 1.0)); + double fval = 0.0; + + for (i = 0; i < l; i++) { + if (labels[i] > 0) t[i] = hiTarget; + else t[i] = loTarget; + fApB = dec_values[i] * A + B; + if (fApB >= 0) + fval += t[i] * fApB + Math.log(1 + Math.exp(-fApB)); + else + fval += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB)); + } + for (iter = 0; iter < max_iter; iter++) { + // Update Gradient and Hessian (use H' = H + sigma I) + h11 = sigma; // numerically ensures strict PD + h22 = sigma; + h21 = 0.0; + g1 = 0.0; + g2 = 0.0; + for (i = 0; i < l; i++) { + fApB = dec_values[i] * A + B; + if (fApB >= 0) { + p = Math.exp(-fApB) / (1.0 + Math.exp(-fApB)); + q = 1.0 / (1.0 + Math.exp(-fApB)); + } else { + p = 1.0 / (1.0 + Math.exp(fApB)); + q = Math.exp(fApB) / (1.0 + Math.exp(fApB)); + } + d2 = p * q; + h11 += dec_values[i] * dec_values[i] * d2; + h22 += d2; + h21 += dec_values[i] * d2; + d1 = t[i] - p; + g1 += dec_values[i] * d1; + g2 += d1; + } + + // Stopping Criteria + if (Math.abs(g1) < eps && Math.abs(g2) < eps) + break; + + // Finding Newton direction: -inv(H') * g + det = h11 * h22 - h21 * h21; + dA = -(h22 * g1 - h21 * g2) / det; + dB = -(-h21 * g1 + h11 * g2) / det; + gd = g1 * dA + g2 * dB; + + + stepsize = 1; // Line Search + while (stepsize >= min_step) { + newA = A + stepsize * dA; + newB = B + stepsize * dB; + + // New function value + newf = 0.0; + for (i = 0; i < l; i++) { + fApB = dec_values[i] * newA + newB; + if (fApB >= 0) + newf += t[i] * fApB + Math.log(1 + Math.exp(-fApB)); + else + newf += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB)); + } + // Check sufficient decrease + if (newf < fval + 0.0001 * stepsize * gd) { + A = newA; + B = newB; + fval = newf; + break; + } else + stepsize = stepsize / 2.0; + } + + if (stepsize < min_step) { + svm.info("Line search fails in two-class probability estimates\n"); + break; + } + } + + if (iter >= max_iter) + svm.info("Reaching maximal iterations in two-class probability estimates\n"); + probAB[0] = A; + probAB[1] = B; + } + + private static double sigmoid_predict(double decision_value, double A, double B){ + double fApB = decision_value * A + B; + if (fApB >= 0) + return Math.exp(-fApB) / (1.0 + Math.exp(-fApB)); + else + return 1.0 / (1 + Math.exp(fApB)); + } + + // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng + private static void multiclass_probability(int k, double[][] r, double[] p){ + int t, j; + int iter, max_iter = Math.max(100, k); + double[][] Q = new double[k][k]; + double[] Qp = new double[k]; + double pQp, eps = 0.005 / k; + + for (t = 0; t < k; t++) { + p[t] = 1.0 / k; // Valid if k = 1 + Q[t][t] = 0; + for (j = 0; j < t; j++) { + Q[t][t] += r[j][t] * r[j][t]; + Q[t][j] = Q[j][t]; + } + for (j = t + 1; j < k; j++) { + Q[t][t] += r[j][t] * r[j][t]; + Q[t][j] = -r[j][t] * r[t][j]; + } + } + for (iter = 0; iter < max_iter; iter++) { + // stopping condition, recalculate QP,pQP for numerical accuracy + pQp = 0; + for (t = 0; t < k; t++) { + Qp[t] = 0; + for (j = 0; j < k; j++) + Qp[t] += Q[t][j] * p[j]; + pQp += p[t] * Qp[t]; + } + double max_error = 0; + for (t = 0; t < k; t++) { + double error = Math.abs(Qp[t] - pQp); + if (error > max_error) + max_error = error; + } + if (max_error < eps) break; + + for (t = 0; t < k; t++) { + double diff = (-Qp[t] + pQp) / Q[t][t]; + p[t] += diff; + pQp = (pQp + diff * (diff * Q[t][t] + 2 * Qp[t])) / (1 + diff) / (1 + diff); + for (j = 0; j < k; j++) { + Qp[j] = (Qp[j] + diff * Q[t][j]) / (1 + diff); + p[j] /= (1 + diff); + } + } + } + if (iter >= max_iter) + svm.info("Exceeds max_iter in multiclass_prob\n"); + } + + // Cross-validation decision values for probability estimates + private static void svm_binary_svc_probability(svm_problem prob, svm_parameter param, double Cp, double Cn, double[] probAB){ + int i; + int nr_fold = 5; + int[] perm = new int[prob.l]; + double[] dec_values = new double[prob.l]; + + // random shuffle + for (i = 0; i < prob.l; i++) perm[i] = i; + for (i = 0; i < prob.l; i++) { + int j = i + (int) (Math.random() * (prob.l - i)); + do { + int temp = perm[i]; + perm[i] = perm[j]; + perm[j] = temp; + } while (false); + } + for (i = 0; i < nr_fold; i++) { + int begin = i * prob.l / nr_fold; + int end = (i + 1) * prob.l / nr_fold; + int j, k; + svm_problem subprob = new svm_problem(); + + subprob.l = prob.l - (end - begin); + subprob.x = new svm_node[subprob.l][]; + subprob.y = new double[subprob.l]; + + k = 0; + for (j = 0; j < begin; j++) { + subprob.x[k] = prob.x[perm[j]]; + subprob.y[k] = prob.y[perm[j]]; + ++k; + } + for (j = end; j < prob.l; j++) { + subprob.x[k] = prob.x[perm[j]]; + subprob.y[k] = prob.y[perm[j]]; + ++k; + } + int p_count = 0, n_count = 0; + for (j = 0; j < k; j++) + if (subprob.y[j] > 0) + p_count++; + else + n_count++; + + if (p_count == 0 && n_count == 0) + for (j = begin; j < end; j++) + dec_values[perm[j]] = 0; + else if (p_count > 0 && n_count == 0) + for (j = begin; j < end; j++) + dec_values[perm[j]] = 1; + else if (p_count == 0 && n_count > 0) + for (j = begin; j < end; j++) + dec_values[perm[j]] = -1; + else { + svm_parameter subparam = (svm_parameter) param.clone(); + subparam.probability = 0; + subparam.C = 1.0; + subparam.nr_weight = 2; + subparam.weight_label = new int[2]; + subparam.weight = new double[2]; + subparam.weight_label[0] = +1; + subparam.weight_label[1] = -1; + subparam.weight[0] = Cp; + subparam.weight[1] = Cn; + svm_model submodel = svm_train(subprob, subparam); + for (j = begin; j < end; j++) { + double[] dec_value = new double[1]; + svm_predict_values(submodel, prob.x[perm[j]], dec_value); + dec_values[perm[j]] = dec_value[0]; + // ensure +1 -1 order; reason not using CV subroutine + dec_values[perm[j]] *= submodel.label[0]; + } + } + } + sigmoid_train(prob.l, dec_values, prob.y, probAB); + } + + // Return parameter of a Laplace distribution + private static double svm_svr_probability(svm_problem prob, svm_parameter param){ + int i; + int nr_fold = 5; + double[] ymv = new double[prob.l]; + double mae = 0; + + svm_parameter newparam = (svm_parameter) param.clone(); + newparam.probability = 0; + svm_cross_validation(prob, newparam, nr_fold, ymv); + for (i = 0; i < prob.l; i++) { + ymv[i] = prob.y[i] - ymv[i]; + mae += Math.abs(ymv[i]); + } + mae /= prob.l; + double std = Math.sqrt(2 * mae * mae); + int count = 0; + mae = 0; + for (i = 0; i < prob.l; i++) + if (Math.abs(ymv[i]) > 5 * std) + count = count + 1; + else + mae += Math.abs(ymv[i]); + mae /= (prob.l - count); + svm.info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=" + mae + "\n"); + return mae; + } + + // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data + // perm, length l, must be allocated before calling this subroutine + private static void svm_group_classes(svm_problem prob, int[] nr_class_ret, int[][] label_ret, int[][] start_ret, int[][] count_ret, int[] perm){ + int l = prob.l; + int max_nr_class = 16; + int nr_class = 0; + int[] label = new int[max_nr_class]; + int[] count = new int[max_nr_class]; + int[] data_label = new int[l]; + int i; + + for (i = 0; i < l; i++) { + int this_label = (int) (prob.y[i]); + int j; + for (j = 0; j < nr_class; j++) { + if (this_label == label[j]) { + ++count[j]; + break; + } + } + data_label[i] = j; + if (j == nr_class) { + if (nr_class == max_nr_class) { + max_nr_class *= 2; + int[] new_data = new int[max_nr_class]; + System.arraycopy(label, 0, new_data, 0, label.length); + label = new_data; + new_data = new int[max_nr_class]; + System.arraycopy(count, 0, new_data, 0, count.length); + count = new_data; + } + label[nr_class] = this_label; + count[nr_class] = 1; + ++nr_class; + } + } + + int[] start = new int[nr_class]; + start[0] = 0; + for (i = 1; i < nr_class; i++) + start[i] = start[i - 1] + count[i - 1]; + for (i = 0; i < l; i++) { + perm[start[data_label[i]]] = i; + ++start[data_label[i]]; + } + start[0] = 0; + for (i = 1; i < nr_class; i++) + start[i] = start[i - 1] + count[i - 1]; + + nr_class_ret[0] = nr_class; + label_ret[0] = label; + start_ret[0] = start; + count_ret[0] = count; + } + + // + // Interface functions + // + public static svm_model svm_train(svm_problem prob, svm_parameter param){ + svm_model model = new svm_model(); + model.param = param; + + if (param.svm_type == svm_parameter.ONE_CLASS || + param.svm_type == svm_parameter.EPSILON_SVR || + param.svm_type == svm_parameter.NU_SVR) { + // regression or one-class-svm + model.nr_class = 2; + model.label = null; + model.nSV = null; + model.probA = null; + model.probB = null; + model.sv_coef = new double[1][]; + + if (param.probability == 1 && + (param.svm_type == svm_parameter.EPSILON_SVR || + param.svm_type == svm_parameter.NU_SVR)) { + model.probA = new double[1]; + model.probA[0] = svm_svr_probability(prob, param); + } + + decision_function f = svm_train_one(prob, param, 0, 0); + model.rho = new double[1]; + model.rho[0] = f.rho; + + int nSV = 0; + int i; + for (i = 0; i < prob.l; i++) + if (Math.abs(f.alpha[i]) > 0) ++nSV; + model.l = nSV; + model.SV = new svm_node[nSV][]; + model.sv_coef[0] = new double[nSV]; + int j = 0; + for (i = 0; i < prob.l; i++) + if (Math.abs(f.alpha[i]) > 0) { + model.SV[j] = prob.x[i]; + model.sv_coef[0][j] = f.alpha[i]; + ++j; + } + } else { + // classification + int l = prob.l; + int[] tmp_nr_class = new int[1]; + int[][] tmp_label = new int[1][]; + int[][] tmp_start = new int[1][]; + int[][] tmp_count = new int[1][]; + int[] perm = new int[l]; + + // group training data of the same class + svm_group_classes(prob, tmp_nr_class, tmp_label, tmp_start, tmp_count, perm); + int nr_class = tmp_nr_class[0]; + int[] label = tmp_label[0]; + int[] start = tmp_start[0]; + int[] count = tmp_count[0]; + svm_node[][] x = new svm_node[l][]; + int i; + for (i = 0; i < l; i++) + x[i] = prob.x[perm[i]]; + + // calculate weighted C + + double[] weighted_C = new double[nr_class]; + for (i = 0; i < nr_class; i++) + weighted_C[i] = param.C; + for (i = 0; i < param.nr_weight; i++) { + int j; + for (j = 0; j < nr_class; j++) + if (param.weight_label[i] == label[j]) + break; + if (j == nr_class) + System.err.print("warning: class label " + param.weight_label[i] + " specified in weight is not found\n"); + else + weighted_C[j] *= param.weight[i]; + } + + // train k*(k-1)/2 models + + boolean[] nonzero = new boolean[l]; + for (i = 0; i < l; i++) + nonzero[i] = false; + decision_function[] f = new decision_function[nr_class * (nr_class - 1) / 2]; + + double[] probA = null, probB = null; + if (param.probability == 1) { + probA = new double[nr_class * (nr_class - 1) / 2]; + probB = new double[nr_class * (nr_class - 1) / 2]; + } + + int p = 0; + for (i = 0; i < nr_class; i++) + for (int j = i + 1; j < nr_class; j++) { + svm_problem sub_prob = new svm_problem(); + int si = start[i], sj = start[j]; + int ci = count[i], cj = count[j]; + sub_prob.l = ci + cj; + sub_prob.x = new svm_node[sub_prob.l][]; + sub_prob.y = new double[sub_prob.l]; + int k; + for (k = 0; k < ci; k++) { + sub_prob.x[k] = x[si + k]; + sub_prob.y[k] = +1; + } + for (k = 0; k < cj; k++) { + sub_prob.x[ci + k] = x[sj + k]; + sub_prob.y[ci + k] = -1; + } + + if (param.probability == 1) { + double[] probAB = new double[2]; + svm_binary_svc_probability(sub_prob, param, weighted_C[i], weighted_C[j], probAB); + probA[p] = probAB[0]; + probB[p] = probAB[1]; + } + + f[p] = svm_train_one(sub_prob, param, weighted_C[i], weighted_C[j]); + for (k = 0; k < ci; k++) + if (!nonzero[si + k] && Math.abs(f[p].alpha[k]) > 0) + nonzero[si + k] = true; + for (k = 0; k < cj; k++) + if (!nonzero[sj + k] && Math.abs(f[p].alpha[ci + k]) > 0) + nonzero[sj + k] = true; + ++p; + } + + // build output + + model.nr_class = nr_class; + + model.label = new int[nr_class]; + for (i = 0; i < nr_class; i++) + model.label[i] = label[i]; + + model.rho = new double[nr_class * (nr_class - 1) / 2]; + for (i = 0; i < nr_class * (nr_class - 1) / 2; i++) + model.rho[i] = f[i].rho; + + if (param.probability == 1) { + model.probA = new double[nr_class * (nr_class - 1) / 2]; + model.probB = new double[nr_class * (nr_class - 1) / 2]; + for (i = 0; i < nr_class * (nr_class - 1) / 2; i++) { + model.probA[i] = probA[i]; + model.probB[i] = probB[i]; + } + } else { + model.probA = null; + model.probB = null; + } + + int nnz = 0; + int[] nz_count = new int[nr_class]; + model.nSV = new int[nr_class]; + for (i = 0; i < nr_class; i++) { + int nSV = 0; + for (int j = 0; j < count[i]; j++) + if (nonzero[start[i] + j]) { + ++nSV; + ++nnz; + } + model.nSV[i] = nSV; + nz_count[i] = nSV; + } + + svm.info("Total nSV = " + nnz + "\n"); + + model.l = nnz; + model.SV = new svm_node[nnz][]; + p = 0; + for (i = 0; i < l; i++) + if (nonzero[i]) model.SV[p++] = x[i]; + + int[] nz_start = new int[nr_class]; + nz_start[0] = 0; + for (i = 1; i < nr_class; i++) + nz_start[i] = nz_start[i - 1] + nz_count[i - 1]; + + model.sv_coef = new double[nr_class - 1][]; + for (i = 0; i < nr_class - 1; i++) + model.sv_coef[i] = new double[nnz]; + + p = 0; + for (i = 0; i < nr_class; i++) + for (int j = i + 1; j < nr_class; j++) { + // classifier (i,j): coefficients with + // i are in sv_coef[j-1][nz_start[i]...], + // j are in sv_coef[i][nz_start[j]...] + + int si = start[i]; + int sj = start[j]; + int ci = count[i]; + int cj = count[j]; + + int q = nz_start[i]; + int k; + for (k = 0; k < ci; k++) + if (nonzero[si + k]) + model.sv_coef[j - 1][q++] = f[p].alpha[k]; + q = nz_start[j]; + for (k = 0; k < cj; k++) + if (nonzero[sj + k]) + model.sv_coef[i][q++] = f[p].alpha[ci + k]; + ++p; + } + } + return model; + } + + // Stratified cross validation + private static void svm_cross_validation(svm_problem prob, svm_parameter param, int nr_fold, double[] target){ + int i; + int[] fold_start = new int[nr_fold + 1]; + int l = prob.l; + int[] perm = new int[l]; + + // stratified cv may not give leave-one-out rate + // Each class to l folds -> some folds may have zero elements + if ((param.svm_type == svm_parameter.C_SVC || + param.svm_type == svm_parameter.NU_SVC) && nr_fold < l) { + int[] tmp_nr_class = new int[1]; + int[][] tmp_label = new int[1][]; + int[][] tmp_start = new int[1][]; + int[][] tmp_count = new int[1][]; + + svm_group_classes(prob, tmp_nr_class, tmp_label, tmp_start, tmp_count, perm); + + int nr_class = tmp_nr_class[0]; + int[] start = tmp_start[0]; + int[] count = tmp_count[0]; + + // random shuffle and then data grouped by fold using the array perm + int[] fold_count = new int[nr_fold]; + int c; + int[] index = new int[l]; + for (i = 0; i < l; i++) + index[i] = perm[i]; + for (c = 0; c < nr_class; c++) + for (i = 0; i < count[c]; i++) { + int j = i + (int) (Math.random() * (count[c] - i)); + do { + int temp = index[start[c] + j]; + index[start[c] + j] = index[start[c] + i]; + index[start[c] + i] = temp; + } while (false); + } + for (i = 0; i < nr_fold; i++) { + fold_count[i] = 0; + for (c = 0; c < nr_class; c++) + fold_count[i] += (i + 1) * count[c] / nr_fold - i * count[c] / nr_fold; + } + fold_start[0] = 0; + for (i = 1; i <= nr_fold; i++) + fold_start[i] = fold_start[i - 1] + fold_count[i - 1]; + for (c = 0; c < nr_class; c++) + for (i = 0; i < nr_fold; i++) { + int begin = start[c] + i * count[c] / nr_fold; + int end = start[c] + (i + 1) * count[c] / nr_fold; + for (int j = begin; j < end; j++) { + perm[fold_start[i]] = index[j]; + fold_start[i]++; + } + } + fold_start[0] = 0; + for (i = 1; i <= nr_fold; i++) + fold_start[i] = fold_start[i - 1] + fold_count[i - 1]; + } else { + for (i = 0; i < l; i++) perm[i] = i; + for (i = 0; i < l; i++) { + int j = i + (int) (Math.random() * (l - i)); + do { + int temp = perm[i]; + perm[i] = perm[j]; + perm[j] = temp; + } while (false); + } + for (i = 0; i <= nr_fold; i++) + fold_start[i] = i * l / nr_fold; + } + + for (i = 0; i < nr_fold; i++) { + int begin = fold_start[i]; + int end = fold_start[i + 1]; + int j, k; + svm_problem subprob = new svm_problem(); + + subprob.l = l - (end - begin); + subprob.x = new svm_node[subprob.l][]; + subprob.y = new double[subprob.l]; + + k = 0; + for (j = 0; j < begin; j++) { + subprob.x[k] = prob.x[perm[j]]; + subprob.y[k] = prob.y[perm[j]]; + ++k; + } + for (j = end; j < l; j++) { + subprob.x[k] = prob.x[perm[j]]; + subprob.y[k] = prob.y[perm[j]]; + ++k; + } + svm_model submodel = svm_train(subprob, param); + if (param.probability == 1 && + (param.svm_type == svm_parameter.C_SVC || + param.svm_type == svm_parameter.NU_SVC)) { + double[] prob_estimates = new double[svm_get_nr_class(submodel)]; + for (j = begin; j < end; j++) + target[perm[j]] = svm_predict_probability(submodel, prob.x[perm[j]], prob_estimates); + } else + for (j = begin; j < end; j++) + target[perm[j]] = svm_predict(submodel, prob.x[perm[j]]); + } + } + + public static int svm_get_svm_type(svm_model model){ + return model.param.svm_type; + } + + public static int svm_get_nr_class(svm_model model){ + return model.nr_class; + } + + public static void svm_get_labels(svm_model model, int[] label){ + if (model.label != null) + System.arraycopy(model.label, 0, label, 0, model.nr_class); + } + + /** + * @param model + * @return + */ + public static double svm_get_svr_probability(svm_model model){ + if ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) && + model.probA != null) + return model.probA[0]; + else { + System.err.print("Model doesn't contain information for SVR probability inference\n"); + return 0; + } + } + + private static void svm_predict_values(svm_model model, svm_node[] x, double[] dec_values){ + if (model.param.svm_type == svm_parameter.ONE_CLASS || + model.param.svm_type == svm_parameter.EPSILON_SVR || + model.param.svm_type == svm_parameter.NU_SVR) { + double[] sv_coef = model.sv_coef[0]; + double sum = 0; + for (int i = 0; i < model.l; i++) + sum += sv_coef[i] * Kernel.k_function(x, model.SV[i], model.param); + sum -= model.rho[0]; + dec_values[0] = sum; + } else { + int i; + int nr_class = model.nr_class; + int l = model.l; + + double[] kvalue = new double[l]; + for (i = 0; i < l; i++) + kvalue[i] = Kernel.k_function(x, model.SV[i], model.param); + + int[] start = new int[nr_class]; + start[0] = 0; + for (i = 1; i < nr_class; i++) + start[i] = start[i - 1] + model.nSV[i - 1]; + + int p = 0; + for (i = 0; i < nr_class; i++) + for (int j = i + 1; j < nr_class; j++) { + double sum = 0; + int si = start[i]; + int sj = start[j]; + int ci = model.nSV[i]; + int cj = model.nSV[j]; + + int k; + double[] coef1 = model.sv_coef[j - 1]; + double[] coef2 = model.sv_coef[i]; + for (k = 0; k < ci; k++) + sum += coef1[si + k] * kvalue[si + k]; + for (k = 0; k < cj; k++) + sum += coef2[sj + k] * kvalue[sj + k]; + sum -= model.rho[p]; + dec_values[p] = sum; + p++; + } + } + } + + public static double svm_predict(svm_model model, svm_node[] x){ + if (model.param.svm_type == svm_parameter.ONE_CLASS || + model.param.svm_type == svm_parameter.EPSILON_SVR || + model.param.svm_type == svm_parameter.NU_SVR) { + double[] res = new double[1]; + svm_predict_values(model, x, res); + + if (model.param.svm_type == svm_parameter.ONE_CLASS) + return (res[0] > 0) ? 1 : -1; + else + return res[0]; + } else { + int i; + int nr_class = model.nr_class; + double[] dec_values = new double[nr_class * (nr_class - 1) / 2]; + svm_predict_values(model, x, dec_values); + + int[] vote = new int[nr_class]; + for (i = 0; i < nr_class; i++) + vote[i] = 0; + int pos = 0; + for (i = 0; i < nr_class; i++) + for (int j = i + 1; j < nr_class; j++) { + if (dec_values[pos++] > 0) + ++vote[i]; + else + ++vote[j]; + } + + int vote_max_idx = 0; + for (i = 1; i < nr_class; i++) + if (vote[i] > vote[vote_max_idx]) + vote_max_idx = i; + return model.label[vote_max_idx]; + } + } + + private static double svm_predict_probability(svm_model model, svm_node[] x, double[] prob_estimates){ + if ((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) && + model.probA != null && model.probB != null) { + int i; + int nr_class = model.nr_class; + double[] dec_values = new double[nr_class * (nr_class - 1) / 2]; + svm_predict_values(model, x, dec_values); + + double min_prob = 1e-7; + double[][] pairwise_prob = new double[nr_class][nr_class]; + + int k = 0; + for (i = 0; i < nr_class; i++) + for (int j = i + 1; j < nr_class; j++) { + pairwise_prob[i][j] = Math.min(Math.max(sigmoid_predict(dec_values[k], model.probA[k], model.probB[k]), min_prob), 1 - min_prob); + pairwise_prob[j][i] = 1 - pairwise_prob[i][j]; + k++; + } + multiclass_probability(nr_class, pairwise_prob, prob_estimates); + + int prob_max_idx = 0; + for (i = 1; i < nr_class; i++) + if (prob_estimates[i] > prob_estimates[prob_max_idx]) + prob_max_idx = i; + return model.label[prob_max_idx]; + } else + return svm_predict(model, x); + } + + public static void svm_save_model(String model_file_name, svm_model model) throws IOException{ + DataOutputStream fp = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(model_file_name))); + + svm_parameter param = model.param; + + fp.writeBytes("svm_type " + svm_type_table[param.svm_type] + "\n"); + fp.writeBytes("kernel_type " + kernel_type_table[param.kernel_type] + "\n"); + + if (param.kernel_type == svm_parameter.POLY) + fp.writeBytes("degree " + param.degree + "\n"); + + if (param.kernel_type == svm_parameter.POLY || + param.kernel_type == svm_parameter.RBF || + param.kernel_type == svm_parameter.SIGMOID) + fp.writeBytes("gamma " + param.gamma + "\n"); + + if (param.kernel_type == svm_parameter.POLY || + param.kernel_type == svm_parameter.SIGMOID) + fp.writeBytes("coef0 " + param.coef0 + "\n"); + + int nr_class = model.nr_class; + int l = model.l; + fp.writeBytes("nr_class " + nr_class + "\n"); + fp.writeBytes("total_sv " + l + "\n"); + + { + fp.writeBytes("rho"); + for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++) + fp.writeBytes(" " + model.rho[i]); + fp.writeBytes("\n"); + } + + if (model.label != null) { + fp.writeBytes("label"); + for (int i = 0; i < nr_class; i++) + fp.writeBytes(" " + model.label[i]); + fp.writeBytes("\n"); + } + + if (model.probA != null) // regression has probA only + { + fp.writeBytes("probA"); + for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++) + fp.writeBytes(" " + model.probA[i]); + fp.writeBytes("\n"); + } + if (model.probB != null) { + fp.writeBytes("probB"); + for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++) + fp.writeBytes(" " + model.probB[i]); + fp.writeBytes("\n"); + } + + if (model.nSV != null) { + fp.writeBytes("nr_sv"); + for (int i = 0; i < nr_class; i++) + fp.writeBytes(" " + model.nSV[i]); + fp.writeBytes("\n"); + } + + fp.writeBytes("SV\n"); + double[][] sv_coef = model.sv_coef; + svm_node[][] SV = model.SV; + + for (int i = 0; i < l; i++) { + for (int j = 0; j < nr_class - 1; j++) + fp.writeBytes(sv_coef[j][i] + " "); + + svm_node[] p = SV[i]; + if (param.kernel_type == svm_parameter.PRECOMPUTED) + fp.writeBytes("0:" + (int) (p[0].value)); + else + for (svm_node aP : p) { + fp.writeBytes(aP.index + ":" + aP.value + " "); + } + fp.writeBytes("\n"); + } + + fp.close(); + } + + private static double atof(String s){ + return Double.valueOf(s); + } + + private static int atoi(String s){ + return Integer.parseInt(s); + } + + /** + * @param model_file_name + * @throws IOException + */ + public static svm_model svm_load_model(String model_file_name) throws IOException{ + BufferedReader fp = new BufferedReader(new FileReader(model_file_name)); + + // read parameters + + svm_model model = new svm_model(); + svm_parameter param = new svm_parameter(); + model.param = param; + model.rho = null; + model.probA = null; + model.probB = null; + model.label = null; + model.nSV = null; + + while (true) { + String cmd = fp.readLine(); + String arg = cmd.substring(cmd.indexOf(' ') + 1); + + if (cmd.startsWith("svm_type")) { + int i; + for (i = 0; i < svm_type_table.length; i++) { + if (arg.contains(svm_type_table[i])) { + param.svm_type = i; + break; + } + } + if (i == svm_type_table.length) { + System.err.print("unknown svm type.\n"); + return null; + } + } else if (cmd.startsWith("kernel_type")) { + int i; + for (i = 0; i < kernel_type_table.length; i++) { + if (arg.contains(kernel_type_table[i])) { + param.kernel_type = i; + break; + } + } + if (i == kernel_type_table.length) { + System.err.print("unknown kernel function.\n"); + return null; + } + } else if (cmd.startsWith("degree")) + param.degree = atoi(arg); + else if (cmd.startsWith("gamma")) + param.gamma = atof(arg); + else if (cmd.startsWith("coef0")) + param.coef0 = atof(arg); + else if (cmd.startsWith("nr_class")) + model.nr_class = atoi(arg); + else if (cmd.startsWith("total_sv")) + model.l = atoi(arg); + else if (cmd.startsWith("rho")) { + int n = model.nr_class * (model.nr_class - 1) / 2; + model.rho = new double[n]; + StringTokenizer st = new StringTokenizer(arg); + for (int i = 0; i < n; i++) + model.rho[i] = atof(st.nextToken()); + } else if (cmd.startsWith("label")) { + int n = model.nr_class; + model.label = new int[n]; + StringTokenizer st = new StringTokenizer(arg); + for (int i = 0; i < n; i++) + model.label[i] = atoi(st.nextToken()); + } else if (cmd.startsWith("probA")) { + int n = model.nr_class * (model.nr_class - 1) / 2; + model.probA = new double[n]; + StringTokenizer st = new StringTokenizer(arg); + for (int i = 0; i < n; i++) + model.probA[i] = atof(st.nextToken()); + } else if (cmd.startsWith("probB")) { + int n = model.nr_class * (model.nr_class - 1) / 2; + model.probB = new double[n]; + StringTokenizer st = new StringTokenizer(arg); + for (int i = 0; i < n; i++) + model.probB[i] = atof(st.nextToken()); + } else if (cmd.startsWith("nr_sv")) { + int n = model.nr_class; + model.nSV = new int[n]; + StringTokenizer st = new StringTokenizer(arg); + for (int i = 0; i < n; i++) + model.nSV[i] = atoi(st.nextToken()); + } else if (cmd.startsWith("SV")) { + break; + } else { + System.err.print("unknown text in model file: [" + cmd + "]\n"); + return null; + } + } + + // read sv_coef and SV + + int m = model.nr_class - 1; + int l = model.l; + model.sv_coef = new double[m][l]; + model.SV = new svm_node[l][]; + + for (int i = 0; i < l; i++) { + String line = fp.readLine(); + StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:"); + + for (int k = 0; k < m; k++) + model.sv_coef[k][i] = atof(st.nextToken()); + int n = st.countTokens() / 2; + model.SV[i] = new svm_node[n]; + for (int j = 0; j < n; j++) { + model.SV[i][j] = new svm_node(); + model.SV[i][j].index = atoi(st.nextToken()); + model.SV[i][j].value = atof(st.nextToken()); + } + } + + fp.close(); + return model; + } + + /** + * 对svm的配置参数叫验证,因为有些参数只针对部分的支持向量机的类型 + * + * @param prob 问题 + * @param param 参数 + */ + public static String svm_check_parameter(svm_problem prob, svm_parameter param){ + // svm_type + + int svm_type = param.svm_type; + if (svm_type != svm_parameter.C_SVC && + svm_type != svm_parameter.NU_SVC && + svm_type != svm_parameter.ONE_CLASS && + svm_type != svm_parameter.EPSILON_SVR && + svm_type != svm_parameter.NU_SVR) + return "unknown svm type"; + + // kernel_type, degree + + int kernel_type = param.kernel_type; + if (kernel_type != svm_parameter.LINEAR && + kernel_type != svm_parameter.POLY && + kernel_type != svm_parameter.RBF && + kernel_type != svm_parameter.SIGMOID && + kernel_type != svm_parameter.PRECOMPUTED) + return "unknown kernel type"; + + if (param.degree < 0) + return "degree of polynomial kernel < 0"; + + // cache_size,eps,C,nu,p,shrinking + + if (param.cache_size <= 0) + return "cache_size <= 0"; + + if (param.eps <= 0) + return "eps <= 0"; + + if (svm_type == svm_parameter.C_SVC || + svm_type == svm_parameter.EPSILON_SVR || + svm_type == svm_parameter.NU_SVR) + if (param.C <= 0) + return "C <= 0"; + + if (svm_type == svm_parameter.NU_SVC || + svm_type == svm_parameter.ONE_CLASS || + svm_type == svm_parameter.NU_SVR) + if (param.nu <= 0 || param.nu > 1) + return "nu <= 0 or nu > 1"; + + if (svm_type == svm_parameter.EPSILON_SVR) + if (param.p < 0) + return "p < 0"; + + if (param.shrinking != 0 && + param.shrinking != 1) + return "shrinking != 0 and shrinking != 1"; + + if (param.probability != 0 && + param.probability != 1) + return "probability != 0 and probability != 1"; + + if (param.probability == 1 && + svm_type == svm_parameter.ONE_CLASS) + return "one-class SVM probability output not supported yet"; + + // check whether nu-svc is feasible + + if (svm_type == svm_parameter.NU_SVC) { + int l = prob.l; + int max_nr_class = 16; + int nr_class = 0; + int[] label = new int[max_nr_class]; + int[] count = new int[max_nr_class]; + + int i; + for (i = 0; i < l; i++) { + int this_label = (int) prob.y[i]; + int j; + for (j = 0; j < nr_class; j++) + if (this_label == label[j]) { + ++count[j]; + break; + } + + if (j == nr_class) { + if (nr_class == max_nr_class) { + max_nr_class *= 2; + int[] new_data = new int[max_nr_class]; + System.arraycopy(label, 0, new_data, 0, label.length); + label = new_data; + + new_data = new int[max_nr_class]; + System.arraycopy(count, 0, new_data, 0, count.length); + count = new_data; + } + label[nr_class] = this_label; + count[nr_class] = 1; + ++nr_class; + } + } + + for (i = 0; i < nr_class; i++) { + int n1 = count[i]; + for (int j = i + 1; j < nr_class; j++) { + int n2 = count[j]; + if (param.nu * (n1 + n2) / 2 > Math.min(n1, n2)) + return "specified nu is infeasible"; + } + } + } + + return null; + } + + public static int svm_check_probability_model(svm_model model){ + if (((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) && + model.probA != null && model.probB != null) || + ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) && + model.probA != null)) + return 1; + else + return 0; + } + + // + // decision_function + // + private static class decision_function { + double[] alpha; + double rho; + } } diff --git a/StatisticalLearning/DataMining_SVM/libsvm/svm_model.java b/StatisticalLearning/DataMining_SVM/libsvm/svm_model.java index 728f82f..a010c8f 100644 --- a/StatisticalLearning/DataMining_SVM/libsvm/svm_model.java +++ b/StatisticalLearning/DataMining_SVM/libsvm/svm_model.java @@ -1,24 +1,24 @@ // // svm_model // -package DataMining_SVM.libsvm; -public class svm_model implements java.io.Serializable -{ - //svm支持向量机的参数 - svm_parameter param; // parameter - //分类的类型数 - int nr_class; // number of classes, = 2 in regression/one class svm - int l; // total #SV - svm_node[][] SV; // SVs (SV[l]) - double[][] sv_coef; // coefficients for SVs in decision functions (sv_coef[k-1][l]) - double[] rho; // constants in decision functions (rho[k*(k-1)/2]) - double[] probA; // pariwise probability information - double[] probB; +package StatisticalLearning.DataMining_SVM.libsvm; - // for classification only +public class svm_model implements java.io.Serializable { + //svm支持向量机的参数 + svm_parameter param; // parameter + //分类的类型数 + int nr_class; // number of classes, = 2 in regression/one class svm + int l; // total #SV + svm_node[][] SV; // SVs (SV[l]) + double[][] sv_coef; // coefficients for SVs in decision functions (sv_coef[k-1][l]) + double[] rho; // constants in decision functions (rho[k*(k-1)/2]) + double[] probA; // pariwise probability information + double[] probB; - //每个类型的类型值 - int[] label; // label of each class (label[k]) - int[] nSV; // number of SVs for each class (nSV[k]) - // nSV[0] + nSV[1] + ... + nSV[k-1] = l -}; + // for classification only + + //每个类型的类型值 + int[] label; // label of each class (label[k]) + int[] nSV; // number of SVs for each class (nSV[k]) + // nSV[0] + nSV[1] + ... + nSV[k-1] = l +} diff --git a/StatisticalLearning/DataMining_SVM/libsvm/svm_node.java b/StatisticalLearning/DataMining_SVM/libsvm/svm_node.java index 1433f4e..1a3d910 100644 --- a/StatisticalLearning/DataMining_SVM/libsvm/svm_node.java +++ b/StatisticalLearning/DataMining_SVM/libsvm/svm_node.java @@ -1,14 +1,13 @@ -package DataMining_SVM.libsvm; +package StatisticalLearning.DataMining_SVM.libsvm; + /** - * * svm向量节点 - * @author lyq * + * @author Qstar */ -public class svm_node implements java.io.Serializable -{ - //节点索引 - public int index; - //节点的值 - public double value; +public class svm_node implements java.io.Serializable { + //节点索引 + public int index; + //节点的值 + public double value; } diff --git a/StatisticalLearning/DataMining_SVM/libsvm/svm_parameter.java b/StatisticalLearning/DataMining_SVM/libsvm/svm_parameter.java index 45b1acc..70c4730 100644 --- a/StatisticalLearning/DataMining_SVM/libsvm/svm_parameter.java +++ b/StatisticalLearning/DataMining_SVM/libsvm/svm_parameter.java @@ -1,52 +1,47 @@ -package DataMining_SVM.libsvm; -public class svm_parameter implements Cloneable,java.io.Serializable -{ - /* svm_type 支持向量机的类型*/ - public static final int C_SVC = 0; - public static final int NU_SVC = 1; - //一类svm - public static final int ONE_CLASS = 2; - public static final int EPSILON_SVR = 3; - public static final int NU_SVR = 4; +package StatisticalLearning.DataMining_SVM.libsvm; - /* kernel_type 核函数类型*/ - //线型核函数 - public static final int LINEAR = 0; - //多项式核函数 - public static final int POLY = 1; - //RBF径向基函数 - public static final int RBF = 2; - //二层神经网络核函数 - public static final int SIGMOID = 3; - public static final int PRECOMPUTED = 4; +public class svm_parameter implements Cloneable, java.io.Serializable { + public static final int EPSILON_SVR = 3; + /* kernel_type 核函数类型*/ + //线型核函数 + public static final int LINEAR = 0; + /* svm_type 支持向量机的类型*/ + static final int C_SVC = 0; + static final int NU_SVC = 1; + //一类svm + static final int ONE_CLASS = 2; + static final int NU_SVR = 4; + //多项式核函数 + static final int POLY = 1; + //RBF径向基函数 + static final int RBF = 2; + //二层神经网络核函数 + static final int SIGMOID = 3; + static final int PRECOMPUTED = 4; - public int svm_type; - public int kernel_type; - public int degree; // for poly - public double gamma; // for poly/rbf/sigmoid - public double coef0; // for poly/sigmoid + public int svm_type; + public int kernel_type; + // these are for training only 后面这些参数只针对训练集的数据 + public double cache_size; // in MB + public double eps; // stopping criteria + public double C; // for C_SVC, EPSILON_SVR and NU_SVR + public double p; // for EPSILON_SVR + int degree; // for poly + double gamma; // for poly/rbf/sigmoid + double coef0; // for poly/sigmoid + int nr_weight; // for C_SVC + int[] weight_label; // for C_SVC + double[] weight; // for C_SVC + double nu; // for NU_SVC, ONE_CLASS, and NU_SVR + int shrinking; // use the shrinking heuristics + int probability; // do probability estimates - // these are for training only 后面这些参数只针对训练集的数据 - public double cache_size; // in MB - public double eps; // stopping criteria - public double C; // for C_SVC, EPSILON_SVR and NU_SVR - public int nr_weight; // for C_SVC - public int[] weight_label; // for C_SVC - public double[] weight; // for C_SVC - public double nu; // for NU_SVC, ONE_CLASS, and NU_SVR - public double p; // for EPSILON_SVR - public int shrinking; // use the shrinking heuristics - public int probability; // do probability estimates - - public Object clone() - { - try - { - return super.clone(); - } catch (CloneNotSupportedException e) - { - return null; - } - } + public Object clone(){ + try { + return super.clone(); + } catch (CloneNotSupportedException e) { + return null; + } + } } diff --git a/StatisticalLearning/DataMining_SVM/libsvm/svm_print_interface.java b/StatisticalLearning/DataMining_SVM/libsvm/svm_print_interface.java index dedaa6c..8c3c71d 100644 --- a/StatisticalLearning/DataMining_SVM/libsvm/svm_print_interface.java +++ b/StatisticalLearning/DataMining_SVM/libsvm/svm_print_interface.java @@ -1,5 +1,5 @@ -package DataMining_SVM.libsvm; -public interface svm_print_interface -{ - public void print(String s); +package StatisticalLearning.DataMining_SVM.libsvm; + +interface svm_print_interface { + void print(String s); } diff --git a/StatisticalLearning/DataMining_SVM/libsvm/svm_problem.java b/StatisticalLearning/DataMining_SVM/libsvm/svm_problem.java index ddddf7c..ce192a1 100644 --- a/StatisticalLearning/DataMining_SVM/libsvm/svm_problem.java +++ b/StatisticalLearning/DataMining_SVM/libsvm/svm_problem.java @@ -1,15 +1,15 @@ -package DataMining_SVM.libsvm; +package StatisticalLearning.DataMining_SVM.libsvm; + /** * 包含了训练集数据的基本信息 - * @author lyq * + * @author Qstar */ -public class svm_problem implements java.io.Serializable -{ - //定义了向量的总个数 - public int l; - //分类类型值数组 - public double[] y; - //训练集向量表 - public svm_node[][] x; +public class svm_problem implements java.io.Serializable { + //定义了向量的总个数 + public int l; + //分类类型值数组 + public double[] y; + //训练集向量表 + public svm_node[][] x; }